about summary refs log tree commit diff
path: root/tvix/store/src/pathinfoservice/nix_http.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/store/src/pathinfoservice/nix_http.rs')
-rw-r--r--tvix/store/src/pathinfoservice/nix_http.rs180
1 files changed, 72 insertions, 108 deletions
diff --git a/tvix/store/src/pathinfoservice/nix_http.rs b/tvix/store/src/pathinfoservice/nix_http.rs
index bdb0e2c3cba7..234340592e89 100644
--- a/tvix/store/src/pathinfoservice/nix_http.rs
+++ b/tvix/store/src/pathinfoservice/nix_http.rs
@@ -1,5 +1,3 @@
-use std::io::{self, BufRead, Read, Write};
-
 use data_encoding::BASE64;
 use futures::{stream::BoxStream, TryStreamExt};
 use nix_compat::{
@@ -8,7 +6,10 @@ use nix_compat::{
     nixhash::NixHash,
 };
 use reqwest::StatusCode;
-use sha2::{digest::FixedOutput, Digest, Sha256};
+use sha2::Digest;
+use std::io::{self, Write};
+use tokio::io::{AsyncRead, BufReader};
+use tokio_util::io::InspectReader;
 use tonic::async_trait;
 use tracing::{debug, instrument, warn};
 use tvix_castore::{
@@ -171,85 +172,83 @@ where
             )));
         }
 
