about summary refs log tree commit diff
path: root/tvix/store/src/blobservice/grpc.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/store/src/blobservice/grpc.rs')
-rw-r--r--tvix/store/src/blobservice/grpc.rs248
1 files changed, 123 insertions, 125 deletions
diff --git a/tvix/store/src/blobservice/grpc.rs b/tvix/store/src/blobservice/grpc.rs
index c6d28860f8..cea796033e 100644
--- a/tvix/store/src/blobservice/grpc.rs
+++ b/tvix/store/src/blobservice/grpc.rs
@@ -1,22 +1,26 @@
-use super::{dumb_seeker::DumbSeeker, BlobReader, BlobService, BlobWriter};
+use super::{naive_seeker::NaiveSeeker, BlobReader, BlobService, BlobWriter};
 use crate::{proto, B3Digest};
-use futures::sink::{SinkExt, SinkMapErr};
-use std::{collections::VecDeque, io};
+use futures::sink::SinkExt;
+use futures::TryFutureExt;
+use std::{
+    collections::VecDeque,
+    io::{self},
+    pin::pin,
+    task::Poll,
+};
+use tokio::io::AsyncWriteExt;
 use tokio::{net::UnixStream, task::JoinHandle};
 use tokio_stream::{wrappers::ReceiverStream, StreamExt};
 use tokio_util::{
-    io::{CopyToBytes, SinkWriter, SyncIoBridge},
+    io::{CopyToBytes, SinkWriter},
     sync::{PollSendError, PollSender},
 };
