about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/mod.rs')
-rw-r--r--tvix/nix-compat/src/wire/bytes/mod.rs61
1 files changed, 60 insertions, 1 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs
index db794b810f35..f48b5000a51d 100644
--- a/tvix/nix-compat/src/wire/bytes/mod.rs
+++ b/tvix/nix-compat/src/wire/bytes/mod.rs
@@ -1,8 +1,9 @@
 use std::{
     io::{Error, ErrorKind},
+    mem::MaybeUninit,
     ops::RangeInclusive,
 };
-use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
+use tokio::io::{self, AsyncReadExt, AsyncWriteExt, ReadBuf};
 
 pub(crate) mod reader;
 pub use reader::BytesReader;
@@ -81,6 +82,64 @@ where
     Ok(buf)
 }
 
+#[allow(dead_code)]
+pub(crate) async fn read_bytes_buf<'a, const N: usize, R: ?Sized>(
+    reader: &mut R,
+    buf: &'a mut [MaybeUninit<u8>; N],
+    allowed_size: RangeInclusive<usize>,
+) -> io::Result<&'a [u8]>
+where
+    R: AsyncReadExt + Unpin,
+{
+    assert_eq!(N % 8, 0);
+    assert!(*allowed_size.end() <= N);
+
+    let len = reader.read_u64_le().await?;
+    let len: usize = len
+        .try_into()
+        .ok()
+        .filter(|len| allowed_size.contains(len))
+        .ok_or_else(|| {
+            io::Error::new(
+                io::ErrorKind::InvalidData,
+                "signalled package size not in allowed range",
+            )
+        })?;
+
+    let buf_len = (len + 7) & !7;
+    let buf = {
+        let mut read_buf = ReadBuf::uninit(&mut buf[..buf_len]);
+
+        while read_buf.filled().len() < buf_len {
+            reader.read_buf(&mut read_buf).await?;
+        }
+
+        // ReadBuf::filled does not pass the underlying buffer's lifetime through,
+        // so we must make a trip to hell.
+        //
+        // SAFETY: `read_buf` is filled up to `buf_len`, and we verify that it is
+        // still pointing at the same underlying buffer.
+        unsafe {
+            assert_eq!(read_buf.filled().as_ptr(), buf.as_ptr() as *const u8);
+            assume_init_bytes(&buf[..buf_len])
+        }
+    };
+
+    if buf[len..buf_len].iter().any(|&b| b != 0) {
+        return Err(io::Error::new(
+            io::ErrorKind::InvalidData,
+            "padding is not all zeroes",
+        ));
+    }
+
+    Ok(&buf[..len])
+}
+
+/// SAFETY: The bytes have to actually be initialized.
+unsafe fn assume_init_bytes(slice: &[MaybeUninit<u8>]) -> &[u8] {
+    &*(slice as *const [MaybeUninit<u8>] as *const [u8])
+}
+
 /// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string.
 /// Internally uses [read_bytes].
 /// Rejects reading more than `allowed_size` bytes of payload.