about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/mod.rs')
-rw-r--r--tvix/nix-compat/src/wire/bytes/mod.rs229
1 files changed, 229 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs
new file mode 100644
index 0000000000..740a7ebfd0
--- /dev/null
+++ b/tvix/nix-compat/src/wire/bytes/mod.rs
@@ -0,0 +1,229 @@
+use std::{
+    io::{Error, ErrorKind},
+    ops::RangeInclusive,
+};
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
+
+mod reader;
+pub use reader::BytesReader;
+mod writer;
+pub use writer::BytesWriter;
+
+/// 8 null bytes, used to write out padding.
+const EMPTY_BYTES: &[u8; 8] = &[0u8; 8];
+
+/// The length of the size field, in bytes is always 8.
+const LEN_SIZE: usize = 8;
+
+#[allow(dead_code)]
+/// Read a "bytes wire packet" from the AsyncRead.
+/// Rejects reading more than `allowed_size` bytes of payload.
+///
+/// The packet is made up of three parts:
+/// - a length header, u64, LE-encoded
+/// - the payload itself
+/// - null bytes to the next 8 byte boundary
+///
+/// Ensures the payload size fits into the `allowed_size` passed,
+/// and that the padding is actual null bytes.
+///
+/// On success, the returned `Vec<u8>` only contains the payload itself.
+/// On failure (for example if a too large byte packet was sent), the reader
+/// becomes unusable.
+///
+/// This buffers the entire payload into memory,
+/// a streaming version is available at [crate::wire::bytes::BytesReader].
+pub async fn read_bytes<R>(
+    r: &mut R,
+    allowed_size: RangeInclusive<usize>,
+) -> std::io::Result<Vec<u8>>
+where
+    R: AsyncReadExt + Unpin,
+{
+    // read the length field
+    let len = r.read_u64_le().await?;
+    let len: usize = len
+        .try_into()
+        .ok()
+        .filter(|len| allowed_size.contains(len))
+        .ok_or_else(|| {
+            std::io::Error::new(
+                std::io::ErrorKind::InvalidData,
+                "signalled package size not in allowed range",
+            )
+        })?;
+
+    // calculate the total length, including padding.
+    // byte packets are padded to 8 byte blocks each.
+    let padded_len = padding_len(len as u64) as u64 + (len as u64);
+    let mut limited_reader = r.take(padded_len);
+
+    let mut buf = Vec::new();
+
+    let s = limited_reader.read_to_end(&mut buf).await?;
+
+    // make sure we got exactly the number of bytes, and not less.
+    if s as u64 != padded_len {
+        return Err(std::io::ErrorKind::UnexpectedEof.into());
+    }
+
+    let (_content, padding) = buf.split_at(len);
+
+    // ensure the padding is all zeroes.
+    if !padding.iter().all(|e| *e == b'\0') {
+        return Err(std::io::Error::new(
+            std::io::ErrorKind::InvalidData,
+            "padding is not all zeroes",
+        ));
+    }
+
+    // return the data without the padding
+    buf.truncate(len);
+    Ok(buf)
+}
+
+/// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string.
+/// Internally uses [read_bytes].
+/// Rejects reading more than `allowed_size` bytes of payload.
+pub async fn read_string<R>(
+    r: &mut R,
+    allowed_size: RangeInclusive<usize>,
+) -> std::io::Result<String>
+where
+    R: AsyncReadExt + Unpin,
+{
+    let bytes = read_bytes(r, allowed_size).await?;
+    String::from_utf8(bytes).map_err(|e| Error::new(ErrorKind::InvalidData, e))
+}
+
+/// Writes a "bytes wire packet" to a (hopefully buffered) [AsyncWriteExt].
+///
+/// Accepts anything implementing AsRef<[u8]> as payload.
+///
+/// See [read_bytes] for a description of the format.
+///
+/// Note: if performance matters to you, make sure your
+/// [AsyncWriteExt] handle is buffered. This function is quite
+/// write-intesive.
+pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>(
+    w: &mut W,
+    b: B,
+) -> std::io::Result<()> {
+    // write the size packet.
+    w.write_u64_le(b.as_ref().len() as u64).await?;
+
+    // write the payload
+    w.write_all(b.as_ref()).await?;
+
+    // write padding if needed
+    let padding_len = padding_len(b.as_ref().len() as u64) as usize;
+    if padding_len != 0 {
+        w.write_all(&EMPTY_BYTES[..padding_len]).await?;
+    }
+    Ok(())
+}
+
+/// Computes the number of bytes we should add to len (a length in
+/// bytes) to be aligned on 64 bits (8 bytes).
+fn padding_len(len: u64) -> u8 {
+    let aligned = len.wrapping_add(7) & !7;
+    aligned.wrapping_sub(len) as u8
+}
+
+#[cfg(test)]
+mod tests {
+    use tokio_test::{assert_ok, io::Builder};
+
+    use super::*;
+    use hex_literal::hex;
+
+    /// The maximum length of bytes packets we're willing to accept in the test
+    /// cases.
+    const MAX_LEN: usize = 1024;
+
+    #[tokio::test]
+    async fn test_read_8_bytes() {
+        let mut mock = Builder::new()
+            .read(&8u64.to_le_bytes())
+            .read(&12345678u64.to_le_bytes())
+            .build();
+
+        assert_eq!(
+            &12345678u64.to_le_bytes(),
+            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_read_9_bytes() {
+        let mut mock = Builder::new()
+            .read(&9u64.to_le_bytes())
+            .read(&hex!("01020304050607080900000000000000"))
+            .build();
+
+        assert_eq!(
+            hex!("010203040506070809"),
+            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_read_0_bytes() {
+        // A empty byte packet is essentially just the 0 length field.
+        // No data is read, and there's zero padding.
+        let mut mock = Builder::new().read(&0u64.to_le_bytes()).build();
+
+        assert_eq!(
+            hex!(""),
+            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
+        );
+    }
+
+    #[tokio::test]
+    /// Ensure we don't read any further than the size field if the length
+    /// doesn't match the range we want to accept.
+    async fn test_read_reject_too_large() {
+        let mut mock = Builder::new().read(&100u64.to_le_bytes()).build();
+
+        read_bytes(&mut mock, 10..=10)
+            .await
+            .expect_err("expect this to fail");
+    }
+
+    #[tokio::test]
+    async fn test_write_bytes_no_padding() {
+        let input = hex!("6478696f34657661");
+        let len = input.len() as u64;
+        let mut mock = Builder::new()
+            .write(&len.to_le_bytes())
+            .write(&input)
+            .build();
+        assert_ok!(write_bytes(&mut mock, &input).await)
+    }
+    #[tokio::test]
+    async fn test_write_bytes_with_padding() {
+        let input = hex!("322e332e3137");
+        let len = input.len() as u64;
+        let mut mock = Builder::new()
+            .write(&len.to_le_bytes())
+            .write(&hex!("322e332e31370000"))
+            .build();
+        assert_ok!(write_bytes(&mut mock, &input).await)
+    }
+
+    #[tokio::test]
+    async fn test_write_string() {
+        let input = "Hello, World!";
+        let len = input.len() as u64;
+        let mut mock = Builder::new()
+            .write(&len.to_le_bytes())
+            .write(&hex!("48656c6c6f2c20576f726c6421000000"))
+            .build();
+        assert_ok!(write_bytes(&mut mock, &input).await)
+    }
+
+    #[test]
+    fn padding_len_u64_max() {
+        assert_eq!(padding_len(u64::MAX), 1);
+    }
+}