about summary refs log tree commit diff
path: root/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/castore/src/proto/grpc_directoryservice_wrapper.rs')
-rw-r--r--tvix/castore/src/proto/grpc_directoryservice_wrapper.rs94
1 files changed, 44 insertions, 50 deletions
diff --git a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
index 097958050e40..b83048045861 100644
--- a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
+++ b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
@@ -2,26 +2,27 @@ use crate::proto;
 use crate::{directoryservice::DirectoryService, B3Digest};
 use futures::StreamExt;
 use std::collections::HashMap;
-use std::sync::Arc;
-use tokio::{sync::mpsc::channel, task};
+use std::ops::Deref;
+use tokio::sync::mpsc::channel;
 use tokio_stream::wrappers::ReceiverStream;
 use tonic::{async_trait, Request, Response, Status, Streaming};
 use tracing::{debug, instrument, warn};
 
-pub struct GRPCDirectoryServiceWrapper {
-    directory_service: Arc<dyn DirectoryService>,
+pub struct GRPCDirectoryServiceWrapper<T> {
+    directory_service: T,
 }
 
-impl From<Arc<dyn DirectoryService>> for GRPCDirectoryServiceWrapper {
-    fn from(value: Arc<dyn DirectoryService>) -> Self {
-        Self {
-            directory_service: value,
-        }
+impl<T> GRPCDirectoryServiceWrapper<T> {
+    pub fn new(directory_service: T) -> Self {
+        Self { directory_service }
     }
 }
 
 #[async_trait]
-impl proto::directory_service_server::DirectoryService for GRPCDirectoryServiceWrapper {
+impl<T> proto::directory_service_server::DirectoryService for GRPCDirectoryServiceWrapper<T>
+where
+    T: Deref<Target = dyn DirectoryService> + Send + Sync + 'static,
+{
     type GetStream = ReceiverStream<tonic::Result<proto::Directory, Status>>;
 
     #[instrument(skip(self))]
@@ -33,50 +34,43 @@ impl proto::directory_service_server::DirectoryService for GRPCDirectoryServiceW
 
         let req_inner = request.into_inner();
 
-        let directory_service = self.directory_service.clone();
-
-        let _task = {
-            // look at the digest in the request and put it in the top of the queue.
-            match &req_inner.by_what {
-                None => return Err(Status::invalid_argument("by_what needs to be specified")),
-                Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => {
-                    let digest: B3Digest = digest
-                        .clone()
-                        .try_into()
-                        .map_err(|_e| Status::invalid_argument("invalid digest length"))?;
-
-                    task::spawn(async move {
-                        if !req_inner.recursive {
-                            let e: Result<proto::Directory, Status> =
-                                match directory_service.get(&digest).await {
-                                    Ok(Some(directory)) => Ok(directory),
-                                    Ok(None) => Err(Status::not_found(format!(
-                                        "directory {} not found",
-                                        digest
-                                    ))),
-                                    Err(e) => Err(e.into()),
-                                };
-
-                            if tx.send(e).await.is_err() {
-                                debug!("receiver dropped");
-                            }
-                        } else {
-                            // If recursive was requested, traverse via get_recursive.
-                            let mut directories_it = directory_service.get_recursive(&digest);
-
-                            while let Some(e) = directories_it.next().await {
-                                // map err in res from Error to Status
-                                let res = e.map_err(|e| Status::internal(e.to_string()));
-                                if tx.send(res).await.is_err() {
-                                    debug!("receiver dropped");
-                                    break;
-                                }
+        // look at the digest in the request and put it in the top of the queue.
+        match &req_inner.by_what {
+            None => return Err(Status::invalid_argument("by_what needs to be specified")),
+            Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => {
+                let digest: B3Digest = digest
+                    .clone()
+                    .try_into()
+                    .map_err(|_e| Status::invalid_argument("invalid digest length"))?;
+
+                if !req_inner.recursive {
+                    let e: Result<proto::Directory, Status> =
+                        match self.directory_service.get(&digest).await {
+                            Ok(Some(directory)) => Ok(directory),
+                            Ok(None) => {
+                                Err(Status::not_found(format!("directory {} not found", digest)))
                             }
+                            Err(e) => Err(e.into()),
+                        };
+
+                    if tx.send(e).await.is_err() {
+                        debug!("receiver dropped");
+                    }
+                } else {
+                    // If recursive was requested, traverse via get_recursive.
+                    let mut directories_it = self.directory_service.get_recursive(&digest);
+
+                    while let Some(e) = directories_it.next().await {
+                        // map err in res from Error to Status
+                        let res = e.map_err(|e| Status::internal(e.to_string()));
+                        if tx.send(res).await.is_err() {
+                            debug!("receiver dropped");
+                            break;
                         }
-                    });
+                    }
                 }
             }
-        };
+        }
 
         let receiver_stream = ReceiverStream::new(rx);
         Ok(Response::new(receiver_stream))