diff options
Diffstat (limited to 'tvix')
-rw-r--r-- | tvix/castore/src/directoryservice/grpc.rs | 6 | ||||
-rw-r--r-- | tvix/castore/src/proto/grpc_directoryservice_wrapper.rs | 94 | ||||
-rw-r--r-- | tvix/castore/src/utils.rs | 2 | ||||
-rw-r--r-- | tvix/store/src/bin/tvix-store.rs | 2 |
4 files changed, 49 insertions, 55 deletions
diff --git a/tvix/castore/src/directoryservice/grpc.rs b/tvix/castore/src/directoryservice/grpc.rs index 1d6ad2c13b86..c98708608e56 100644 --- a/tvix/castore/src/directoryservice/grpc.rs +++ b/tvix/castore/src/directoryservice/grpc.rs @@ -288,7 +288,7 @@ impl DirectoryPutter for GRPCPutter { mod tests { use core::time; use futures::StreamExt; - use std::{any::Any, sync::Arc, time::Duration}; + use std::{any::Any, time::Duration}; use tempfile::TempDir; use tokio::net::UnixListener; use tokio_retry::{strategy::ExponentialBackoff, Retry}; @@ -460,8 +460,8 @@ mod tests { let mut server = tonic::transport::Server::builder(); let router = server.add_service( crate::proto::directory_service_server::DirectoryServiceServer::new( - GRPCDirectoryServiceWrapper::from( - Arc::new(MemoryDirectoryService::default()) as Arc<dyn DirectoryService> + GRPCDirectoryServiceWrapper::new( + Box::<MemoryDirectoryService>::default() as Box<dyn DirectoryService> ), ), ); diff --git a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs index 097958050e40..b83048045861 100644 --- a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs +++ b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs @@ -2,26 +2,27 @@ use crate::proto; use crate::{directoryservice::DirectoryService, B3Digest}; use futures::StreamExt; use std::collections::HashMap; -use std::sync::Arc; -use tokio::{sync::mpsc::channel, task}; +use std::ops::Deref; +use tokio::sync::mpsc::channel; use tokio_stream::wrappers::ReceiverStream; use tonic::{async_trait, Request, Response, Status, Streaming}; use tracing::{debug, instrument, warn}; -pub struct GRPCDirectoryServiceWrapper { - directory_service: Arc<dyn DirectoryService>, +pub struct GRPCDirectoryServiceWrapper<T> { + directory_service: T, } -impl From<Arc<dyn DirectoryService>> for GRPCDirectoryServiceWrapper { - fn from(value: Arc<dyn DirectoryService>) -> Self { - Self { - directory_service: value, - } +impl<T> GRPCDirectoryServiceWrapper<T> { + pub fn new(directory_service: T) -> Self { + Self { directory_service } } } #[async_trait] -impl proto::directory_service_server::DirectoryService for GRPCDirectoryServiceWrapper { +impl<T> proto::directory_service_server::DirectoryService for GRPCDirectoryServiceWrapper<T> +where + T: Deref<Target = dyn DirectoryService> + Send + Sync + 'static, +{ type GetStream = ReceiverStream<tonic::Result<proto::Directory, Status>>; #[instrument(skip(self))] @@ -33,50 +34,43 @@ impl proto::directory_service_server::DirectoryService for GRPCDirectoryServiceW let req_inner = request.into_inner(); - let directory_service = self.directory_service.clone(); - - let _task = { - // look at the digest in the request and put it in the top of the queue. - match &req_inner.by_what { - None => return Err(Status::invalid_argument("by_what needs to be specified")), - Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => { - let digest: B3Digest = digest - .clone() - .try_into() - .map_err(|_e| Status::invalid_argument("invalid digest length"))?; - - task::spawn(async move { - if !req_inner.recursive { - let e: Result<proto::Directory, Status> = - match directory_service.get(&digest).await { - Ok(Some(directory)) => Ok(directory), - Ok(None) => Err(Status::not_found(format!( - "directory {} not found", - digest - ))), - Err(e) => Err(e.into()), - }; - - if tx.send(e).await.is_err() { - debug!("receiver dropped"); - } - } else { - // If recursive was requested, traverse via get_recursive. - let mut directories_it = directory_service.get_recursive(&digest); - - while let Some(e) = directories_it.next().await { - // map err in res from Error to Status - let res = e.map_err(|e| Status::internal(e.to_string())); - if tx.send(res).await.is_err() { - debug!("receiver dropped"); - break; - } + // look at the digest in the request and put it in the top of the queue. + match &req_inner.by_what { + None => return Err(Status::invalid_argument("by_what needs to be specified")), + Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => { + let digest: B3Digest = digest + .clone() + .try_into() + .map_err(|_e| Status::invalid_argument("invalid digest length"))?; + + if !req_inner.recursive { + let e: Result<proto::Directory, Status> = + match self.directory_service.get(&digest).await { + Ok(Some(directory)) => Ok(directory), + Ok(None) => { + Err(Status::not_found(format!("directory {} not found", digest))) } + Err(e) => Err(e.into()), + }; + + if tx.send(e).await.is_err() { + debug!("receiver dropped"); + } + } else { + // If recursive was requested, traverse via get_recursive. + let mut directories_it = self.directory_service.get_recursive(&digest); + + while let Some(e) = directories_it.next().await { + // map err in res from Error to Status + let res = e.map_err(|e| Status::internal(e.to_string())); + if tx.send(res).await.is_err() { + debug!("receiver dropped"); + break; } - }); + } } } - }; + } let receiver_stream = ReceiverStream::new(rx); Ok(Response::new(receiver_stream)) diff --git a/tvix/castore/src/utils.rs b/tvix/castore/src/utils.rs index 1b0d4c674218..b24627ed9bef 100644 --- a/tvix/castore/src/utils.rs +++ b/tvix/castore/src/utils.rs @@ -35,7 +35,7 @@ pub(crate) async fn gen_directorysvc_grpc_client() -> DirectoryServiceClient<Cha // spin up a new DirectoryService let mut server = Server::builder(); let router = server.add_service(DirectoryServiceServer::new( - GRPCDirectoryServiceWrapper::from(gen_directory_service()), + GRPCDirectoryServiceWrapper::new(gen_directory_service()), )); router diff --git a/tvix/store/src/bin/tvix-store.rs b/tvix/store/src/bin/tvix-store.rs index 9136c34ff885..59c39f7b4e76 100644 --- a/tvix/store/src/bin/tvix-store.rs +++ b/tvix/store/src/bin/tvix-store.rs @@ -201,7 +201,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> { blob_service, ))) .add_service(DirectoryServiceServer::new( - GRPCDirectoryServiceWrapper::from(directory_service), + GRPCDirectoryServiceWrapper::new(directory_service), )) .add_service(PathInfoServiceServer::new(GRPCPathInfoServiceWrapper::new( Arc::from(path_info_service), |