diff options
Diffstat (limited to 'tvix/store/src/blobservice/sled.rs')
-rw-r--r-- | tvix/store/src/blobservice/sled.rs | 103 |
1 files changed, 64 insertions, 39 deletions
diff --git a/tvix/store/src/blobservice/sled.rs b/tvix/store/src/blobservice/sled.rs index 1ae34ee5fb9c..c229f799de98 100644 --- a/tvix/store/src/blobservice/sled.rs +++ b/tvix/store/src/blobservice/sled.rs @@ -1,13 +1,13 @@ -use std::path::PathBuf; +use std::{ + io::{self, Cursor}, + path::PathBuf, +}; +use super::{BlobService, BlobWriter}; +use crate::Error; use data_encoding::BASE64; -use prost::Message; use tracing::instrument; -use crate::{proto, Error}; - -use super::BlobService; - #[derive(Clone)] pub struct SledBlobService { db: sled::Db, @@ -30,44 +30,69 @@ impl SledBlobService { } impl BlobService for SledBlobService { - #[instrument(name = "SledBlobService::stat", skip(self, req), fields(blob.digest=BASE64.encode(&req.digest)))] - fn stat(&self, req: &proto::StatBlobRequest) -> Result<Option<proto::BlobMeta>, Error> { - if req.include_bao { - todo!("not implemented yet") + type BlobReader = Cursor<Vec<u8>>; + type BlobWriter = SledBlobWriter; + + #[instrument(name = "SledBlobService::has", skip(self), fields(blob.digest=BASE64.encode(digest)))] + fn has(&self, digest: &[u8; 32]) -> Result<bool, Error> { + match self.db.contains_key(digest) { + Ok(has) => Ok(has), + Err(e) => Err(Error::StorageError(e.to_string())), } + } - // if include_chunks is also false, the user only wants to know if the - // blob is present at all. - if !req.include_chunks { - match self.db.contains_key(&req.digest) { - Ok(false) => Ok(None), - Ok(true) => Ok(Some(proto::BlobMeta::default())), - Err(e) => Err(Error::StorageError(e.to_string())), - } - } else { - match self.db.get(&req.digest) { - Ok(None) => Ok(None), - Ok(Some(data)) => match proto::BlobMeta::decode(&*data) { - Ok(blob_meta) => Ok(Some(blob_meta)), - Err(e) => Err(Error::StorageError(format!( - "unable to parse blobmeta message for blob {}: {}", - BASE64.encode(&req.digest), - e - ))), - }, - Err(e) => Err(Error::StorageError(e.to_string())), - } + #[instrument(name = "SledBlobService::open_read", skip(self), fields(blob.digest=BASE64.encode(digest)))] + fn open_read(&self, digest: &[u8; 32]) -> Result<Option<Self::BlobReader>, Error> { + match self.db.get(digest) { + Ok(None) => Ok(None), + Ok(Some(data)) => Ok(Some(Cursor::new(data[..].to_vec()))), + Err(e) => Err(Error::StorageError(e.to_string())), } } - #[instrument(name = "SledBlobService::put", skip(self, blob_meta, blob_digest), fields(blob.digest = BASE64.encode(blob_digest)))] - fn put(&self, blob_digest: &[u8], blob_meta: proto::BlobMeta) -> Result<(), Error> { - let result = self.db.insert(blob_digest, blob_meta.encode_to_vec()); - if let Err(e) = result { - return Err(Error::StorageError(e.to_string())); + #[instrument(name = "SledBlobService::open_write", skip(self))] + fn open_write(&self) -> Result<Self::BlobWriter, Error> { + Ok(SledBlobWriter::new(self.db.clone())) + } +} + +pub struct SledBlobWriter { + db: sled::Db, + buf: Vec<u8>, + hasher: blake3::Hasher, +} + +impl SledBlobWriter { + pub fn new(db: sled::Db) -> Self { + Self { + buf: Vec::default(), + db, + hasher: blake3::Hasher::new(), } - Ok(()) - // TODO: make sure all callers make sure the chunks exist. - // TODO: where should we calculate the bao? + } +} + +impl io::Write for SledBlobWriter { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + let bytes_written = self.buf.write(buf)?; + self.hasher.write(&buf[..bytes_written]) + } + + fn flush(&mut self) -> io::Result<()> { + self.buf.flush() + } +} + +impl BlobWriter for SledBlobWriter { + fn close(self) -> Result<[u8; 32], Error> { + let digest = self.hasher.finalize(); + self.db.insert(digest.as_bytes(), self.buf).map_err(|e| { + Error::StorageError(format!("unable to insert blob: {}", e.to_string())) + })?; + + Ok(digest + .to_owned() + .try_into() + .map_err(|_| Error::StorageError("invalid digest length in response".to_string()))?) } } |