diff options
Diffstat (limited to 'tvix/store/src/nar')
-rw-r--r-- | tvix/store/src/nar/renderer.rs | 134 |
1 files changed, 85 insertions, 49 deletions
diff --git a/tvix/store/src/nar/renderer.rs b/tvix/store/src/nar/renderer.rs index 55dce911ee1a..6e98d2902df6 100644 --- a/tvix/store/src/nar/renderer.rs +++ b/tvix/store/src/nar/renderer.rs @@ -1,9 +1,15 @@ use super::RenderError; +use async_recursion::async_recursion; use count_write::CountWrite; -use nix_compat::nar; +use nix_compat::nar::writer::r#async as nar_writer; use sha2::{Digest, Sha256}; -use std::{io, sync::Arc}; -use tokio::{io::BufReader, task::spawn_blocking}; +use std::{ + pin::Pin, + sync::Arc, + task::{self, Poll}, +}; +use tokio::io::{self, AsyncWrite, BufReader}; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use tracing::warn; use tvix_castore::{ blobservice::BlobService, @@ -19,57 +25,79 @@ pub async fn calculate_size_and_sha256( blob_service: Arc<dyn BlobService>, directory_service: Arc<dyn DirectoryService>, ) -> Result<(u64, [u8; 32]), RenderError> { - let h = Sha256::new(); - let cw = CountWrite::from(h); + let mut h = Sha256::new(); + let mut cw = CountWrite::from(&mut h); + + write_nar( + // The hasher doesn't speak async. It doesn't + // actually do any I/O, so it's fine to wrap. + AsyncIoBridge(&mut cw), + root_node, + blob_service, + directory_service, + ) + .await?; + + Ok((cw.count(), h.finalize().into())) +} + +/// The inverse of [tokio_util::io::SyncIoBridge]. +/// Don't use this with anything that actually does blocking I/O. +struct AsyncIoBridge<T>(T); + +impl<W: std::io::Write + Unpin> AsyncWrite for AsyncIoBridge<W> { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + Poll::Ready(self.get_mut().0.write(buf)) + } - let cw = write_nar(cw, root_node, blob_service, directory_service).await?; + fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(self.get_mut().0.flush()) + } - Ok((cw.count(), cw.into_inner().finalize().into())) + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } } /// Accepts a [castorepb::node::Node] pointing to the root of a (store) path, /// and uses the passed blob_service and directory_service to perform the /// necessary lookups as it traverses the structure. -/// The contents in NAR serialization are writen to the passed [std::io::Write]. -/// -/// The writer is passed back in the return value. This is done because async Rust -/// lacks scoped blocking tasks, so we need to transfer ownership of the writer -/// internally. -/// -/// # Panics -/// This will panic if called outside the context of a Tokio runtime. -pub async fn write_nar<W: std::io::Write + Send + 'static>( - mut w: W, +/// The contents in NAR serialization are writen to the passed [AsyncWrite]. +pub async fn write_nar<W: AsyncWrite + Unpin + Send>( + w: W, proto_root_node: &castorepb::node::Node, blob_service: Arc<dyn BlobService>, directory_service: Arc<dyn DirectoryService>, -) -> Result<W, RenderError> { - let tokio_handle = tokio::runtime::Handle::current(); - let proto_root_node = proto_root_node.clone(); - - spawn_blocking(move || { - // Initialize NAR writer - let nar_root_node = nar::writer::open(&mut w).map_err(RenderError::NARWriterError)?; - - walk_node( - tokio_handle, - nar_root_node, - &proto_root_node, - blob_service, - directory_service, - )?; - - Ok(w) - }) - .await - .unwrap() +) -> Result<(), RenderError> { + // Initialize NAR writer + let mut w = w.compat_write(); + let nar_root_node = nar_writer::open(&mut w) + .await + .map_err(RenderError::NARWriterError)?; + + walk_node( + nar_root_node, + &proto_root_node, + blob_service, + directory_service, + ) + .await?; + + Ok(()) } /// Process an intermediate node in the structure. /// This consumes the node. -fn walk_node( - tokio_handle: tokio::runtime::Handle, - nar_node: nar::writer::Node, +#[async_recursion] +async fn walk_node( + nar_node: nar_writer::Node<'async_recursion, '_>, proto_node: &castorepb::node::Node, blob_service: Arc<dyn BlobService>, directory_service: Arc<dyn DirectoryService>, @@ -78,6 +106,7 @@ fn walk_node( castorepb::node::Node::Symlink(proto_symlink_node) => { nar_node .symlink(&proto_symlink_node.target) + .await .map_err(RenderError::NARWriterError)?; } castorepb::node::Node::File(proto_file_node) => { @@ -92,8 +121,9 @@ fn walk_node( )) })?; - let blob_reader = match tokio_handle - .block_on(async { blob_service.open_read(&digest).await }) + let blob_reader = match blob_service + .open_read(&digest) + .await .map_err(RenderError::StoreError)? { Some(blob_reader) => Ok(BufReader::new(blob_reader)), @@ -107,8 +137,9 @@ fn walk_node( .file( proto_file_node.executable, proto_file_node.size.into(), - &mut tokio_util::io::SyncIoBridge::new(blob_reader), + &mut blob_reader.compat(), ) + .await .map_err(RenderError::NARWriterError)?; } castorepb::node::Node::Directory(proto_directory_node) => { @@ -123,8 +154,9 @@ fn walk_node( })?; // look it up with the directory service - match tokio_handle - .block_on(async { directory_service.get(&digest).await }) + match directory_service + .get(&digest) + .await .map_err(RenderError::StoreError)? { // if it's None, that's an error! @@ -136,27 +168,31 @@ fn walk_node( } Some(proto_directory) => { // start a directory node - let mut nar_node_directory = - nar_node.directory().map_err(RenderError::NARWriterError)?; + let mut nar_node_directory = nar_node + .directory() + .await + .map_err(RenderError::NARWriterError)?; // for each node in the directory, create a new entry with its name, // and then invoke walk_node on that entry. for proto_node in proto_directory.nodes() { let child_node = nar_node_directory .entry(proto_node.get_name()) + .await .map_err(RenderError::NARWriterError)?; walk_node( - tokio_handle.clone(), child_node, &proto_node, blob_service.clone(), directory_service.clone(), - )?; + ) + .await?; } // close the directory nar_node_directory .close() + .await .map_err(RenderError::NARWriterError)?; } } |