about summary refs log tree commit diff
path: root/tvix/nix-compat/src
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src')
-rw-r--r--tvix/nix-compat/src/wire/bytes.rs130
-rw-r--r--tvix/nix-compat/src/wire/mod.rs3
2 files changed, 133 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/wire/bytes.rs b/tvix/nix-compat/src/wire/bytes.rs
new file mode 100644
index 000000000000..c720f912eee2
--- /dev/null
+++ b/tvix/nix-compat/src/wire/bytes.rs
@@ -0,0 +1,130 @@
+use std::ops::RangeBounds;
+
+use tokio::io::AsyncReadExt;
+
+use super::primitive;
+
+#[allow(dead_code)]
+/// Read a limited number of bytes from the AsyncRead.
+/// Rejects reading more than `allowed_size` bytes of payload.
+/// Internally takes care of dealing with the padding, so the returned Vec<u8>
+/// only contains the payload.
+/// This always buffers the entire contents into memory, we'll add a streaming
+/// version later.
+pub async fn read_bytes<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<Vec<u8>>
+where
+    R: AsyncReadExt + Unpin,
+    S: RangeBounds<u64>,
+{
+    // read the length field
+    let len = primitive::read_u64(r).await?;
+
+    if !allowed_size.contains(&len) {
+        return Err(std::io::Error::new(
+            std::io::ErrorKind::InvalidData,
+            "signalled package size not in allowed range",
+        ));
+    }
+
+    // calculate the total length, including padding.
+    // byte packets are padded to 8 byte blocks each.
+    let padded_len = if len % 8 == 0 {
+        len
+    } else {
+        len + (8 - len % 8)
+    };
+
+    let mut limited_reader = r.take(padded_len);
+
+    let mut buf = Vec::new();
+
+    let s = limited_reader.read_to_end(&mut buf).await?;
+
+    // make sure we got exactly the number of bytes, and not less.
+    if s as u64 != padded_len {
+        return Err(std::io::Error::new(
+            std::io::ErrorKind::InvalidData,
+            "got less bytes than expected",
+        ));
+    }
+
+    let (_content, padding) = buf.split_at(len as usize);
+
+    // ensure the padding is all zeroes.
+    if !padding.iter().all(|e| *e == b'\0') {
+        return Err(std::io::Error::new(
+            std::io::ErrorKind::InvalidData,
+            "padding is not all zeroes",
+        ));
+    }
+
+    // return the data without the padding
+    buf.truncate(len as usize);
+    Ok(buf)
+}
+
+#[allow(dead_code)]
+/// Read an unlimited number of bytes from the AsyncRead.
+/// Note this can exhaust memory.
+/// Internally uses [read_bytes], which takes care of dealing with the padding,
+/// so the returned Vec<u8> only contains the payload.
+pub async fn read_bytes_unchecked<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<Vec<u8>> {
+    read_bytes(r, 0u64..).await
+}
+
+#[cfg(test)]
+mod tests {
+    use tokio_test::io::Builder;
+
+    use super::*;
+    use hex_literal::hex;
+
+    #[tokio::test]
+    async fn test_read_8_bytes_unchecked() {
+        let mut mock = Builder::new()
+            .read(&8u64.to_le_bytes())
+            .read(&12345678u64.to_le_bytes())
+            .build();
+
+        assert_eq!(
+            &12345678u64.to_le_bytes(),
+            read_bytes_unchecked(&mut mock).await.unwrap().as_slice()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_read_9_bytes_unchecked() {
+        let mut mock = Builder::new()
+            .read(&9u64.to_le_bytes())
+            .read(&hex!("01020304050607080900000000000000"))
+            .build();
+
+        assert_eq!(
+            hex!("010203040506070809"),
+            read_bytes_unchecked(&mut mock).await.unwrap().as_slice()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_read_0_bytes_unchecked() {
+        // A empty byte packet is essentially just the 0 length field.
+        // No data is read, and there's zero padding.
+        let mut mock = Builder::new().read(&0u64.to_le_bytes()).build();
+
+        assert_eq!(
+            hex!(""),
+            read_bytes_unchecked(&mut mock).await.unwrap().as_slice()
+        );
+    }
+
+    #[tokio::test]
+    /// Ensure we don't read any further than the size field if the length
+    /// doesn't match the range we want to accept.
+    async fn test_reject_too_large() {
+        let mut mock = Builder::new().read(&100u64.to_le_bytes()).build();
+
+        read_bytes(&mut mock, 10..10)
+            .await
+            .expect_err("expect this to fail");
+    }
+}
diff --git a/tvix/nix-compat/src/wire/mod.rs b/tvix/nix-compat/src/wire/mod.rs
index e0b184c78aec..9444ebbcfe59 100644
--- a/tvix/nix-compat/src/wire/mod.rs
+++ b/tvix/nix-compat/src/wire/mod.rs
@@ -2,4 +2,7 @@
 //! nix-daemon protocol as well as in the NAR format.
 
 #[cfg(feature = "async")]
+pub mod bytes;
+
+#[cfg(feature = "async")]
 pub mod primitive;