about summary refs log tree commit diff
path: root/tvix/castore
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/castore')
-rw-r--r--tvix/castore/src/proto/grpc_directoryservice_wrapper.rs83
1 files changed, 36 insertions, 47 deletions
diff --git a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
index 7d741a3f07bb..5c1428690cb4 100644
--- a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
+++ b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
@@ -1,12 +1,12 @@
 use crate::directoryservice::ClosureValidator;
 use crate::proto;
 use crate::{directoryservice::DirectoryService, B3Digest};
-use futures::StreamExt;
+use futures::stream::BoxStream;
+use futures::TryStreamExt;
 use std::ops::Deref;
-use tokio::sync::mpsc::channel;
-use tokio_stream::wrappers::ReceiverStream;
+use tokio_stream::once;
 use tonic::{async_trait, Request, Response, Status, Streaming};
-use tracing::{debug, instrument, warn};
+use tracing::{instrument, warn};
 
 pub struct GRPCDirectoryServiceWrapper<T> {
     directory_service: T,
@@ -23,63 +23,52 @@ impl<T> proto::directory_service_server::DirectoryService for GRPCDirectoryServi
 where
     T: Deref<Target = dyn DirectoryService> + Send + Sync + 'static,
 {
-    type GetStream = ReceiverStream<tonic::Result<proto::Directory, Status>>;
+    type GetStream = BoxStream<'static, tonic::Result<proto::Directory, Status>>;
 
     #[instrument(skip_all)]
-    async fn get(
-        &self,
+    async fn get<'a>(
+        &'a self,
         request: Request<proto::GetDirectoryRequest>,
     ) -> Result<Response<Self::GetStream>, Status> {
-        let (tx, rx) = channel(5);
-
         let req_inner = request.into_inner();
 
-        // look at the digest in the request and put it in the top of the queue.
-        match &req_inner.by_what {
-            None => return Err(Status::invalid_argument("by_what needs to be specified")),
-            Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => {
+        let by_what = &req_inner
+            .by_what
+            .ok_or_else(|| Status::invalid_argument("invalid by_what"))?;
+
+        match by_what {
+            proto::get_directory_request::ByWhat::Digest(ref digest) => {
                 let digest: B3Digest = digest
                     .clone()
                     .try_into()
                     .map_err(|_e| Status::invalid_argument("invalid digest length"))?;
 
-                if !req_inner.recursive {
-                    let e: Result<proto::Directory, Status> = match self
-                        .directory_service
-                        .get(&digest)
-                        .await
-                    {
-                        Ok(Some(directory)) => Ok(directory),
-                        Ok(None) => {
-                            Err(Status::not_found(format!("directory {} not found", digest)))
-                        }
-                        Err(e) => {
-                            warn!(err = %e, directory.digest=%digest, "failed to get directory");
-                            Err(e.into())
-                        }
-                    };
-
-                    if tx.send(e).await.is_err() {
-                        debug!("receiver dropped");
+                Ok(tonic::Response::new({
+                    if !req_inner.recursive {
+                        let directory = self
+                            .directory_service
+                            .get(&digest)
+                            .await
+                            .map_err(|e| {
+                                warn!(err = %e, directory.digest=%digest, "failed to get directory");
+                                tonic::Status::new(tonic::Code::Internal, e.to_string())
+                            })?
+                            .ok_or_else(|| {
+                                Status::not_found(format!("directory {} not found", digest))
+                            })?;
+
+                        Box::pin(once(Ok(directory)))
+                    } else {
+                        // If recursive was requested, traverse via get_recursive.
+                        Box::pin(
+                            self.directory_service.get_recursive(&digest).map_err(|e| {
+                                tonic::Status::new(tonic::Code::Internal, e.to_string())
+                            }),
+                        )
                     }
-                } else {
-                    // If recursive was requested, traverse via get_recursive.
-                    let mut directories_it = self.directory_service.get_recursive(&digest);
-
-                    while let Some(e) = directories_it.next().await {
-                        // map err in res from Error to Status
-                        let res = e.map_err(|e| Status::internal(e.to_string()));
-                        if tx.send(res).await.is_err() {
-                            debug!("receiver dropped");
-                            break;
-                        }
-                    }
-                }
+                }))
             }
         }
-
-        let receiver_stream = ReceiverStream::new(rx);
-        Ok(Response::new(receiver_stream))
     }
 
     #[instrument(skip_all)]