about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader.rs41
1 files changed, 41 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader.rs b/tvix/nix-compat/src/wire/bytes/reader.rs
index 4239b40fec2d..18a8c6c686f0 100644
--- a/tvix/nix-compat/src/wire/bytes/reader.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader.rs
@@ -61,6 +61,20 @@ where
             state: BytesPacketPosition::Size(0),
         }
     }
+
+    /// Construct a new BytesReader with a known, and already-read size.
+    pub fn with_size(r: R, size: u64) -> Self {
+        Self {
+            inner: r,
+            allowed_size: size..=size,
+            payload_size: u64::to_le_bytes(size),
+            state: if size != 0 {
+                BytesPacketPosition::Payload(0)
+            } else {
+                BytesPacketPosition::Padding(0)
+            },
+        }
+    }
 }
 /// Returns an error if the passed usize is 0.
 #[inline]
@@ -261,6 +275,33 @@ mod tests {
         assert_eq!(payload, &buf[..]);
     }
 
+    /// Read bytes packets of various length, and ensure read_to_end returns the
+    /// expected payload.
+    #[rstest]
+    #[case::empty(&[])] // empty bytes packet
+    #[case::size_1b(&[0xff])] // 1 bytes payload
+    #[case::size_8b(&hex!("0001020304050607"))] // 8 bytes payload (no padding)
+    #[case::size_9b( &hex!("000102030405060708"))] // 9 bytes payload (7 bytes padding)
+    #[case::size_1m(LARGE_PAYLOAD.as_slice())] // larger bytes packet
+    #[tokio::test]
+    async fn read_payload_correct_known(#[case] payload: &[u8]) {
+        let packet = produce_packet_bytes(payload).await;
+
+        let size = u64::from_le_bytes({
+            let mut buf = [0; 8];
+            buf.copy_from_slice(&packet[..8]);
+            buf
+        });
+
+        let mut mock = Builder::new().read(&packet[8..]).build();
+
+        let mut r = BytesReader::with_size(&mut mock, size);
+        let mut buf = Vec::new();
+        r.read_to_end(&mut buf).await.expect("must succeed");
+
+        assert_eq!(payload, &buf[..]);
+    }
+
     /// Fail if the bytes packet is larger than allowed
     #[tokio::test]
     async fn read_bigger_than_allowed_fail() {