about summary refs log tree commit diff
path: root/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/castore/src/proto/grpc_directoryservice_wrapper.rs')
-rw-r--r--tvix/castore/src/proto/grpc_directoryservice_wrapper.rs107
1 files changed, 53 insertions, 54 deletions
diff --git a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
index 7d741a3f07bb..62fdb34a25a0 100644
--- a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
+++ b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
@@ -1,12 +1,11 @@
-use crate::directoryservice::ClosureValidator;
-use crate::proto;
-use crate::{directoryservice::DirectoryService, B3Digest};
-use futures::StreamExt;
+use crate::directoryservice::{DirectoryGraph, DirectoryService, LeavesToRootValidator};
+use crate::{proto, B3Digest, DirectoryError};
+use futures::stream::BoxStream;
+use futures::TryStreamExt;
 use std::ops::Deref;
-use tokio::sync::mpsc::channel;
-use tokio_stream::wrappers::ReceiverStream;
+use tokio_stream::once;
 use tonic::{async_trait, Request, Response, Status, Streaming};
-use tracing::{debug, instrument, warn};
+use tracing::{instrument, warn};
 
 pub struct GRPCDirectoryServiceWrapper<T> {
     directory_service: T,
@@ -23,63 +22,55 @@ impl<T> proto::directory_service_server::DirectoryService for GRPCDirectoryServi
 where
     T: Deref<Target = dyn DirectoryService> + Send + Sync + 'static,
 {
-    type GetStream = ReceiverStream<tonic::Result<proto::Directory, Status>>;
+    type GetStream = BoxStream<'static, tonic::Result<proto::Directory, Status>>;
 
     #[instrument(skip_all)]
-    async fn get(
-        &self,
+    async fn get<'a>(
+        &'a self,
         request: Request<proto::GetDirectoryRequest>,
     ) -> Result<Response<Self::GetStream>, Status> {
-        let (tx, rx) = channel(5);
-
         let req_inner = request.into_inner();
 
-        // look at the digest in the request and put it in the top of the queue.
-        match &req_inner.by_what {
-            None => return Err(Status::invalid_argument("by_what needs to be specified")),
-            Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => {
+        let by_what = &req_inner
+            .by_what
+            .ok_or_else(|| Status::invalid_argument("invalid by_what"))?;
+
+        match by_what {
+            proto::get_directory_request::ByWhat::Digest(ref digest) => {
                 let digest: B3Digest = digest
                     .clone()
                     .try_into()
                     .map_err(|_e| Status::invalid_argument("invalid digest length"))?;
 
-                if !req_inner.recursive {
-                    let e: Result<proto::Directory, Status> = match self
-                        .directory_service
-                        .get(&digest)
-                        .await
-                    {
-                        Ok(Some(directory)) => Ok(directory),
-                        Ok(None) => {
-                            Err(Status::not_found(format!("directory {} not found", digest)))
-                        }
-                        Err(e) => {
-                            warn!(err = %e, directory.digest=%digest, "failed to get directory");
-                            Err(e.into())
-                        }
-                    };
-
-                    if tx.send(e).await.is_err() {
-                        debug!("receiver dropped");
+                Ok(tonic::Response::new({
+                    if !req_inner.recursive {
+                        let directory = self
+                            .directory_service
+                            .get(&digest)
+                            .await
+                            .map_err(|e| {
+                                warn!(err = %e, directory.digest=%digest, "failed to get directory");
+                                tonic::Status::new(tonic::Code::Internal, e.to_string())
+                            })?
+                            .ok_or_else(|| {
+                                Status::not_found(format!("directory {} not found", digest))
+                            })?;
+
+                        Box::pin(once(Ok(directory.into())))
+                    } else {
+                        // If recursive was requested, traverse via get_recursive.
+                        Box::pin(
+                            self.directory_service
+                                .get_recursive(&digest)
+                                .map_ok(proto::Directory::from)
+                                .map_err(|e| {
+                                    tonic::Status::new(tonic::Code::Internal, e.to_string())
+                                }),
+                        )
                     }
-                } else {
-                    // If recursive was requested, traverse via get_recursive.
-                    let mut directories_it = self.directory_service.get_recursive(&digest);
-
-                    while let Some(e) = directories_it.next().await {
-                        // map err in res from Error to Status
-                        let res = e.map_err(|e| Status::internal(e.to_string()));
-                        if tx.send(res).await.is_err() {
-                            debug!("receiver dropped");
-                            break;
-                        }
-                    }
-                }
+                }))
             }
         }
-
-        let receiver_stream = ReceiverStream::new(rx);
-        Ok(Response::new(receiver_stream))
     }
 
     #[instrument(skip_all)]
@@ -89,14 +80,22 @@ where
     ) -> Result<Response<proto::PutDirectoryResponse>, Status> {
         let mut req_inner = request.into_inner();
 
-        // We put all Directory messages we receive into ClosureValidator first.
-        let mut validator = ClosureValidator::default();
+        // We put all Directory messages we receive into DirectoryGraph.
+        let mut validator = DirectoryGraph::<LeavesToRootValidator>::default();
         while let Some(directory) = req_inner.message().await? {
-            validator.add(directory)?;
+            validator
+                .add(directory.try_into().map_err(|e: DirectoryError| {
+                    tonic::Status::new(tonic::Code::Internal, e.to_string())
+                })?)
+                .map_err(|e| tonic::Status::new(tonic::Code::Internal, e.to_string()))?;
         }
 
         // drain, which validates connectivity too.
-        let directories = validator.finalize()?;
+        let directories = validator
+            .validate()
+            .map_err(|e| tonic::Status::new(tonic::Code::Internal, e.to_string()))?
+            .drain_leaves_to_root()
+            .collect::<Vec<_>>();
 
         let mut directory_putter = self.directory_service.put_multiple_start();
         for directory in directories {