about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/mod.rs')
-rw-r--r--tvix/nix-compat/src/wire/bytes/mod.rs169
1 files changed, 99 insertions, 70 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs
index 9ec8b3fa04..2ed071e379 100644
--- a/tvix/nix-compat/src/wire/bytes/mod.rs
+++ b/tvix/nix-compat/src/wire/bytes/mod.rs
@@ -1,23 +1,21 @@
 use std::{
     io::{Error, ErrorKind},
-    ops::RangeBounds,
+    mem::MaybeUninit,
+    ops::RangeInclusive,
 };
-use tokio::io::{AsyncReadExt, AsyncWriteExt};
+use tokio::io::{self, AsyncReadExt, AsyncWriteExt, ReadBuf};
 
-mod reader;
+pub(crate) mod reader;
 pub use reader::BytesReader;
 mod writer;
 pub use writer::BytesWriter;
 
-use super::primitive;
-
 /// 8 null bytes, used to write out padding.
 const EMPTY_BYTES: &[u8; 8] = &[0u8; 8];
 
 /// The length of the size field, in bytes is always 8.
 const LEN_SIZE: usize = 8;
 
-#[allow(dead_code)]
 /// Read a "bytes wire packet" from the AsyncRead.
 /// Rejects reading more than `allowed_size` bytes of payload.
 ///
@@ -35,24 +33,29 @@ const LEN_SIZE: usize = 8;
 ///
 /// This buffers the entire payload into memory,
 /// a streaming version is available at [crate::wire::bytes::BytesReader].
-pub async fn read_bytes<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<Vec<u8>>
+pub async fn read_bytes<R: ?Sized>(
+    r: &mut R,
+    allowed_size: RangeInclusive<usize>,
+) -> io::Result<Vec<u8>>
 where
     R: AsyncReadExt + Unpin,
-    S: RangeBounds<u64>,
 {
     // read the length field
-    let len = primitive::read_u64(r).await?;
-
-    if !allowed_size.contains(&len) {
-        return Err(std::io::Error::new(
-            std::io::ErrorKind::InvalidData,
-            "signalled package size not in allowed range",
-        ));
-    }
+    let len = r.read_u64_le().await?;
+    let len: usize = len
+        .try_into()
+        .ok()
+        .filter(|len| allowed_size.contains(len))
+        .ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidData,
+                "signalled package size not in allowed range",
+            )
+        })?;
 
     // calculate the total length, including padding.
     // byte packets are padded to 8 byte blocks each.
-    let padded_len = padding_len(len) as u64 + (len as u64);
+    let padded_len = padding_len(len as u64) as u64 + (len as u64);
     let mut limited_reader = r.take(padded_len);
 
     let mut buf = Vec::new();
@@ -61,34 +64,87 @@ where
 
     // make sure we got exactly the number of bytes, and not less.
     if s as u64 != padded_len {
-        return Err(std::io::Error::new(
-            std::io::ErrorKind::InvalidData,
-            "got less bytes than expected",
-        ));
+        return Err(io::ErrorKind::UnexpectedEof.into());
     }
 
-    let (_content, padding) = buf.split_at(len as usize);
+    let (_content, padding) = buf.split_at(len);
 
     // ensure the padding is all zeroes.
-    if !padding.iter().all(|e| *e == b'\0') {
-        return Err(std::io::Error::new(
-            std::io::ErrorKind::InvalidData,
+    if padding.iter().any(|&b| b != 0) {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
             "padding is not all zeroes",
         ));
     }
 
     // return the data without the padding
-    buf.truncate(len as usize);
+    buf.truncate(len);
     Ok(buf)
 }
 
+pub(crate) async fn read_bytes_buf<'a, const N: usize, R: ?Sized>(
+    reader: &mut R,
+    buf: &'a mut [MaybeUninit<u8>; N],
+    allowed_size: RangeInclusive<usize>,
+) -> io::Result<&'a [u8]>
+where
+    R: AsyncReadExt + Unpin,
+{
+    assert_eq!(N % 8, 0);
+    assert!(*allowed_size.end() <= N);
+
+    let len = reader.read_u64_le().await?;
+    let len: usize = len
+        .try_into()
+        .ok()
+        .filter(|len| allowed_size.contains(len))
+        .ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidData,
+                "signalled package size not in allowed range",
+            )
+        })?;
+
+    let buf_len = (len + 7) & !7;
+    let buf = {
+        let mut read_buf = ReadBuf::uninit(&mut buf[..buf_len]);
+
+        while read_buf.filled().len() < buf_len {
+            reader.read_buf(&mut read_buf).await?;
+        }
+
+        // ReadBuf::filled does not pass the underlying buffer's lifetime through,
+        // so we must make a trip to hell.
+        //
+        // SAFETY: `read_buf` is filled up to `buf_len`, and we verify that it is
+        // still pointing at the same underlying buffer.
+        unsafe {
+            assert_eq!(read_buf.filled().as_ptr(), buf.as_ptr() as *const u8);
+            assume_init_bytes(&buf[..buf_len])
+        }
+    };
+
+    if buf[len..buf_len].iter().any(|&b| b != 0) {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            "padding is not all zeroes",
+        ));
+    }
+
+    Ok(&buf[..len])
+}
+
+/// SAFETY: The bytes have to actually be initialized.
+unsafe fn assume_init_bytes(slice: &[MaybeUninit<u8>]) -> &[u8] {
+    &*(slice as *const [MaybeUninit<u8>] as *const [u8])
+}
+
 /// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string.
 /// Internally uses [read_bytes].
 /// Rejects reading more than `allowed_size` bytes of payload.
