diff options
Diffstat (limited to 'tvix/castore/src/proto/grpc_directoryservice_wrapper.rs')
-rw-r--r-- | tvix/castore/src/proto/grpc_directoryservice_wrapper.rs | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs new file mode 100644 index 0000000000..7d741a3f07 --- /dev/null +++ b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs @@ -0,0 +1,114 @@ +use crate::directoryservice::ClosureValidator; +use crate::proto; +use crate::{directoryservice::DirectoryService, B3Digest}; +use futures::StreamExt; +use std::ops::Deref; +use tokio::sync::mpsc::channel; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{async_trait, Request, Response, Status, Streaming}; +use tracing::{debug, instrument, warn}; + +pub struct GRPCDirectoryServiceWrapper<T> { + directory_service: T, +} + +impl<T> GRPCDirectoryServiceWrapper<T> { + pub fn new(directory_service: T) -> Self { + Self { directory_service } + } +} + +#[async_trait] +impl<T> proto::directory_service_server::DirectoryService for GRPCDirectoryServiceWrapper<T> +where + T: Deref<Target = dyn DirectoryService> + Send + Sync + 'static, +{ + type GetStream = ReceiverStream<tonic::Result<proto::Directory, Status>>; + + #[instrument(skip_all)] + async fn get( + &self, + request: Request<proto::GetDirectoryRequest>, + ) -> Result<Response<Self::GetStream>, Status> { + let (tx, rx) = channel(5); + + let req_inner = request.into_inner(); + + // look at the digest in the request and put it in the top of the queue. + match &req_inner.by_what { + None => return Err(Status::invalid_argument("by_what needs to be specified")), + Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => { + let digest: B3Digest = digest + .clone() + .try_into() + .map_err(|_e| Status::invalid_argument("invalid digest length"))?; + + if !req_inner.recursive { + let e: Result<proto::Directory, Status> = match self + .directory_service + .get(&digest) + .await + { + Ok(Some(directory)) => Ok(directory), + Ok(None) => { + Err(Status::not_found(format!("directory {} not found", digest))) + } + Err(e) => { + warn!(err = %e, directory.digest=%digest, "failed to get directory"); + Err(e.into()) + } + }; + + if tx.send(e).await.is_err() { + debug!("receiver dropped"); + } + } else { + // If recursive was requested, traverse via get_recursive. + let mut directories_it = self.directory_service.get_recursive(&digest); + + while let Some(e) = directories_it.next().await { + // map err in res from Error to Status + let res = e.map_err(|e| Status::internal(e.to_string())); + if tx.send(res).await.is_err() { + debug!("receiver dropped"); + break; + } + } + } + } + } + + let receiver_stream = ReceiverStream::new(rx); + Ok(Response::new(receiver_stream)) + } + + #[instrument(skip_all)] + async fn put( + &self, + request: Request<Streaming<proto::Directory>>, + ) -> Result<Response<proto::PutDirectoryResponse>, Status> { + let mut req_inner = request.into_inner(); + + // We put all Directory messages we receive into ClosureValidator first. + let mut validator = ClosureValidator::default(); + while let Some(directory) = req_inner.message().await? { + validator.add(directory)?; + } + + // drain, which validates connectivity too. + let directories = validator.finalize()?; + + let mut directory_putter = self.directory_service.put_multiple_start(); + for directory in directories { + directory_putter.put(directory).await?; + } + + // Properly close the directory putter. Peek at last_directory_digest + // and return it, or propagate errors. + let last_directory_dgst = directory_putter.close().await?; + + Ok(Response::new(proto::PutDirectoryResponse { + root_digest: last_directory_dgst.into(), + })) + } +} |