about summary refs log tree commit diff
path: root/tvix/store/src/directoryservice/grpc.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/store/src/directoryservice/grpc.rs')
-rw-r--r--tvix/store/src/directoryservice/grpc.rs285
1 files changed, 127 insertions, 158 deletions
diff --git a/tvix/store/src/directoryservice/grpc.rs b/tvix/store/src/directoryservice/grpc.rs
index 73d88bb688a3..6257a8e81485 100644
--- a/tvix/store/src/directoryservice/grpc.rs
+++ b/tvix/store/src/directoryservice/grpc.rs
@@ -1,22 +1,24 @@
 use std::collections::HashSet;
+use std::pin::Pin;
 
 use super::{DirectoryPutter, DirectoryService};
 use crate::proto::{self, get_directory_request::ByWhat};
 use crate::{B3Digest, Error};
+use async_stream::try_stream;
+use futures::Stream;
 use tokio::net::UnixStream;
+use tokio::spawn;
 use tokio::sync::mpsc::UnboundedSender;
 use tokio::task::JoinHandle;
 use tokio_stream::wrappers::UnboundedReceiverStream;
+use tonic::async_trait;
+use tonic::Code;
 use tonic::{transport::Channel, Status};
-use tonic::{Code, Streaming};
 use tracing::{instrument, warn};
 
 /// Connects to a (remote) tvix-store DirectoryService over gRPC.
 #[derive(Clone)]
 pub struct GRPCDirectoryService {
-    /// A handle into the active tokio runtime. Necessary to spawn tasks.
-    tokio_handle: tokio::runtime::Handle,
-
     /// The internal reference to a gRPC client.
     /// Cloning it is cheap, and it internally handles concurrent requests.
     grpc_client: proto::directory_service_client::DirectoryServiceClient<Channel>,
@@ -28,13 +30,11 @@ impl GRPCDirectoryService {
     pub fn from_client(
         grpc_client: proto::directory_service_client::DirectoryServiceClient<Channel>,
     ) -> Self {
-        Self {
-            tokio_handle: tokio::runtime::Handle::current(),
-            grpc_client,
-        }
+        Self { grpc_client }
     }
 }
 
+#[async_trait]
 impl DirectoryService for GRPCDirectoryService {
     /// Constructs a [GRPCDirectoryService] from the passed [url::Url]:
     /// - scheme has to match `grpc+*://`.
@@ -89,11 +89,15 @@ impl DirectoryService for GRPCDirectoryService {
             }
         }
     }