-pub async fn read_string<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<String>
+pub async fn read_string<R>(r: &mut R, allowed_size: RangeInclusive<usize>) -> io::Result<String>
 where
     R: AsyncReadExt + Unpin,
-    S: RangeBounds<u64>,
 {
     let bytes = read_bytes(r, allowed_size).await?;
     String::from_utf8(bytes).map_err(|e| Error::new(ErrorKind::InvalidData, e))
@@ -106,9 +162,9 @@ where
 pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>(
     w: &mut W,
     b: B,
-) -> std::io::Result<()> {
+) -> io::Result<()> {
     // write the size packet.
-    primitive::write_u64(w, b.as_ref().len() as u64).await?;
+    w.write_u64_le(b.as_ref().len() as u64).await?;
 
     // write the payload
     w.write_all(b.as_ref()).await?;
@@ -122,33 +178,10 @@ pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>(
 }
 
 /// Computes the number of bytes we should add to len (a length in
-/// bytes) to be alined on 64 bits (8 bytes).
+/// bytes) to be aligned on 64 bits (8 bytes).
 fn padding_len(len: u64) -> u8 {
-    let modulo = len % 8;
-    if modulo == 0 {
-        0
-    } else {
-        8 - modulo as u8
-    }
-}
-
-/// Models the position inside a "bytes wire packet" that the reader or writer
-/// is in.
-/// It can be in three different stages, inside size, payload or padding fields.
-/// The number tracks the number of bytes written inside the specific field.
-/// There shall be no ambiguous states, at the end of a stage we immediately
-/// move to the beginning of the next one:
-/// - Size(LEN_SIZE) must be expressed as Payload(0)
-/// - Payload(self.payload_len) must be expressed as Padding(0)
-/// There's one exception - Size(LEN_SIZE) in the reader represents a failure
-/// state we enter in case the allowed size doesn't match the allowed range.
-///
-/// Padding(padding_len) means we're at the end of the bytes wire packet.
-#[derive(Clone, Debug, PartialEq, Eq)]
-enum BytesPacketPosition {
-    Size(usize),
-    Payload(u64),
-    Padding(usize),
+    let aligned = len.wrapping_add(7) & !7;
+    aligned.wrapping_sub(len) as u8
 }
 
 #[cfg(test)]
@@ -160,7 +193,7 @@ mod tests {
 
     /// The maximum length of bytes packets we're willing to accept in the test
     /// cases.
-    const MAX_LEN: u64 = 1024;
+    const MAX_LEN: usize = 1024;
 
     #[tokio::test]
     async fn test_read_8_bytes() {
@@ -171,10 +204,7 @@ mod tests {
 
         assert_eq!(
             &12345678u64.to_le_bytes(),
-            read_bytes(&mut mock, 0u64..MAX_LEN)
-                .await
-                .unwrap()
-                .as_slice()
+            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
         );
     }
 
@@ -187,10 +217,7 @@ mod tests {
 
         assert_eq!(
             hex!("010203040506070809"),
-            read_bytes(&mut mock, 0u64..MAX_LEN)
-                .await
-                .unwrap()
-                .as_slice()
+            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
         );
     }
 
@@ -202,10 +229,7 @@ mod tests {
 
         assert_eq!(
             hex!(""),
-            read_bytes(&mut mock, 0u64..MAX_LEN)
-                .await
-                .unwrap()
-                .as_slice()
+            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
         );
     }
 
@@ -215,7 +239,7 @@ mod tests {
     async fn test_read_reject_too_large() {
         let mut mock = Builder::new().read(&100u64.to_le_bytes()).build();
 
-        read_bytes(&mut mock, 10..10)
+        read_bytes(&mut mock, 10..=10)
             .await
             .expect_err("expect this to fail");
     }
@@ -251,4 +275,9 @@ mod tests {
             .build();
         assert_ok!(write_bytes(&mut mock, &input).await)
     }
+
+    #[test]
+    fn padding_len_u64_max() {
+        assert_eq!(padding_len(u64::MAX), 1);
+    }
 }