diff options
Diffstat (limited to 'tvix/castore/src/proto/grpc_directoryservice_wrapper.rs')
-rw-r--r-- | tvix/castore/src/proto/grpc_directoryservice_wrapper.rs | 94 |
1 files changed, 44 insertions, 50 deletions
diff --git a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs index 097958050e40..b83048045861 100644 --- a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs +++ b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs @@ -2,26 +2,27 @@ use crate::proto; use crate::{directoryservice::DirectoryService, B3Digest}; use futures::StreamExt; use std::collections::HashMap; -use std::sync::Arc; -use tokio::{sync::mpsc::channel, task}; +use std::ops::Deref; +use tokio::sync::mpsc::channel; use tokio_stream::wrappers::ReceiverStream; use tonic::{async_trait, Request, Response, Status, Streaming}; use tracing::{debug, instrument, warn}; -pub struct GRPCDirectoryServiceWrapper { - directory_service: Arc<dyn DirectoryService>, +pub struct GRPCDirectoryServiceWrapper<T> { + directory_service: T, } -impl From<Arc<dyn DirectoryService>> for GRPCDirectoryServiceWrapper { - fn from(value: Arc<dyn DirectoryService>) -> Self { - Self { - directory_service: value, - } +impl<T> GRPCDirectoryServiceWrapper<T> { + pub fn new(directory_service: T) -> Self { + Self { directory_service } } } #[async_trait] -impl proto::directory_service_server::DirectoryService for GRPCDirectoryServiceWrapper { +impl<T> proto::directory_service_server::DirectoryService for GRPCDirectoryServiceWrapper<T> +where + T: Deref<Target = dyn DirectoryService> + Send + Sync + 'static, +{ type GetStream = ReceiverStream<tonic::Result<proto::Directory, Status>>; #[instrument(skip(self))] @@ -33,50 +34,43 @@ impl proto::directory_service_server::DirectoryService for GRPCDirectoryServiceW let req_inner = request.into_inner(); - let directory_service = self.directory_service.clone(); - - let _task = { - // look at the digest in the request and put it in the top of the queue. - match &req_inner.by_what { - None => return Err(Status::invalid_argument("by_what needs to be specified")), - Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => { - let digest: B3Digest = digest - .clone() - .try_into() - .map_err(|_e| Status::invalid_argument("invalid digest length"))?; - - task::spawn(async move { - if !req_inner.recursive { - let e: Result<proto::Directory, Status> = - match directory_service.get(&digest).await { - Ok(Some(directory)) => Ok(directory), - Ok(None) => Err(Status::not_found(format!( - "directory {} not found", - digest - ))), - Err(e) => Err(e.into()), - }; - - if tx.send(e).await.is_err() { - debug!("receiver dropped"); - } - } else { - // If recursive was requested, traverse via get_recursive. - let mut directories_it = directory_service.get_recursive(&digest); - - while let Some(e) = directories_it.next().await { - // map err in res from Error to Status - let res = e.map_err(|e| Status::internal(e.to_string())); - if tx.send(res).await.is_err() { - debug!("receiver dropped"); - break; - } + // look at the digest in the request and put it in the top of the queue. + match &req_inner.by_what { + None => return Err(Status::invalid_argument("by_what needs to be specified")), + Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => { + let digest: B3Digest = digest + .clone() + .try_into() + .map_err(|_e| Status::invalid_argument("invalid digest length"))?; + + if !req_inner.recursive { + let e: Result<proto::Directory, Status> = + match self.directory_service.get(&digest).await { + Ok(Some(directory)) => Ok(directory), + Ok(None) => { + Err(Status::not_found(format!("directory {} not found", digest))) } + Err(e) => Err(e.into()), + }; + + if tx.send(e).await.is_err() { + debug!("receiver dropped"); + } + } else { + // If recursive was requested, traverse via get_recursive. + let mut directories_it = self.directory_service.get_recursive(&digest); + + while let Some(e) = directories_it.next().await { + // map err in res from Error to Status + let res = e.map_err(|e| Status::internal(e.to_string())); + if tx.send(res).await.is_err() { + debug!("receiver dropped"); + break; } - }); + } } } - }; + } let receiver_stream = ReceiverStream::new(rx); Ok(Response::new(receiver_stream)) |