about summary refs log tree commit diff
path: root/tvix/castore/src/import/blobs.rs
use std::{
    io::{Cursor, Write},
    sync::Arc,
};

use tokio::{
    io::AsyncRead,
    sync::Semaphore,
    task::{JoinError, JoinSet},
};
use tokio_util::io::InspectReader;

use crate::{blobservice::BlobService, B3Digest, Path, PathBuf};

/// Files smaller than this threshold, in bytes, are uploaded to the [BlobService] in the
/// background.
///
/// This is a u32 since we acquire a weighted semaphore using the size of the blob.
/// [Semaphore::acquire_many_owned] takes a u32, so we need to ensure the size of
/// the blob can be represented using a u32 and will not cause an overflow.
const CONCURRENT_BLOB_UPLOAD_THRESHOLD: u32 = 1024 * 1024;

/// The maximum amount of bytes allowed to be buffered in memory to perform async blob uploads.
const MAX_BUFFER_SIZE: usize = 128 * 1024 * 1024;

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("unable to read blob contents for {0}: {1}")]
    BlobRead(PathBuf, std::io::Error),

    // FUTUREWORK: proper error for blob finalize
    #[error("unable to finalize blob {0}: {1}")]
    BlobFinalize(PathBuf, std::io::Error),

    #[error("unexpected size for {path} wanted: {wanted} got: {got}")]
    UnexpectedSize {
        path: PathBuf,
        wanted: u64,
        got: u64,
    },

    #[error("blob upload join error: {0}")]
    JoinError(#[from] JoinError),
}

/// The concurrent blob uploader provides a mechanism for concurrently uploading small blobs.
/// This is useful when ingesting from sources like tarballs and archives which each blob entry
/// must be read sequentially. Ingesting many small blobs sequentially becomes slow due to
/// round trip time with the blob service. The concurrent blob uploader will buffer small
/// blobs in memory and upload them to the blob service in the background.
///
/// Once all blobs have been uploaded, make sure to call [ConcurrentBlobUploader::join] to wait
/// for all background jobs to complete and check for any errors.
pub struct ConcurrentBlobUploader<BS> {
    blob_service: BS,
    upload_tasks: JoinSet<Result<(), Error>>,
    upload_semaphore: Arc<Semaphore>,
}

impl<BS> ConcurrentBlobUploader<BS>
where
    BS: BlobService + Clone + 'static,
{
    /// Creates a new concurrent blob uploader which uploads blobs to the provided
    /// blob service.
    pub fn new(blob_service: BS) -> Self {
        Self {
            blob_service,
            upload_tasks: JoinSet::new(),
            upload_semaphore: Arc::new(Semaphore::new(MAX_BUFFER_SIZE)),
        }
    }

    /// Uploads a blob to the blob service. If the blob is small enough it will be read to a buffer
    /// and uploaded in the background.
    /// This will read the entirety of the provided reader unless an error occurs, even if blobs
    /// are uploaded in the background..
    pub async fn upload<R>(
        &mut self,
        path: &Path,
        expected_size: u64,
        mut r: R,
    ) -> Result<B3Digest, Error>
    where
        R: AsyncRead + Unpin,
    {
        if expected_size < CONCURRENT_BLOB_UPLOAD_THRESHOLD as u64 {
            let mut buffer = Vec::with_capacity(expected_size as usize);
            let mut hasher = blake3::Hasher::new();
            let mut reader = InspectReader::new(&mut r, |bytes| {
                hasher.write_all(bytes).unwrap();
            });

            let permit = self
                .upload_semaphore
                .clone()
                // This cast is safe because ensure the header_size is less than
                // CONCURRENT_BLOB_UPLOAD_THRESHOLD which is a u32.
                .acquire_many_owned(expected_size as u32)
                .await
                .unwrap();
            let size = tokio::io::copy(&mut reader, &mut buffer)
                .await
                .map_err(|e| Error::BlobRead(path.into(), e))?;
            let digest: B3Digest = hasher.finalize().as_bytes().into();

            if size != expected_size {
                return Err(Error::UnexpectedSize {
                    path: path.into(),
                    wanted: expected_size,
                    got: size,
                });
            }

            self.upload_tasks.spawn({
                let blob_service = self.blob_service.clone();
                let expected_digest = digest.clone();
                let path = path.to_owned();
                let r = Cursor::new(buffer);
                async move {
                    let digest = upload_blob(&blob_service, &path, expected_size, r).await?;

                    assert_eq!(digest, expected_digest, "Tvix bug: blob digest mismatch");

                    // Make sure we hold the permit until we finish writing the blob
                    // to the [BlobService].
                    drop(permit);
                    Ok(())
                }
            });

            return Ok(digest);
        }

        upload_blob(&self.blob_service, path, expected_size, r).await
    }

    /// Waits for all background upload jobs to complete, returning any upload errors.
    pub async fn join(mut self) -> Result<(), Error> {
        while let Some(result) = self.upload_tasks.join_next().await {
            result??;
        }
        Ok(())
    }
}

async fn upload_blob<BS, R>(
    blob_service: &BS,
    path: &Path,
    expected_size: u64,
    mut r: R,
) -> Result<B3Digest, Error>
where
    BS: BlobService,
    R: AsyncRead + Unpin,
{
    let mut writer = blob_service.open_write().await;

    let size = tokio::io::copy(&mut r, &mut writer)
        .await
        .map_err(|e| Error::BlobRead(path.into(), e))?;

    let digest = writer
        .close()
        .await
        .map_err(|e| Error::BlobFinalize(path.into(), e))?;

    if size != expected_size {
        return Err(Error::UnexpectedSize {
            path: path.into(),
            wanted: expected_size,
            got: size,
        });
    }

    Ok(digest)
}