-    fn get(&self, digest: &B3Digest) -> Result<Option<crate::proto::Directory>, crate::Error> {
+
+    async fn get(
+        &self,
+        digest: &B3Digest,
+    ) -> Result<Option<crate::proto::Directory>, crate::Error> {
         // Get a new handle to the gRPC client, and copy the digest.
         let mut grpc_client = self.grpc_client.clone();
         let digest_cpy = digest.clone();
-        let task = self.tokio_handle.spawn(async move {
+        let message = async move {
             let mut s = grpc_client
                 .get(proto::GetDirectoryRequest {
                     recursive: false,
@@ -104,10 +108,10 @@ impl DirectoryService for GRPCDirectoryService {
 
             // Retrieve the first message only, then close the stream (we set recursive to false)
             s.message().await
-        });
+        };
 
         let digest = digest.clone();
-        match self.tokio_handle.block_on(task)? {
+        match message.await {
             Ok(Some(directory)) => {
                 // Validate the retrieved Directory indeed has the
                 // digest we expect it to have, to detect corruptions.
@@ -134,14 +138,12 @@ impl DirectoryService for GRPCDirectoryService {
         }
     }
 
-    fn put(&self, directory: crate::proto::Directory) -> Result<B3Digest, crate::Error> {
+    async fn put(&self, directory: crate::proto::Directory) -> Result<B3Digest, crate::Error> {
         let mut grpc_client = self.grpc_client.clone();
 
-        let task = self
-            .tokio_handle
-            .spawn(async move { grpc_client.put(tokio_stream::iter(vec![directory])).await });
+        let resp = grpc_client.put(tokio_stream::iter(vec![directory])).await;
 
-        match self.tokio_handle.block_on(task)? {
+        match resp {
             Ok(put_directory_resp) => Ok(put_directory_resp
                 .into_inner()
                 .root_digest
@@ -157,32 +159,82 @@ impl DirectoryService for GRPCDirectoryService {
     fn get_recursive(
         &self,
         root_directory_digest: &B3Digest,
-    ) -> Box<dyn Iterator<Item = Result<proto::Directory, Error>> + Send> {
+    ) -> Pin<Box<dyn Stream<Item = Result<proto::Directory, Error>> + Send>> {
         let mut grpc_client = self.grpc_client.clone();
+        let root_directory_digest = root_directory_digest.clone();
 
-        // clone so we can move it
-        let root_directory_digest_cpy = root_directory_digest.clone();
-
-        let task: JoinHandle<Result<Streaming<proto::Directory>, Status>> =
-            self.tokio_handle.spawn(async move {
-                let s = grpc_client
-                    .get(proto::GetDirectoryRequest {
-                        recursive: true,
-                        by_what: Some(ByWhat::Digest(root_directory_digest_cpy.into())),
-                    })
-                    .await?
-                    .into_inner();
-
-                Ok(s)
-            });
+        let stream = try_stream! {
+            let mut stream = grpc_client
+                .get(proto::GetDirectoryRequest {
+                    recursive: true,
+                    by_what: Some(ByWhat::Digest(root_directory_digest.clone().into())),
+                })
+                .await
+                .map_err(|e| crate::Error::StorageError(e.to_string()))?
+                .into_inner();
 
-        let stream = self.tokio_handle.block_on(task).unwrap().unwrap();
+            // The Directory digests we received so far
+            let mut received_directory_digests: HashSet<B3Digest> = HashSet::new();
+            // The Directory digests we're still expecting to get sent.
+            let mut expected_directory_digests: HashSet<B3Digest> = HashSet::from([root_directory_digest]);
+
+            loop {
+                match stream.message().await {
+                    Ok(Some(directory)) => {
+                        // validate the directory itself.
+                        if let Err(e) = directory.validate() {
+                            Err(crate::Error::StorageError(format!(
+                                "directory {} failed validation: {}",
+                                directory.digest(),
+                                e,
+                            )))?;
+                        }
+                        // validate we actually expected that directory, and move it from expected to received.
+                        let directory_digest = directory.digest();
+                        let was_expected = expected_directory_digests.remove(&directory_digest);
+                        if !was_expected {
+                            // FUTUREWORK: dumb clients might send the same stuff twice.
+                            // as a fallback, we might want to tolerate receiving
+                            // it if it's in received_directory_digests (as that
+                            // means it once was in expected_directory_digests)
+                            Err(crate::Error::StorageError(format!(
+                                "received unexpected directory {}",
+                                directory_digest
+                            )))?;
+                        }
+                        received_directory_digests.insert(directory_digest);
+
+                        // register all children in expected_directory_digests.
+                        for child_directory in &directory.directories {
+                            // We ran validate() above, so we know these digests must be correct.
+                            let child_directory_digest =
+                                child_directory.digest.clone().try_into().unwrap();
+
+                            expected_directory_digests
+                                .insert(child_directory_digest);
+                        }
+
+                        yield directory;
+                    },
+                    Ok(None) => {
+                        // If we were still expecting something, that's an error.
+                        if !expected_directory_digests.is_empty() {
+                            Err(crate::Error::StorageError(format!(
+                                "still expected {} directories, but got premature end of stream",
+                                expected_directory_digests.len(),
+                            )))?
+                        } else {
+                            return
+                        }
+                    },
+                    Err(e) => {
+                        Err(crate::Error::StorageError(e.to_string()))?;
+                    },
+                }
+            }
+        };
 
-        Box::new(StreamIterator::new(
-            self.tokio_handle.clone(),
-            root_directory_digest.clone(),
-            stream,
-        ))
+        Box::pin(stream)
     }
 
     #[instrument(skip_all)]
@@ -194,110 +246,21 @@ impl DirectoryService for GRPCDirectoryService {
 
         let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
 
-        let task: JoinHandle<Result<proto::PutDirectoryResponse, Status>> =
-            self.tokio_handle.spawn(async move {
-                let s = grpc_client
-                    .put(UnboundedReceiverStream::new(rx))
-                    .await?
-                    .into_inner();
-
-                Ok(s)
-            });
-
-        Box::new(GRPCPutter::new(self.tokio_handle.clone(), tx, task))
-    }
-}
-
-pub struct StreamIterator {
-    /// A handle into the active tokio runtime. Necessary to run futures to completion.
-    tokio_handle: tokio::runtime::Handle,
-    // A stream of [proto::Directory]
-    stream: Streaming<proto::Directory>,
-    // The Directory digests we received so far
-    received_directory_digests: HashSet<B3Digest>,
-    // The Directory digests we're still expecting to get sent.
-    expected_directory_digests: HashSet<B3Digest>,
-}
-
-impl StreamIterator {
-    pub fn new(
-        tokio_handle: tokio::runtime::Handle,
-        root_digest: B3Digest,
-        stream: Streaming<proto::Directory>,
-    ) -> Self {
-        Self {
-            tokio_handle,
-            stream,
-            received_directory_digests: HashSet::new(),
-            expected_directory_digests: HashSet::from([root_digest]),
-        }
-    }
-}
-
-impl Iterator for StreamIterator {
-    type Item = Result<proto::Directory, crate::Error>;
-
-    fn next(&mut self) -> Option<Self::Item> {
-        match self.tokio_handle.block_on(self.stream.message()) {
-            Ok(ok) => match ok {
-                Some(directory) => {
-                    // validate the directory itself.
-                    if let Err(e) = directory.validate() {
-                        return Some(Err(crate::Error::StorageError(format!(
-                            "directory {} failed validation: {}",
-                            directory.digest(),
-                            e,
-                        ))));
-                    }
-                    // validate we actually expected that directory, and move it from expected to received.
-                    let directory_digest = directory.digest();
-                    let was_expected = self.expected_directory_digests.remove(&directory_digest);
-                    if !was_expected {
-                        // FUTUREWORK: dumb clients might send the same stuff twice.
-                        // as a fallback, we might want to tolerate receiving
-                        // it if it's in received_directory_digests (as that
-                        // means it once was in expected_directory_digests)
-                        return Some(Err(crate::Error::StorageError(format!(
-                            "received unexpected directory {}",
-                            directory_digest
-                        ))));
-                    }
-                    self.received_directory_digests.insert(directory_digest);
-
-                    // register all children in expected_directory_digests.
-                    for child_directory in &directory.directories {
-                        // We ran validate() above, so we know these digests must be correct.
-                        let child_directory_digest =
-                            child_directory.digest.clone().try_into().unwrap();
+        let task: JoinHandle<Result<proto::PutDirectoryResponse, Status>> = spawn(async move {
+            let s = grpc_client
+                .put(UnboundedReceiverStream::new(rx))
+                .await?
+                .into_inner();
 
-                        self.expected_directory_digests
-                            .insert(child_directory_digest);
-                    }
+            Ok(s)
+        });
 
-                    Some(Ok(directory))
-                }
-                None => {
-                    // If we were still expecting something, that's an error.
-                    if !self.expected_directory_digests.is_empty() {
-                        Some(Err(crate::Error::StorageError(format!(
-                            "still expected {} directories, but got premature end of stream",
-                            self.expected_directory_digests.len(),
-                        ))))
-                    } else {
-                        None
-                    }
-                }
-            },
-            Err(e) => Some(Err(crate::Error::StorageError(e.to_string()))),
-        }
+        Box::new(GRPCPutter::new(tx, task))
     }
 }
 
 /// Allows uploading multiple Directory messages in the same gRPC stream.
 pub struct GRPCPutter {
-    /// A handle into the active tokio runtime. Necessary to spawn tasks.
-    tokio_handle: tokio::runtime::Handle,
-
     /// Data about the current request - a handle to the task, and the tx part
     /// of the channel.
     /// The tx part of the pipe is used to send [proto::Directory] to the ongoing request.
@@ -311,19 +274,18 @@ pub struct GRPCPutter {
 
 impl GRPCPutter {
     pub fn new(
-        tokio_handle: tokio::runtime::Handle,
         directory_sender: UnboundedSender<proto::Directory>,
         task: JoinHandle<Result<proto::PutDirectoryResponse, Status>>,
     ) -> Self {
         Self {
-            tokio_handle,
             rq: Some((task, directory_sender)),
         }
     }
 }
 
+#[async_trait]
 impl DirectoryPutter for GRPCPutter {
-    fn put(&mut self, directory: proto::Directory) -> Result<(), crate::Error> {
+    async fn put(&mut self, directory: proto::Directory) -> Result<(), crate::Error> {
         match self.rq {
             // If we're not already closed, send the directory to directory_sender.
             Some((_, ref directory_sender)) => {
@@ -331,7 +293,7 @@ impl DirectoryPutter for GRPCPutter {
                     // If the channel has been prematurely closed, invoke close (so we can peek at the error code)
                     // That error code is much more helpful, because it
                     // contains the error message from the server.
-                    self.close()?;
+                    self.close().await?;
                 }
                 Ok(())
             }
@@ -343,7 +305,7 @@ impl DirectoryPutter for GRPCPutter {
     }
 
     /// Closes the stream for sending, and returns the value
-    fn close(&mut self) -> Result<B3Digest, crate::Error> {
+    async fn close(&mut self) -> Result<B3Digest, crate::Error> {
         // get self.rq, and replace it with None.
         // This ensures we can only close it once.
         match std::mem::take(&mut self.rq) {
@@ -352,9 +314,8 @@ impl DirectoryPutter for GRPCPutter {
                 // close directory_sender, so blocking on task will finish.
                 drop(directory_sender);
 
-                let root_digest = self
-                    .tokio_handle
-                    .block_on(task)?
+                let root_digest = task
+                    .await?
                     .map_err(|e| Error::StorageError(e.to_string()))?
                     .root_digest;
 
@@ -379,6 +340,7 @@ mod tests {
     use core::time;
     use std::thread;
 
+    use futures::StreamExt;
     use tempfile::TempDir;
     use tokio::net::{UnixListener, UnixStream};
     use tokio_stream::wrappers::UnixListenerStream;
@@ -446,7 +408,7 @@ mod tests {
             );
         }
 
-        let task = tester_runtime.spawn_blocking(move || {
+        tester_runtime.block_on(async move {
             // Create a channel, connecting to the uds at socket_path.
             // The URI is unused.
             let channel = Endpoint::try_from("http://[::]:50051")
@@ -465,6 +427,7 @@ mod tests {
                 None,
                 directory_service
                     .get(&DIRECTORY_A.digest())
+                    .await
                     .expect("must not fail")
             );
 
@@ -473,6 +436,7 @@ mod tests {
                 DIRECTORY_A.digest(),
                 directory_service
                     .put(DIRECTORY_A.clone())
+                    .await
                     .expect("must succeed")
             );
 
@@ -481,6 +445,7 @@ mod tests {
                 DIRECTORY_A.clone(),
                 directory_service
                     .get(&DIRECTORY_A.digest())
+                    .await
                     .expect("must succeed")
                     .expect("must be some")
             );
@@ -488,21 +453,22 @@ mod tests {
             // Putting DIRECTORY_B alone should fail, because it refers to DIRECTORY_A.
             directory_service
                 .put(DIRECTORY_B.clone())
+                .await
                 .expect_err("must fail");
 
             // Putting DIRECTORY_B in a put_multiple will succeed, but the close
             // will always fail.
             {
                 let mut handle = directory_service.put_multiple_start();
-                handle.put(DIRECTORY_B.clone()).expect("must succeed");
-                handle.close().expect_err("must fail");
+                handle.put(DIRECTORY_B.clone()).await.expect("must succeed");
+                handle.close().await.expect_err("must fail");
             }
 
             // Uploading A and then B should succeed, and closing should return the digest of B.
             let mut handle = directory_service.put_multiple_start();
-            handle.put(DIRECTORY_A.clone()).expect("must succeed");
-            handle.put(DIRECTORY_B.clone()).expect("must succeed");
-            let digest = handle.close().expect("must succeed");
+            handle.put(DIRECTORY_A.clone()).await.expect("must succeed");
+            handle.put(DIRECTORY_B.clone()).await.expect("must succeed");
+            let digest = handle.close().await.expect("must succeed");
             assert_eq!(DIRECTORY_B.digest(), digest);
 
             // Now try to retrieve the closure of DIRECTORY_B, which should return B and then A.
@@ -511,6 +477,7 @@ mod tests {
                 DIRECTORY_B.clone(),
                 directories_it
                     .next()
+                    .await
                     .expect("must be some")
                     .expect("must succeed")
             );
@@ -518,6 +485,7 @@ mod tests {
                 DIRECTORY_A.clone(),
                 directories_it
                     .next()
+                    .await
                     .expect("must be some")
                     .expect("must succeed")
             );
@@ -529,15 +497,15 @@ mod tests {
             {
                 let mut handle = directory_service.put_multiple_start();
                 // sending out B will always be fine
-                handle.put(DIRECTORY_B.clone()).expect("must succeed");
+                handle.put(DIRECTORY_B.clone()).await.expect("must succeed");
 
                 // whether we will be able to put A as well depends on whether we
                 // already received the error about B.
-                if handle.put(DIRECTORY_A.clone()).is_ok() {
+                if handle.put(DIRECTORY_A.clone()).await.is_ok() {
                     // If we didn't, and this was Ok(_), …
                     // a subsequent close MUST fail (because it waits for the
                     // server)
-                    handle.close().expect_err("must fail");
+                    handle.close().await.expect_err("must fail");
                 }
             }
 
@@ -547,7 +515,7 @@ mod tests {
             // and then assert that uploading anything else via the handle will fail.
             {
                 let mut handle = directory_service.put_multiple_start();
-                handle.put(DIRECTORY_B.clone()).expect("must succeed");
+                handle.put(DIRECTORY_B.clone()).await.expect("must succeed");
 
                 let mut is_closed = false;
                 for _try in 1..1000 {
@@ -555,7 +523,7 @@ mod tests {
                         is_closed = true;
                         break;
                     }
-                    std::thread::sleep(time::Duration::from_millis(10))
+                    tokio::time::sleep(time::Duration::from_millis(10)).await;
                 }
 
                 assert!(
@@ -563,12 +531,13 @@ mod tests {
                     "expected channel to eventually close, but never happened"
                 );
 
-                handle.put(DIRECTORY_A.clone()).expect_err("must fail");
+                handle
+                    .put(DIRECTORY_A.clone())
+                    .await
+                    .expect_err("must fail");
             }
         });
 
-        tester_runtime.block_on(task)?;
-
         Ok(())
     }
 }