about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes/reader
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/reader')
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader/mod.rs75
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader/trailer.rs12
2 files changed, 72 insertions, 15 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
index cd45f78a0c84..4a8cfd1f6599 100644
--- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
@@ -1,6 +1,7 @@
 use std::{
     future::Future,
     io,
+    num::NonZeroU64,
     ops::RangeBounds,
     pin::Pin,
     task::{self, ready, Poll},
@@ -33,14 +34,26 @@ pub struct BytesReader<R, T: Tag = Pad> {
     state: State<R, T>,
 }
 
+/// Split the `user_len` into `body_len` and `tail_len`, which are respectively
+/// the non-terminal 8-byte blocks, and the ≤8 bytes of user data contained in
+/// the trailer block.
+#[inline(always)]
+fn split_user_len(user_len: NonZeroU64) -> (u64, u8) {
+    let n = user_len.get() - 1;
+    let body_len = n & !7;
+    let tail_len = (n & 7) as u8 + 1;
+    (body_len, tail_len)
+}
+
 #[derive(Debug)]
 enum State<R, T: Tag> {
     /// Full 8-byte blocks are being read and released to the caller.
+    /// NOTE: The final 8-byte block is *always* part of the trailer.
     Body {
         reader: Option<R>,
         consumed: u64,
         /// The total length of all user data contained in both the body and trailer.
-        user_len: u64,
+        user_len: NonZeroU64,
     },
     /// The trailer is in the process of being read.
     ReadTrailer(ReadTrailer<R, T>),
@@ -76,10 +89,16 @@ where
         }
 
         Ok(Self {
-            state: State::Body {
-                reader: Some(reader),
-                consumed: 0,
-                user_len: size,
+            state: match NonZeroU64::new(size) {
+                Some(size) => State::Body {
+                    reader: Some(reader),
+                    consumed: 0,
+                    user_len: size,
+                },
+                None => State::ReleaseTrailer {
+                    consumed: 0,
+                    data: read_trailer::<R, T>(reader, 0).await?,
+                },
             },
         })
     }
@@ -96,7 +115,7 @@ where
         match self.state {
             State::Body {
                 consumed, user_len, ..
-            } => user_len - consumed,
+            } => user_len.get() - consumed,
             State::ReadTrailer(ref fut) => fut.len() as u64,
             State::ReleaseTrailer { consumed, ref data } => data.len() as u64 - consumed as u64,
         }
@@ -119,13 +138,12 @@ impl<R: AsyncRead + Unpin, T: Tag> AsyncRead for BytesReader<R, T> {
                     consumed,
                     user_len,
                 } => {
-                    let body_len = *user_len & !7;
+                    let (body_len, tail_len) = split_user_len(*user_len);
                     let remaining = body_len - *consumed;
 
                     let reader = if remaining == 0 {
                         let reader = reader.take().unwrap();
-                        let user_len = (*user_len & 7) as u8;
-                        *this = State::ReadTrailer(read_trailer(reader, user_len));
+                        *this = State::ReadTrailer(read_trailer(reader, tail_len));
                         continue;
                     } else {
                         reader.as_mut().unwrap()
@@ -277,6 +295,45 @@ mod tests {
         );
     }
 
+    /// Read the trailer immediately if there is no payload.
+    #[tokio::test]
+    async fn read_trailer_immediately() {
+        use crate::nar::wire::PadPar;
+
+        let mut mock = Builder::new()
+            .read(&[0; 8])
+            .read(&PadPar::PATTERN[8..])
+            .build();
+
+        BytesReader::<_, PadPar>::new_internal(&mut mock, ..)
+            .await
+            .unwrap();
+
+        // The mock reader will panic if dropped without reading all data.
+    }
+
+    /// Read the trailer even if we only read the exact payload size.
+    #[tokio::test]
+    async fn read_exact_trailer() {
+        use crate::nar::wire::PadPar;
+
+        let mut mock = Builder::new()
+            .read(&16u64.to_le_bytes())
+            .read(&[0x55; 16])
+            .read(&PadPar::PATTERN[8..])
+            .build();
+
+        let mut reader = BytesReader::<_, PadPar>::new_internal(&mut mock, ..)
+            .await
+            .unwrap();
+
+        let mut buf = [0; 16];
+        reader.read_exact(&mut buf).await.unwrap();
+        assert_eq!(buf, [0x55; 16]);
+
+        // The mock reader will panic if dropped without reading all data.
+    }
+
     /// Fail if the padding is not all zeroes
     #[tokio::test]
     async fn read_fail_if_nonzero_padding() {
diff --git a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
index 82aa2a228095..3a5bb75e7103 100644
--- a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
@@ -9,11 +9,11 @@ use std::{
 
 use tokio::io::{self, AsyncRead, ReadBuf};
 
-/// Trailer represents up to 7 bytes of data read as part of the trailer block(s)
+/// Trailer represents up to 8 bytes of data read as part of the trailer block(s)
 #[derive(Debug)]
 pub(crate) struct Trailer {
     data_len: u8,
-    buf: [u8; 7],
+    buf: [u8; 8],
 }
 
 impl Deref for Trailer {
@@ -28,7 +28,7 @@ impl Deref for Trailer {
 pub(crate) trait Tag {
     /// The expected suffix
     ///
-    /// The first 7 bytes may be ignored, and it must be an 8-byte aligned size.
+    /// The first 8 bytes may be ignored, and it must be an 8-byte aligned size.
     const PATTERN: &'static [u8];
 
     /// Suitably sized buffer for reading [Self::PATTERN]
@@ -67,7 +67,7 @@ pub(crate) fn read_trailer<R: AsyncRead + Unpin, T: Tag>(
     reader: R,
     data_len: u8,
 ) -> ReadTrailer<R, T> {
-    assert!(data_len < 8, "payload in trailer must be less than 8 bytes");
+    assert!(data_len <= 8, "payload in trailer must be <= 8 bytes");
 
     let buf = T::make_buf();
     assert_eq!(buf.as_ref().len(), T::PATTERN.len());
@@ -108,8 +108,8 @@ impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
             }
 
             if this.filled as usize == T::PATTERN.len() {
-                let mut buf = [0; 7];
-                buf.copy_from_slice(&this.buf.as_ref()[..7]);
+                let mut buf = [0; 8];
+                buf.copy_from_slice(&this.buf.as_ref()[..8]);
 
                 return Ok(Trailer {
                     data_len: this.data_len,