diff options
Diffstat (limited to 'tvix/store/src/blobservice/grpc.rs')
-rw-r--r-- | tvix/store/src/blobservice/grpc.rs | 248 |
1 files changed, 123 insertions, 125 deletions
diff --git a/tvix/store/src/blobservice/grpc.rs b/tvix/store/src/blobservice/grpc.rs index c6d28860f8aa..cea796033ea9 100644 --- a/tvix/store/src/blobservice/grpc.rs +++ b/tvix/store/src/blobservice/grpc.rs @@ -1,22 +1,26 @@ -use super::{dumb_seeker::DumbSeeker, BlobReader, BlobService, BlobWriter}; +use super::{naive_seeker::NaiveSeeker, BlobReader, BlobService, BlobWriter}; use crate::{proto, B3Digest}; -use futures::sink::{SinkExt, SinkMapErr}; -use std::{collections::VecDeque, io}; +use futures::sink::SinkExt; +use futures::TryFutureExt; +use std::{ + collections::VecDeque, + io::{self}, + pin::pin, + task::Poll, +}; +use tokio::io::AsyncWriteExt; use tokio::{net::UnixStream, task::JoinHandle}; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tokio_util::{ - io::{CopyToBytes, SinkWriter, SyncIoBridge}, + io::{CopyToBytes, SinkWriter}, sync::{PollSendError, PollSender}, }; -use tonic::{transport::Channel, Code, Status, Streaming}; +use tonic::{async_trait, transport::Channel, Code, Status}; use tracing::instrument; /// Connects to a (remote) tvix-store BlobService over gRPC. #[derive(Clone)] pub struct GRPCBlobService { - /// A handle into the active tokio runtime. Necessary to spawn tasks. - tokio_handle: tokio::runtime::Handle, - /// The internal reference to a gRPC client. /// Cloning it is cheap, and it internally handles concurrent requests. grpc_client: proto::blob_service_client::BlobServiceClient<Channel>, @@ -28,13 +32,11 @@ impl GRPCBlobService { pub fn from_client( grpc_client: proto::blob_service_client::BlobServiceClient<Channel>, ) -> Self { - Self { - tokio_handle: tokio::runtime::Handle::current(), - grpc_client, - } + Self { grpc_client } } } +#[async_trait] impl BlobService for GRPCBlobService { /// Constructs a [GRPCBlobService] from the passed [url::Url]: /// - scheme has to match `grpc+*://`. @@ -89,22 +91,16 @@ impl BlobService for GRPCBlobService { } #[instrument(skip(self, digest), fields(blob.digest=%digest))] - fn has(&self, digest: &B3Digest) -> Result<bool, crate::Error> { - // Get a new handle to the gRPC client, and copy the digest. + async fn has(&self, digest: &B3Digest) -> Result<bool, crate::Error> { let mut grpc_client = self.grpc_client.clone(); - let digest = digest.clone(); - - let task: JoinHandle<Result<_, Status>> = self.tokio_handle.spawn(async move { - Ok(grpc_client - .stat(proto::StatBlobRequest { - digest: digest.into(), - ..Default::default() - }) - .await? - .into_inner()) - }); - - match self.tokio_handle.block_on(task)? { + let resp = grpc_client + .stat(proto::StatBlobRequest { + digest: digest.clone().into(), + ..Default::default() + }) + .await; + + match resp { Ok(_blob_meta) => Ok(true), Err(e) if e.code() == Code::NotFound => Ok(false), Err(e) => Err(crate::Error::StorageError(e.to_string())), @@ -113,35 +109,30 @@ impl BlobService for GRPCBlobService { // On success, this returns a Ok(Some(io::Read)), which can be used to read // the contents of the Blob, identified by the digest. - fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, crate::Error> { + async fn open_read( + &self, + digest: &B3Digest, + ) -> Result<Option<Box<dyn BlobReader>>, crate::Error> { // Get a new handle to the gRPC client, and copy the digest. let mut grpc_client = self.grpc_client.clone(); - let digest = digest.clone(); - - // Construct the task that'll send out the request and return the stream - // the gRPC client should use to send [proto::BlobChunk], or an error if - // the blob doesn't exist. - let task: JoinHandle<Result<Streaming<proto::BlobChunk>, Status>> = - self.tokio_handle.spawn(async move { - let stream = grpc_client - .read(proto::ReadBlobRequest { - digest: digest.into(), - }) - .await? - .into_inner(); - - Ok(stream) - }); + + // Get a stream of [proto::BlobChunk], or return an error if the blob + // doesn't exist. + let resp = grpc_client + .read(proto::ReadBlobRequest { + digest: digest.clone().into(), + }) + .await; // This runs the task to completion, which on success will return a stream. // On reading from it, we receive individual [proto::BlobChunk], so we // massage this to a stream of bytes, // then create an [AsyncRead], which we'll turn into a [io::Read], // that's returned from the function. - match self.tokio_handle.block_on(task)? { + match resp { Ok(stream) => { // map the stream of proto::BlobChunk to bytes. - let data_stream = stream.map(|x| { + let data_stream = stream.into_inner().map(|x| { x.map(|x| VecDeque::from(x.data.to_vec())) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e)) }); @@ -149,9 +140,7 @@ impl BlobService for GRPCBlobService { // Use StreamReader::new to convert to an AsyncRead. let data_reader = tokio_util::io::StreamReader::new(data_stream); - // Use SyncIoBridge to turn it into a sync Read. - let sync_reader = tokio_util::io::SyncIoBridge::new(data_reader); - Ok(Some(Box::new(DumbSeeker::new(sync_reader)))) + Ok(Some(Box::new(NaiveSeeker::new(data_reader)))) } Err(e) if e.code() == Code::NotFound => Ok(None), Err(e) => Err(crate::Error::StorageError(e.to_string())), @@ -160,7 +149,7 @@ impl BlobService for GRPCBlobService { /// Returns a BlobWriter, that'll internally wrap each write in a // [proto::BlobChunk], which is send to the gRPC server. - fn open_write(&self) -> Box<dyn BlobWriter> { + async fn open_write(&self) -> Box<dyn BlobWriter> { let mut grpc_client = self.grpc_client.clone(); // set up an mpsc channel passing around Bytes. @@ -171,9 +160,8 @@ impl BlobService for GRPCBlobService { let blobchunk_stream = ReceiverStream::new(rx).map(|x| proto::BlobChunk { data: x }); // That receiver stream is used as a stream in the gRPC BlobService.put rpc call. - let task: JoinHandle<Result<_, Status>> = self - .tokio_handle - .spawn(async move { Ok(grpc_client.put(blobchunk_stream).await?.into_inner()) }); + let task: JoinHandle<Result<_, Status>> = + tokio::spawn(async move { Ok(grpc_client.put(blobchunk_stream).await?.into_inner()) }); // The tx part of the channel is converted to a sink of byte chunks. @@ -187,43 +175,26 @@ impl BlobService for GRPCBlobService { // We need to explicitly cast here, otherwise rustc does error with "expected fn pointer, found fn item" // … which is turned into an [tokio::io::AsyncWrite]. - let async_writer = SinkWriter::new(CopyToBytes::new(sink)); - // … which is then turned into a [io::Write]. - let writer = SyncIoBridge::new(async_writer); + let writer = SinkWriter::new(CopyToBytes::new(sink)); Box::new(GRPCBlobWriter { - tokio_handle: self.tokio_handle.clone(), task_and_writer: Some((task, writer)), digest: None, }) } } -type BridgedWriter = SyncIoBridge< - SinkWriter< - CopyToBytes< - SinkMapErr<PollSender<bytes::Bytes>, fn(PollSendError<bytes::Bytes>) -> io::Error>, - >, - >, ->; - -pub struct GRPCBlobWriter { - /// A handle into the active tokio runtime. Necessary to block on the task - /// containing the put request. - tokio_handle: tokio::runtime::Handle, - +pub struct GRPCBlobWriter<W: tokio::io::AsyncWrite> { /// The task containing the put request, and the inner writer, if we're still writing. - task_and_writer: Option<( - JoinHandle<Result<proto::PutBlobResponse, Status>>, - BridgedWriter, - )>, + task_and_writer: Option<(JoinHandle<Result<proto::PutBlobResponse, Status>>, W)>, /// The digest that has been returned, if we successfully closed. digest: Option<B3Digest>, } -impl BlobWriter for GRPCBlobWriter { - fn close(&mut self) -> Result<B3Digest, crate::Error> { +#[async_trait] +impl<W: tokio::io::AsyncWrite + Send + Sync + Unpin + 'static> BlobWriter for GRPCBlobWriter<W> { + async fn close(&mut self) -> Result<B3Digest, crate::Error> { if self.task_and_writer.is_none() { // if we're already closed, return the b3 digest, which must exist. // If it doesn't, we already closed and failed once, and didn't handle the error. @@ -240,12 +211,14 @@ impl BlobWriter for GRPCBlobWriter { // the channel. writer .shutdown() - .map_err(|e| crate::Error::StorageError(e.to_string()))?; + .map_err(|e| crate::Error::StorageError(e.to_string())) + .await?; // block on the RPC call to return. // This ensures all chunks are sent out, and have been received by the // backend. - match self.tokio_handle.block_on(task)? { + + match task.await? { Ok(resp) => { // return the digest from the response, and store it in self.digest for subsequent closes. let digest: B3Digest = resp.digest.try_into().map_err(|_| { @@ -262,26 +235,48 @@ impl BlobWriter for GRPCBlobWriter { } } -impl io::Write for GRPCBlobWriter { - fn write(&mut self, buf: &[u8]) -> io::Result<usize> { +impl<W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for GRPCBlobWriter<W> { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll<Result<usize, io::Error>> { match &mut self.task_and_writer { - None => Err(io::Error::new( + None => Poll::Ready(Err(io::Error::new( io::ErrorKind::NotConnected, "already closed", - )), - Some((_, ref mut writer)) => writer.write(buf), + ))), + Some((_, ref mut writer)) => { + let pinned_writer = pin!(writer); + pinned_writer.poll_write(cx, buf) + } } } - fn flush(&mut self) -> io::Result<()> { + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Result<(), io::Error>> { match &mut self.task_and_writer { - None => Err(io::Error::new( + None => Poll::Ready(Err(io::Error::new( io::ErrorKind::NotConnected, "already closed", - )), - Some((_, ref mut writer)) => writer.flush(), + ))), + Some((_, ref mut writer)) => { + let pinned_writer = pin!(writer); + pinned_writer.poll_flush(cx) + } } } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Result<(), io::Error>> { + // TODO(raitobezarius): this might not be a graceful shutdown of the + // channel inside the gRPC connection. + Poll::Ready(Ok(())) + } } #[cfg(test)] @@ -291,7 +286,6 @@ mod tests { use tempfile::TempDir; use tokio::net::UnixListener; - use tokio::task; use tokio::time; use tokio_stream::wrappers::UnixListenerStream; @@ -358,32 +352,23 @@ mod tests { } /// This uses the correct scheme for a unix socket, and provides a server on the other side. - #[tokio::test] - async fn test_valid_unix_path_ping_pong() { + /// This is not a tokio::test, because spawn two separate tokio runtimes and + // want to have explicit control. + #[test] + fn test_valid_unix_path_ping_pong() { let tmpdir = TempDir::new().unwrap(); let path = tmpdir.path().join("daemon"); - // let mut join_set = JoinSet::new(); - - // prepare a client - let client = { - let mut url = url::Url::parse("grpc+unix:///path/to/somewhere").expect("must parse"); - url.set_path(path.to_str().unwrap()); - GRPCBlobService::from_url(&url).expect("must succeed") - }; - - let path_copy = path.clone(); + let path_clone = path.clone(); // Spin up a server, in a thread far away, which spawns its own tokio runtime, // and blocks on the task. thread::spawn(move || { // Create the runtime let rt = tokio::runtime::Runtime::new().unwrap(); - // Get a handle from this runtime - let handle = rt.handle(); - let task = handle.spawn(async { - let uds = UnixListener::bind(path_copy).unwrap(); + let task = rt.spawn(async { + let uds = UnixListener::bind(path_clone).unwrap(); let uds_stream = UnixListenerStream::new(uds); // spin up a new server @@ -397,33 +382,46 @@ mod tests { router.serve_with_incoming(uds_stream).await }); - handle.block_on(task) + rt.block_on(task).unwrap().unwrap(); }); - // wait for the socket to be created - { - let mut socket_created = false; - for _try in 1..20 { - if path.exists() { - socket_created = true; - break; + // Now create another tokio runtime which we'll use in the main test code. + let rt = tokio::runtime::Runtime::new().unwrap(); + + let task = rt.spawn(async move { + // wait for the socket to be created + { + let mut socket_created = false; + // TODO: exponential backoff urgently + for _try in 1..20 { + if path.exists() { + socket_created = true; + break; + } + tokio::time::sleep(time::Duration::from_millis(20)).await; } - tokio::time::sleep(time::Duration::from_millis(20)).await; + + assert!( + socket_created, + "expected socket path to eventually get created, but never happened" + ); } - assert!( - socket_created, - "expected socket path to eventually get created, but never happened" - ); - } + // prepare a client + let client = { + let mut url = + url::Url::parse("grpc+unix:///path/to/somewhere").expect("must parse"); + url.set_path(path.to_str().unwrap()); + GRPCBlobService::from_url(&url).expect("must succeed") + }; - let has = task::spawn_blocking(move || { - client + let has = client .has(&fixtures::BLOB_A_DIGEST) - .expect("must not be err") - }) - .await - .expect("must not be err"); - assert!(!has); + .await + .expect("must not be err"); + + assert!(!has); + }); + rt.block_on(task).unwrap() } } |