-        // get an AsyncRead of the response body.
-        let async_r = tokio_util::io::StreamReader::new(resp.bytes_stream().map_err(|e| {
+        // get a reader of the response body.
+        let r = tokio_util::io::StreamReader::new(resp.bytes_stream().map_err(|e| {
             let e = e.without_url();
             warn!(e=%e, "failed to get response body");
             io::Error::new(io::ErrorKind::BrokenPipe, e.to_string())
         }));
-        let sync_r = tokio_util::io::SyncIoBridge::new(async_r);
 
-        // handle decompression, by wrapping the reader.
-        let sync_r: Box<dyn BufRead + Send> = match narinfo.compression {
-            Some("none") => Box::new(sync_r),
-            Some("xz") => Box::new(io::BufReader::new(xz2::read::XzDecoder::new(sync_r))),
-            Some(comp) => {
-                return Err(Error::InvalidRequest(
-                    format!("unsupported compression: {}", comp).to_string(),
-                ))
-            }
-            None => {
-                return Err(Error::InvalidRequest(
-                    "unsupported compression: bzip2".to_string(),
-                ))
+        // handle decompression, depending on the compression field.
+        let r: Box<dyn AsyncRead + Send + Unpin> = match narinfo.compression {
+            Some("none") => Box::new(r) as Box<dyn AsyncRead + Send + Unpin>,
+            Some("bzip2") | None => Box::new(async_compression::tokio::bufread::BzDecoder::new(r))
+                as Box<dyn AsyncRead + Send + Unpin>,
+            Some("gzip") => Box::new(async_compression::tokio::bufread::GzipDecoder::new(r))
+                as Box<dyn AsyncRead + Send + Unpin>,
+            Some("xz") => Box::new(async_compression::tokio::bufread::XzDecoder::new(r))
+                as Box<dyn AsyncRead + Send + Unpin>,
+            Some("zstd") => Box::new(async_compression::tokio::bufread::ZstdDecoder::new(r))
+                as Box<dyn AsyncRead + Send + Unpin>,
+            Some(comp_str) => {
+                return Err(Error::StorageError(format!(
+                    "unsupported compression: {comp_str}"
+                )));
             }
         };
-
-        let res = tokio::task::spawn_blocking({
-            let blob_service = self.blob_service.clone();
-            let directory_service = self.directory_service.clone();
-            move || -> io::Result<_> {
-                // Wrap the reader once more, so we can calculate NarSize and NarHash
-                let mut sync_r = io::BufReader::new(NarReader::from(sync_r));
-                let root_node = crate::nar::read_nar(&mut sync_r, blob_service, directory_service)?;
-
-                let (_, nar_hash, nar_size) = sync_r.into_inner().into_inner();
-
-                Ok((root_node, nar_hash, nar_size))
-            }
-        })
+        let mut nar_hash = sha2::Sha256::new();
+        let mut nar_size = 0;
+
+        // Assemble NarHash and NarSize as we read bytes.
+        let r = InspectReader::new(r, |b| {
+            nar_size += b.len() as u64;
+            nar_hash.write_all(b).unwrap();
+        });
+
+        // HACK: InspectReader doesn't implement AsyncBufRead, but neither do our decompressors.
+        let mut r = BufReader::new(r);
+
+        let root_node = crate::nar::ingest_nar(
+            self.blob_service.clone(),
+            self.directory_service.clone(),
+            &mut r,
+        )
         .await
-        .unwrap();
-
-        match res {
-            Ok((root_node, nar_hash, nar_size)) => {
-                // ensure the ingested narhash and narsize do actually match.
-                if narinfo.nar_size != nar_size {
-                    warn!(
-                        narinfo.nar_size = narinfo.nar_size,
-                        http.nar_size = nar_size,
-                        "NARSize mismatch"
-                    );
-                    Err(io::Error::new(
-                        io::ErrorKind::InvalidData,
-                        "NarSize mismatch".to_string(),
-                    ))?;
-                }
-                if narinfo.nar_hash != nar_hash {
-                    warn!(
-                        narinfo.nar_hash = %NixHash::Sha256(narinfo.nar_hash),
-                        http.nar_hash = %NixHash::Sha256(nar_hash),
-                        "NarHash mismatch"
-                    );
-                    Err(io::Error::new(
-                        io::ErrorKind::InvalidData,
-                        "NarHash mismatch".to_string(),
-                    ))?;
-                }
-
-                Ok(Some(PathInfo {
-                    node: Some(castorepb::Node {
-                        // set the name of the root node to the digest-name of the store path.
-                        node: Some(
-                            root_node.rename(narinfo.store_path.to_string().to_owned().into()),
-                        ),
-                    }),
-                    references: pathinfo.references,
-                    narinfo: pathinfo.narinfo,
-                }))
-            }
-            Err(e) => Err(e.into()),
+        .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
+
+        // ensure the ingested narhash and narsize do actually match.
+        if narinfo.nar_size != nar_size {
+            warn!(
+                narinfo.nar_size = narinfo.nar_size,
+                http.nar_size = nar_size,
+                "NARSize mismatch"
+            );
+            Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                "NarSize mismatch".to_string(),
+            ))?;
         }
+        let nar_hash: [u8; 32] = nar_hash.finalize().into();
+        if narinfo.nar_hash != nar_hash {
+            warn!(
+                narinfo.nar_hash = %NixHash::Sha256(narinfo.nar_hash),
+                http.nar_hash = %NixHash::Sha256(nar_hash),
+                "NarHash mismatch"
+            );
+            Err(io::Error::new(
+                io::ErrorKind::InvalidData,
+                "NarHash mismatch".to_string(),
+            ))?;
+        }
+
+        Ok(Some(PathInfo {
+            node: Some(castorepb::Node {
+                // set the name of the root node to the digest-name of the store path.
+                node: Some(root_node.rename(narinfo.store_path.to_string().to_owned().into())),
+            }),
+            references: pathinfo.references,
+            narinfo: pathinfo.narinfo,
+        }))
     }
 
     #[instrument(skip_all, fields(path_info=?_path_info))]
@@ -277,38 +276,3 @@ where
         }))
     }
 }
-
-/// Small helper reader implementing [std::io::Read].
-/// It can be used to wrap another reader, counts the number of bytes read
-/// and the sha256 digest of the contents.
-struct NarReader<R: Read> {
-    r: R,
-
-    sha256: sha2::Sha256,
-    bytes_read: u64,
-}
-
-impl<R: Read> NarReader<R> {
-    pub fn from(inner: R) -> Self {
-        Self {
-            r: inner,
-            sha256: Sha256::new(),
-            bytes_read: 0,
-        }
-    }
-
-    /// Returns the (remaining) inner reader, the sha256 digest and the number of bytes read.
-    pub fn into_inner(self) -> (R, [u8; 32], u64) {
-        (self.r, self.sha256.finalize_fixed().into(), self.bytes_read)
-    }
-}
-
-impl<R: Read> Read for NarReader<R> {
-    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
-        self.r.read(buf).map(|n| {
-            self.bytes_read += n as u64;
-            self.sha256.write_all(&buf[..n]).unwrap();
-            n
-        })
-    }
-}