use std::{
io::{Cursor, Write},
sync::Arc,
};
use tokio::{
io::AsyncRead,
sync::Semaphore,
task::{JoinError, JoinSet},
};
use tokio_util::io::InspectReader;
use crate::{blobservice::BlobService, B3Digest, Path, PathBuf};
/// Files smaller than this threshold, in bytes, are uploaded to the [BlobService] in the
/// background.
///
/// This is a u32 since we acquire a weighted semaphore using the size of the blob.
/// [Semaphore::acquire_many_owned] takes a u32, so we need to ensure the size of
/// the blob can be represented using a u32 and will not cause an overflow.
const CONCURRENT_BLOB_UPLOAD_THRESHOLD: u32 = 1024 * 1024;
/// The maximum amount of bytes allowed to be buffered in memory to perform async blob uploads.
const MAX_BUFFER_SIZE: usize = 128 * 1024 * 1024;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("unable to read blob contents for {0}: {1}")]
BlobRead(PathBuf, std::io::Error),
// FUTUREWORK: proper error for blob finalize
#[error("unable to finalize blob {0}: {1}")]
BlobFinalize(PathBuf, std::io::Error),
#[error("unexpected size for {path} wanted: {wanted} got: {got}")]
UnexpectedSize {
path: PathBuf,
wanted: u64,
got: u64,
},
#[error("blob upload join error: {0}")]
JoinError(#[from] JoinError),
}
/// The concurrent blob uploader provides a mechanism for concurrently uploading small blobs.
/// This is useful when ingesting from sources like tarballs and archives which each blob entry
/// must be read sequentially. Ingesting many small blobs sequentially becomes slow due to
/// round trip time with the blob service. The concurrent blob uploader will buffer small
/// blobs in memory and upload them to the blob service in the background.
///
/// Once all blobs have been uploaded, make sure to call [ConcurrentBlobUploader::join] to wait
/// for all background jobs to complete and check for any errors.
pub struct ConcurrentBlobUploader<BS> {
blob_service: BS,
upload_tasks: JoinSet<Result<(), Error>>,
upload_semaphore: Arc<Semaphore>,
}
impl<BS> ConcurrentBlobUploader<BS>
where
BS: BlobService + Clone + 'static,
{
/// Creates a new concurrent blob uploader which uploads blobs to the provided
/// blob service.
pub fn new(blob_service: BS) -> Self {
Self {
blob_service,
upload_tasks: JoinSet::new(),
upload_semaphore: Arc::new(Semaphore::new(MAX_BUFFER_SIZE)),
}
}
/// Uploads a blob to the blob service. If the blob is small enough it will be read to a buffer
/// and uploaded in the background.
/// This will read the entirety of the provided reader unless an error occurs, even if blobs
/// are uploaded in the background..
pub async fn upload<R>(
&mut self,
path: &Path,
expected_size: u64,
mut r: R,
) -> Result<B3Digest, Error>
where
R: AsyncRead + Unpin,
{
if expected_size < CONCURRENT_BLOB_UPLOAD_THRESHOLD as u64 {
let mut buffer = Vec::with_capacity(expected_size as usize);
let mut hasher = blake3::Hasher::new();
let mut reader = InspectReader::new(&mut r, |bytes| {
hasher.write_all(bytes).unwrap();
});
let permit = self
.upload_semaphore
.clone()
// This cast is safe because ensure the header_size is less than
// CONCURRENT_BLOB_UPLOAD_THRESHOLD which is a u32.
.acquire_many_owned(expected_size as u32)
.await
.unwrap();
let size = tokio::io::copy(&mut reader, &mut buffer)
.await
.map_err(|e| Error::BlobRead(path.into(), e))?;
let digest: B3Digest = hasher.finalize().as_bytes().into();
if size != expected_size {
return Err(Error::UnexpectedSize {
path: path.into(),
wanted: expected_size,
got: size,
});
}
self.upload_tasks.spawn({
let blob_service = self.blob_service.clone();
let expected_digest = digest.clone();
let path = path.to_owned();
let r = Cursor::new(buffer);
async move {
let digest = upload_blob(&blob_service, &path, expected_size, r).await?;
assert_eq!(digest, expected_digest, "Tvix bug: blob digest mismatch");
// Make sure we hold the permit until we finish writing the blob
// to the [BlobService].
drop(permit);
Ok(())
}
});
return Ok(digest);
}
upload_blob(&self.blob_service, path, expected_size, r).await
}
/// Waits for all background upload jobs to complete, returning any upload errors.
pub async fn join(mut self) -> Result<(), Error> {
while let Some(result) = self.upload_tasks.join_next().await {
result??;
}
Ok(())
}
}
async fn upload_blob<BS, R>(
blob_service: &BS,
path: &Path,
expected_size: u64,
mut r: R,
) -> Result<B3Digest, Error>
where
BS: BlobService,
R: AsyncRead + Unpin,
{
let mut writer = blob_service.open_write().await;
let size = tokio::io::copy(&mut r, &mut writer)
.await
.map_err(|e| Error::BlobRead(path.into(), e))?;
let digest = writer
.close()
.await
.map_err(|e| Error::BlobFinalize(path.into(), e))?;
if size != expected_size {
return Err(Error::UnexpectedSize {
path: path.into(),
wanted: expected_size,
got: size,
});
}
Ok(digest)
}