about summary refs log tree commit diff
path: root/tvix/store/src/blobservice/sled.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/store/src/blobservice/sled.rs')
-rw-r--r--tvix/store/src/blobservice/sled.rs103
1 files changed, 64 insertions, 39 deletions
diff --git a/tvix/store/src/blobservice/sled.rs b/tvix/store/src/blobservice/sled.rs
index 1ae34ee5fb9c..c229f799de98 100644
--- a/tvix/store/src/blobservice/sled.rs
+++ b/tvix/store/src/blobservice/sled.rs
@@ -1,13 +1,13 @@
-use std::path::PathBuf;
+use std::{
+    io::{self, Cursor},
+    path::PathBuf,
+};
 
+use super::{BlobService, BlobWriter};
+use crate::Error;
 use data_encoding::BASE64;
-use prost::Message;
 use tracing::instrument;
 
-use crate::{proto, Error};
-
-use super::BlobService;
-
 #[derive(Clone)]
 pub struct SledBlobService {
     db: sled::Db,
@@ -30,44 +30,69 @@ impl SledBlobService {
 }
 
 impl BlobService for SledBlobService {
-    #[instrument(name = "SledBlobService::stat", skip(self, req), fields(blob.digest=BASE64.encode(&req.digest)))]
-    fn stat(&self, req: &proto::StatBlobRequest) -> Result<Option<proto::BlobMeta>, Error> {
-        if req.include_bao {
-            todo!("not implemented yet")
+    type BlobReader = Cursor<Vec<u8>>;
+    type BlobWriter = SledBlobWriter;
+
+    #[instrument(name = "SledBlobService::has", skip(self), fields(blob.digest=BASE64.encode(digest)))]
+    fn has(&self, digest: &[u8; 32]) -> Result<bool, Error> {
+        match self.db.contains_key(digest) {
+            Ok(has) => Ok(has),
+            Err(e) => Err(Error::StorageError(e.to_string())),
         }
+    }
 
-        // if include_chunks is also false, the user only wants to know if the
-        // blob is present at all.
-        if !req.include_chunks {
-            match self.db.contains_key(&req.digest) {
-                Ok(false) => Ok(None),
-                Ok(true) => Ok(Some(proto::BlobMeta::default())),
-                Err(e) => Err(Error::StorageError(e.to_string())),
-            }
-        } else {
-            match self.db.get(&req.digest) {
-                Ok(None) => Ok(None),
-                Ok(Some(data)) => match proto::BlobMeta::decode(&*data) {
-                    Ok(blob_meta) => Ok(Some(blob_meta)),
-                    Err(e) => Err(Error::StorageError(format!(
-                        "unable to parse blobmeta message for blob {}: {}",
-                        BASE64.encode(&req.digest),
-                        e
-                    ))),
-                },
-                Err(e) => Err(Error::StorageError(e.to_string())),
-            }
+    #[instrument(name = "SledBlobService::open_read", skip(self), fields(blob.digest=BASE64.encode(digest)))]
+    fn open_read(&self, digest: &[u8; 32]) -> Result<Option<Self::BlobReader>, Error> {
+        match self.db.get(digest) {
+            Ok(None) => Ok(None),
+            Ok(Some(data)) => Ok(Some(Cursor::new(data[..].to_vec()))),
+            Err(e) => Err(Error::StorageError(e.to_string())),
         }
     }
 
-    #[instrument(name = "SledBlobService::put", skip(self, blob_meta, blob_digest), fields(blob.digest = BASE64.encode(blob_digest)))]
-    fn put(&self, blob_digest: &[u8], blob_meta: proto::BlobMeta) -> Result<(), Error> {
-        let result = self.db.insert(blob_digest, blob_meta.encode_to_vec());
-        if let Err(e) = result {
-            return Err(Error::StorageError(e.to_string()));
+    #[instrument(name = "SledBlobService::open_write", skip(self))]
+    fn open_write(&self) -> Result<Self::BlobWriter, Error> {
+        Ok(SledBlobWriter::new(self.db.clone()))
+    }
+}
+
+pub struct SledBlobWriter {
+    db: sled::Db,
+    buf: Vec<u8>,
+    hasher: blake3::Hasher,
+}
+
+impl SledBlobWriter {
+    pub fn new(db: sled::Db) -> Self {
+        Self {
+            buf: Vec::default(),
+            db,
+            hasher: blake3::Hasher::new(),
         }
-        Ok(())
-        // TODO: make sure all callers make sure the chunks exist.
-        // TODO: where should we calculate the bao?
+    }
+}
+
+impl io::Write for SledBlobWriter {
+    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+        let bytes_written = self.buf.write(buf)?;
+        self.hasher.write(&buf[..bytes_written])
+    }
+
+    fn flush(&mut self) -> io::Result<()> {
+        self.buf.flush()
+    }
+}
+
+impl BlobWriter for SledBlobWriter {
+    fn close(self) -> Result<[u8; 32], Error> {
+        let digest = self.hasher.finalize();
+        self.db.insert(digest.as_bytes(), self.buf).map_err(|e| {
+            Error::StorageError(format!("unable to insert blob: {}", e.to_string()))
+        })?;
+
+        Ok(digest
+            .to_owned()
+            .try_into()
+            .map_err(|_| Error::StorageError("invalid digest length in response".to_string()))?)
     }
 }