-use tonic::{transport::Channel, Code, Status, Streaming};
+use tonic::{async_trait, transport::Channel, Code, Status};
 use tracing::instrument;
 
 /// Connects to a (remote) tvix-store BlobService over gRPC.
 #[derive(Clone)]
 pub struct GRPCBlobService {
-    /// A handle into the active tokio runtime. Necessary to spawn tasks.
-    tokio_handle: tokio::runtime::Handle,
-
     /// The internal reference to a gRPC client.
     /// Cloning it is cheap, and it internally handles concurrent requests.
     grpc_client: proto::blob_service_client::BlobServiceClient<Channel>,
@@ -28,13 +32,11 @@ impl GRPCBlobService {
     pub fn from_client(
         grpc_client: proto::blob_service_client::BlobServiceClient<Channel>,
     ) -> Self {
-        Self {
-            tokio_handle: tokio::runtime::Handle::current(),
-            grpc_client,
-        }
+        Self { grpc_client }
     }
 }
 
+#[async_trait]
 impl BlobService for GRPCBlobService {
     /// Constructs a [GRPCBlobService] from the passed [url::Url]:
     /// - scheme has to match `grpc+*://`.
@@ -89,22 +91,16 @@ impl BlobService for GRPCBlobService {
     }
 
     #[instrument(skip(self, digest), fields(blob.digest=%digest))]
-    fn has(&self, digest: &B3Digest) -> Result<bool, crate::Error> {
-        // Get a new handle to the gRPC client, and copy the digest.
+    async fn has(&self, digest: &B3Digest) -> Result<bool, crate::Error> {
         let mut grpc_client = self.grpc_client.clone();
-        let digest = digest.clone();
-
-        let task: JoinHandle<Result<_, Status>> = self.tokio_handle.spawn(async move {
-            Ok(grpc_client
-                .stat(proto::StatBlobRequest {
-                    digest: digest.into(),
-                    ..Default::default()
-                })
-                .await?
-                .into_inner())
-        });
-
-        match self.tokio_handle.block_on(task)? {
+        let resp = grpc_client
+            .stat(proto::StatBlobRequest {
+                digest: digest.clone().into(),
+                ..Default::default()
+            })
+            .await;
+
+        match resp {
             Ok(_blob_meta) => Ok(true),
             Err(e) if e.code() == Code::NotFound => Ok(false),
             Err(e) => Err(crate::Error::StorageError(e.to_string())),
@@ -113,35 +109,30 @@ impl BlobService for GRPCBlobService {
 
     // On success, this returns a Ok(Some(io::Read)), which can be used to read
     // the contents of the Blob, identified by the digest.
-    fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, crate::Error> {
+    async fn open_read(
+        &self,
+        digest: &B3Digest,
+    ) -> Result<Option<Box<dyn BlobReader>>, crate::Error> {
         // Get a new handle to the gRPC client, and copy the digest.
         let mut grpc_client = self.grpc_client.clone();
-        let digest = digest.clone();
-
-        // Construct the task that'll send out the request and return the stream
-        // the gRPC client should use to send [proto::BlobChunk], or an error if
-        // the blob doesn't exist.
-        let task: JoinHandle<Result<Streaming<proto::BlobChunk>, Status>> =
-            self.tokio_handle.spawn(async move {
-                let stream = grpc_client
-                    .read(proto::ReadBlobRequest {
-                        digest: digest.into(),
-                    })
-                    .await?
-                    .into_inner();
-
-                Ok(stream)
-            });
+
+        // Get a stream of [proto::BlobChunk], or return an error if the blob
+        // doesn't exist.
+        let resp = grpc_client
+            .read(proto::ReadBlobRequest {
+                digest: digest.clone().into(),
+            })
+            .await;
 
         // This runs the task to completion, which on success will return a stream.
         // On reading from it, we receive individual [proto::BlobChunk], so we
         // massage this to a stream of bytes,
         // then create an [AsyncRead], which we'll turn into a [io::Read],
         // that's returned from the function.
-        match self.tokio_handle.block_on(task)? {
+        match resp {
             Ok(stream) => {
                 // map the stream of proto::BlobChunk to bytes.
-                let data_stream = stream.map(|x| {
+                let data_stream = stream.into_inner().map(|x| {
                     x.map(|x| VecDeque::from(x.data.to_vec()))
                         .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
                 });
@@ -149,9 +140,7 @@ impl BlobService for GRPCBlobService {
                 // Use StreamReader::new to convert to an AsyncRead.
                 let data_reader = tokio_util::io::StreamReader::new(data_stream);
 
-                // Use SyncIoBridge to turn it into a sync Read.
-                let sync_reader = tokio_util::io::SyncIoBridge::new(data_reader);
-                Ok(Some(Box::new(DumbSeeker::new(sync_reader))))
+                Ok(Some(Box::new(NaiveSeeker::new(data_reader))))
             }
             Err(e) if e.code() == Code::NotFound => Ok(None),
             Err(e) => Err(crate::Error::StorageError(e.to_string())),
@@ -160,7 +149,7 @@ impl BlobService for GRPCBlobService {
 
     /// Returns a BlobWriter, that'll internally wrap each write in a
     // [proto::BlobChunk], which is send to the gRPC server.
-    fn open_write(&self) -> Box<dyn BlobWriter> {
+    async fn open_write(&self) -> Box<dyn BlobWriter> {
         let mut grpc_client = self.grpc_client.clone();
 
         // set up an mpsc channel passing around Bytes.
@@ -171,9 +160,8 @@ impl BlobService for GRPCBlobService {
         let blobchunk_stream = ReceiverStream::new(rx).map(|x| proto::BlobChunk { data: x });
 
         // That receiver stream is used as a stream in the gRPC BlobService.put rpc call.
-        let task: JoinHandle<Result<_, Status>> = self
-            .tokio_handle
-            .spawn(async move { Ok(grpc_client.put(blobchunk_stream).await?.into_inner()) });
+        let task: JoinHandle<Result<_, Status>> =
+            tokio::spawn(async move { Ok(grpc_client.put(blobchunk_stream).await?.into_inner()) });
 
         // The tx part of the channel is converted to a sink of byte chunks.
 
@@ -187,43 +175,26 @@ impl BlobService for GRPCBlobService {
         // We need to explicitly cast here, otherwise rustc does error with "expected fn pointer, found fn item"
 
         // … which is turned into an [tokio::io::AsyncWrite].
-        let async_writer = SinkWriter::new(CopyToBytes::new(sink));
-        // … which is then turned into a [io::Write].
-        let writer = SyncIoBridge::new(async_writer);
+        let writer = SinkWriter::new(CopyToBytes::new(sink));
 
         Box::new(GRPCBlobWriter {
-            tokio_handle: self.tokio_handle.clone(),
             task_and_writer: Some((task, writer)),
             digest: None,
         })
     }
 }
 
-type BridgedWriter = SyncIoBridge<
-    SinkWriter<
-        CopyToBytes<
-            SinkMapErr<PollSender<bytes::Bytes>, fn(PollSendError<bytes::Bytes>) -> io::Error>,
-        >,
-    >,
->;
-
-pub struct GRPCBlobWriter {
-    /// A handle into the active tokio runtime. Necessary to block on the task
-    /// containing the put request.
-    tokio_handle: tokio::runtime::Handle,
-
+pub struct GRPCBlobWriter<W: tokio::io::AsyncWrite> {
     /// The task containing the put request, and the inner writer, if we're still writing.
-    task_and_writer: Option<(
-        JoinHandle<Result<proto::PutBlobResponse, Status>>,
-        BridgedWriter,
-    )>,
+    task_and_writer: Option<(JoinHandle<Result<proto::PutBlobResponse, Status>>, W)>,
 
     /// The digest that has been returned, if we successfully closed.
     digest: Option<B3Digest>,
 }
 
-impl BlobWriter for GRPCBlobWriter {
-    fn close(&mut self) -> Result<B3Digest, crate::Error> {
+#[async_trait]
+impl<W: tokio::io::AsyncWrite + Send + Sync + Unpin + 'static> BlobWriter for GRPCBlobWriter<W> {
+    async fn close(&mut self) -> Result<B3Digest, crate::Error> {
         if self.task_and_writer.is_none() {
             // if we're already closed, return the b3 digest, which must exist.
             // If it doesn't, we already closed and failed once, and didn't handle the error.
@@ -240,12 +211,14 @@ impl BlobWriter for GRPCBlobWriter {
             // the channel.
             writer
                 .shutdown()
-                .map_err(|e| crate::Error::StorageError(e.to_string()))?;
+                .map_err(|e| crate::Error::StorageError(e.to_string()))
+                .await?;
 
             // block on the RPC call to return.
             // This ensures all chunks are sent out, and have been received by the
             // backend.
-            match self.tokio_handle.block_on(task)? {
+
+            match task.await? {
                 Ok(resp) => {
                     // return the digest from the response, and store it in self.digest for subsequent closes.
                     let digest: B3Digest = resp.digest.try_into().map_err(|_| {
@@ -262,26 +235,48 @@ impl BlobWriter for GRPCBlobWriter {
     }
 }
 
-impl io::Write for GRPCBlobWriter {
-    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+impl<W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for GRPCBlobWriter<W> {
+    fn poll_write(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+        buf: &[u8],
+    ) -> std::task::Poll<Result<usize, io::Error>> {
         match &mut self.task_and_writer {
-            None => Err(io::Error::new(
+            None => Poll::Ready(Err(io::Error::new(
                 io::ErrorKind::NotConnected,
                 "already closed",
-            )),
-            Some((_, ref mut writer)) => writer.write(buf),
+            ))),
+            Some((_, ref mut writer)) => {
+                let pinned_writer = pin!(writer);
+                pinned_writer.poll_write(cx, buf)
+            }
         }
     }
 
-    fn flush(&mut self) -> io::Result<()> {
+    fn poll_flush(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> std::task::Poll<Result<(), io::Error>> {
         match &mut self.task_and_writer {
-            None => Err(io::Error::new(
+            None => Poll::Ready(Err(io::Error::new(
                 io::ErrorKind::NotConnected,
                 "already closed",
-            )),
-            Some((_, ref mut writer)) => writer.flush(),
+            ))),
+            Some((_, ref mut writer)) => {
+                let pinned_writer = pin!(writer);
+                pinned_writer.poll_flush(cx)
+            }
         }
     }
+
+    fn poll_shutdown(
+        self: std::pin::Pin<&mut Self>,
+        _cx: &mut std::task::Context<'_>,
+    ) -> std::task::Poll<Result<(), io::Error>> {
+        // TODO(raitobezarius): this might not be a graceful shutdown of the
+        // channel inside the gRPC connection.
+        Poll::Ready(Ok(()))
+    }
 }
 
 #[cfg(test)]
@@ -291,7 +286,6 @@ mod tests {
 
     use tempfile::TempDir;
     use tokio::net::UnixListener;
-    use tokio::task;
     use tokio::time;
     use tokio_stream::wrappers::UnixListenerStream;
 
@@ -358,32 +352,23 @@ mod tests {
     }
 
     /// This uses the correct scheme for a unix socket, and provides a server on the other side.
-    #[tokio::test]
-    async fn test_valid_unix_path_ping_pong() {
+    /// This is not a tokio::test, because spawn two separate tokio runtimes and
+    // want to have explicit control.
+    #[test]
+    fn test_valid_unix_path_ping_pong() {
         let tmpdir = TempDir::new().unwrap();
         let path = tmpdir.path().join("daemon");
 
-        // let mut join_set = JoinSet::new();
-
-        // prepare a client
-        let client = {
-            let mut url = url::Url::parse("grpc+unix:///path/to/somewhere").expect("must parse");
-            url.set_path(path.to_str().unwrap());
-            GRPCBlobService::from_url(&url).expect("must succeed")
-        };
-
-        let path_copy = path.clone();
+        let path_clone = path.clone();
 
         // Spin up a server, in a thread far away, which spawns its own tokio runtime,
         // and blocks on the task.
         thread::spawn(move || {
             // Create the runtime
             let rt = tokio::runtime::Runtime::new().unwrap();
-            // Get a handle from this runtime
-            let handle = rt.handle();
 
-            let task = handle.spawn(async {
-                let uds = UnixListener::bind(path_copy).unwrap();
+            let task = rt.spawn(async {
+                let uds = UnixListener::bind(path_clone).unwrap();
                 let uds_stream = UnixListenerStream::new(uds);
 
                 // spin up a new server
@@ -397,33 +382,46 @@ mod tests {
                 router.serve_with_incoming(uds_stream).await
             });
 
-            handle.block_on(task)
+            rt.block_on(task).unwrap().unwrap();
         });
 
-        // wait for the socket to be created
-        {
-            let mut socket_created = false;
-            for _try in 1..20 {
-                if path.exists() {
-                    socket_created = true;
-                    break;
+        // Now create another tokio runtime which we'll use in the main test code.
+        let rt = tokio::runtime::Runtime::new().unwrap();
+
+        let task = rt.spawn(async move {
+            // wait for the socket to be created
+            {
+                let mut socket_created = false;
+                // TODO: exponential backoff urgently
+                for _try in 1..20 {
+                    if path.exists() {
+                        socket_created = true;
+                        break;
+                    }
+                    tokio::time::sleep(time::Duration::from_millis(20)).await;
                 }
-                tokio::time::sleep(time::Duration::from_millis(20)).await;
+
+                assert!(
+                    socket_created,
+                    "expected socket path to eventually get created, but never happened"
+                );
             }
 
-            assert!(
-                socket_created,
-                "expected socket path to eventually get created, but never happened"
-            );
-        }
+            // prepare a client
+            let client = {
+                let mut url =
+                    url::Url::parse("grpc+unix:///path/to/somewhere").expect("must parse");
+                url.set_path(path.to_str().unwrap());
+                GRPCBlobService::from_url(&url).expect("must succeed")
+            };
 
-        let has = task::spawn_blocking(move || {
-            client
+            let has = client
                 .has(&fixtures::BLOB_A_DIGEST)
-                .expect("must not be err")
-        })
-        .await
-        .expect("must not be err");
-        assert!(!has);
+                .await
+                .expect("must not be err");
+
+            assert!(!has);
+        });
+        rt.block_on(task).unwrap()
     }
 }