about summary refs log tree commit diff
path: root/tvix
diff options
context:
space:
mode:
Diffstat (limited to 'tvix')
-rw-r--r--tvix/castore/src/directoryservice/grpc.rs6
-rw-r--r--tvix/castore/src/proto/grpc_directoryservice_wrapper.rs94
-rw-r--r--tvix/castore/src/utils.rs2
-rw-r--r--tvix/store/src/bin/tvix-store.rs2
4 files changed, 49 insertions, 55 deletions
diff --git a/tvix/castore/src/directoryservice/grpc.rs b/tvix/castore/src/directoryservice/grpc.rs
index 1d6ad2c13b86..c98708608e56 100644
--- a/tvix/castore/src/directoryservice/grpc.rs
+++ b/tvix/castore/src/directoryservice/grpc.rs
@@ -288,7 +288,7 @@ impl DirectoryPutter for GRPCPutter {
 mod tests {
     use core::time;
     use futures::StreamExt;
-    use std::{any::Any, sync::Arc, time::Duration};
+    use std::{any::Any, time::Duration};
     use tempfile::TempDir;
     use tokio::net::UnixListener;
     use tokio_retry::{strategy::ExponentialBackoff, Retry};
@@ -460,8 +460,8 @@ mod tests {
             let mut server = tonic::transport::Server::builder();
             let router = server.add_service(
                 crate::proto::directory_service_server::DirectoryServiceServer::new(
-                    GRPCDirectoryServiceWrapper::from(
-                        Arc::new(MemoryDirectoryService::default()) as Arc<dyn DirectoryService>
+                    GRPCDirectoryServiceWrapper::new(
+                        Box::<MemoryDirectoryService>::default() as Box<dyn DirectoryService>
                     ),
                 ),
             );
diff --git a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
index 097958050e40..b83048045861 100644
--- a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
+++ b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs
@@ -2,26 +2,27 @@ use crate::proto;
 use crate::{directoryservice::DirectoryService, B3Digest};
 use futures::StreamExt;
 use std::collections::HashMap;
-use std::sync::Arc;
-use tokio::{sync::mpsc::channel, task};
+use std::ops::Deref;
+use tokio::sync::mpsc::channel;
 use tokio_stream::wrappers::ReceiverStream;
 use tonic::{async_trait, Request, Response, Status, Streaming};
 use tracing::{debug, instrument, warn};
 
-pub struct GRPCDirectoryServiceWrapper {
-    directory_service: Arc<dyn DirectoryService>,
+pub struct GRPCDirectoryServiceWrapper<T> {
+    directory_service: T,
 }
 
-impl From<Arc<dyn DirectoryService>> for GRPCDirectoryServiceWrapper {
-    fn from(value: Arc<dyn DirectoryService>) -> Self {
-        Self {
-            directory_service: value,
-        }
+impl<T> GRPCDirectoryServiceWrapper<T> {
+    pub fn new(directory_service: T) -> Self {
+        Self { directory_service }
     }
 }
 
 #[async_trait]
-impl proto::directory_service_server::DirectoryService for GRPCDirectoryServiceWrapper {
+impl<T> proto::directory_service_server::DirectoryService for GRPCDirectoryServiceWrapper<T>
+where
+    T: Deref<Target = dyn DirectoryService> + Send + Sync + 'static,
+{
     type GetStream = ReceiverStream<tonic::Result<proto::Directory, Status>>;
 
     #[instrument(skip(self))]
@@ -33,50 +34,43 @@ impl proto::directory_service_server::DirectoryService for GRPCDirectoryServiceW
 
         let req_inner = request.into_inner();
 
-        let directory_service = self.directory_service.clone();
-
-        let _task = {
-            // look at the digest in the request and put it in the top of the queue.
-            match &req_inner.by_what {
-                None => return Err(Status::invalid_argument("by_what needs to be specified")),
-                Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => {
-                    let digest: B3Digest = digest
-                        .clone()
-                        .try_into()
-                        .map_err(|_e| Status::invalid_argument("invalid digest length"))?;
-
-                    task::spawn(async move {
-                        if !req_inner.recursive {
-                            let e: Result<proto::Directory, Status> =
-                                match directory_service.get(&digest).await {
-                                    Ok(Some(directory)) => Ok(directory),
-                                    Ok(None) => Err(Status::not_found(format!(
-                                        "directory {} not found",
-                                        digest
-                                    ))),
-                                    Err(e) => Err(e.into()),
-                                };
-
-                            if tx.send(e).await.is_err() {
-                                debug!("receiver dropped");
-                            }
-                        } else {
-                            // If recursive was requested, traverse via get_recursive.
-                            let mut directories_it = directory_service.get_recursive(&digest);
-
-                            while let Some(e) = directories_it.next().await {
-                                // map err in res from Error to Status
-                                let res = e.map_err(|e| Status::internal(e.to_string()));
-                                if tx.send(res).await.is_err() {
-                                    debug!("receiver dropped");
-                                    break;
-                                }
+        // look at the digest in the request and put it in the top of the queue.
+        match &req_inner.by_what {
+            None => return Err(Status::invalid_argument("by_what needs to be specified")),
+            Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => {
+                let digest: B3Digest = digest
+                    .clone()
+                    .try_into()
+                    .map_err(|_e| Status::invalid_argument("invalid digest length"))?;
+
+                if !req_inner.recursive {
+                    let e: Result<proto::Directory, Status> =
+                        match self.directory_service.get(&digest).await {
+                            Ok(Some(directory)) => Ok(directory),
+                            Ok(None) => {
+                                Err(Status::not_found(format!("directory {} not found", digest)))
                             }
+                            Err(e) => Err(e.into()),
+                        };
+
+                    if tx.send(e).await.is_err() {
+                        debug!("receiver dropped");
+                    }
+                } else {
+                    // If recursive was requested, traverse via get_recursive.
+                    let mut directories_it = self.directory_service.get_recursive(&digest);
+
+                    while let Some(e) = directories_it.next().await {
+                        // map err in res from Error to Status
+                        let res = e.map_err(|e| Status::internal(e.to_string()));
+                        if tx.send(res).await.is_err() {
+                            debug!("receiver dropped");
+                            break;
                         }
-                    });
+                    }
                 }
             }
-        };
+        }
 
         let receiver_stream = ReceiverStream::new(rx);
         Ok(Response::new(receiver_stream))
diff --git a/tvix/castore/src/utils.rs b/tvix/castore/src/utils.rs
index 1b0d4c674218..b24627ed9bef 100644
--- a/tvix/castore/src/utils.rs
+++ b/tvix/castore/src/utils.rs
@@ -35,7 +35,7 @@ pub(crate) async fn gen_directorysvc_grpc_client() -> DirectoryServiceClient<Cha
         // spin up a new DirectoryService
         let mut server = Server::builder();
         let router = server.add_service(DirectoryServiceServer::new(
-            GRPCDirectoryServiceWrapper::from(gen_directory_service()),
+            GRPCDirectoryServiceWrapper::new(gen_directory_service()),
         ));
 
         router
diff --git a/tvix/store/src/bin/tvix-store.rs b/tvix/store/src/bin/tvix-store.rs
index 9136c34ff885..59c39f7b4e76 100644
--- a/tvix/store/src/bin/tvix-store.rs
+++ b/tvix/store/src/bin/tvix-store.rs
@@ -201,7 +201,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
                     blob_service,
                 )))
                 .add_service(DirectoryServiceServer::new(
-                    GRPCDirectoryServiceWrapper::from(directory_service),
+                    GRPCDirectoryServiceWrapper::new(directory_service),
                 ))
                 .add_service(PathInfoServiceServer::new(GRPCPathInfoServiceWrapper::new(
                     Arc::from(path_info_service),