about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes')
-rw-r--r--tvix/nix-compat/src/wire/bytes/mod.rs254
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader.rs463
-rw-r--r--tvix/nix-compat/src/wire/bytes/writer.rs521
3 files changed, 1238 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs
new file mode 100644
index 000000000000..9487536eb720
--- /dev/null
+++ b/tvix/nix-compat/src/wire/bytes/mod.rs
@@ -0,0 +1,254 @@
+use std::{
+    io::{Error, ErrorKind},
+    ops::RangeBounds,
+};
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
+
+mod reader;
+pub use reader::BytesReader;
+mod writer;
+pub use writer::BytesWriter;
+
+use super::primitive;
+
+/// 8 null bytes, used to write out padding.
+const EMPTY_BYTES: &[u8; 8] = &[0u8; 8];
+
+/// The length of the size field, in bytes is always 8.
+const LEN_SIZE: usize = 8;
+
+#[allow(dead_code)]
+/// Read a "bytes wire packet" from the AsyncRead.
+/// Rejects reading more than `allowed_size` bytes of payload.
+///
+/// The packet is made up of three parts:
+/// - a length header, u64, LE-encoded
+/// - the payload itself
+/// - null bytes to the next 8 byte boundary
+///
+/// Ensures the payload size fits into the `allowed_size` passed,
+/// and that the padding is actual null bytes.
+///
+/// On success, the returned `Vec<u8>` only contains the payload itself.
+/// On failure (for example if a too large byte packet was sent), the reader
+/// becomes unusable.
+///
+/// This buffers the entire payload into memory, a streaming version will be
+/// added later.
+pub async fn read_bytes<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<Vec<u8>>
+where
+    R: AsyncReadExt + Unpin,
+    S: RangeBounds<u64>,
+{
+    // read the length field
+    let len = primitive::read_u64(r).await?;
+
+    if !allowed_size.contains(&len) {
+        return Err(std::io::Error::new(
+            std::io::ErrorKind::InvalidData,
+            "signalled package size not in allowed range",
+        ));
+    }
+
+    // calculate the total length, including padding.
+    // byte packets are padded to 8 byte blocks each.
+    let padded_len = padding_len(len) as u64 + (len as u64);
+    let mut limited_reader = r.take(padded_len);
+
+    let mut buf = Vec::new();
+
+    let s = limited_reader.read_to_end(&mut buf).await?;
+
+    // make sure we got exactly the number of bytes, and not less.
+    if s as u64 != padded_len {
+        return Err(std::io::Error::new(
+            std::io::ErrorKind::InvalidData,
+            "got less bytes than expected",
+        ));
+    }
+
+    let (_content, padding) = buf.split_at(len as usize);
+
+    // ensure the padding is all zeroes.
+    if !padding.iter().all(|e| *e == b'\0') {
+        return Err(std::io::Error::new(
+            std::io::ErrorKind::InvalidData,
+            "padding is not all zeroes",
+        ));
+    }
+
+    // return the data without the padding
+    buf.truncate(len as usize);
+    Ok(buf)
+}
+
+/// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string.
+/// Internally uses [read_bytes].
+/// Rejects reading more than `allowed_size` bytes of payload.
+pub async fn read_string<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<String>
+where
+    R: AsyncReadExt + Unpin,
+    S: RangeBounds<u64>,
+{
+    let bytes = read_bytes(r, allowed_size).await?;
+    String::from_utf8(bytes).map_err(|e| Error::new(ErrorKind::InvalidData, e))
+}
+
+/// Writes a "bytes wire packet" to a (hopefully buffered) [AsyncWriteExt].
+///
+/// Accepts anything implementing AsRef<[u8]> as payload.
+///
+/// See [read_bytes] for a description of the format.
+///
+/// Note: if performance matters to you, make sure your
+/// [AsyncWriteExt] handle is buffered. This function is quite
+/// write-intesive.
+pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>(
+    w: &mut W,
+    b: B,
+) -> std::io::Result<()> {
+    // write the size packet.
+    primitive::write_u64(w, b.as_ref().len() as u64).await?;
+
+    // write the payload
+    w.write_all(b.as_ref()).await?;
+
+    // write padding if needed
+    let padding_len = padding_len(b.as_ref().len() as u64) as usize;
+    if padding_len != 0 {
+        w.write_all(&EMPTY_BYTES[..padding_len]).await?;
+    }
+    Ok(())
+}
+
+/// Computes the number of bytes we should add to len (a length in
+/// bytes) to be alined on 64 bits (8 bytes).
+fn padding_len(len: u64) -> u8 {
+    let modulo = len % 8;
+    if modulo == 0 {
+        0
+    } else {
+        8 - modulo as u8
+    }
+}
+
+/// Models the position inside a "bytes wire packet" that the reader or writer
+/// is in.
+/// It can be in three different stages, inside size, payload or padding fields.
+/// The number tracks the number of bytes written inside the specific field.
+/// There shall be no ambiguous states, at the end of a stage we immediately
+/// move to the beginning of the next one:
+/// - Size(LEN_SIZE) must be expressed as Payload(0)
+/// - Payload(self.payload_len) must be expressed as Padding(0)
+/// There's one exception - Size(LEN_SIZE) in the reader represents a failure
+/// state we enter in case the allowed size doesn't match the allowed range.
+///
+/// Padding(padding_len) means we're at the end of the bytes wire packet.
+#[derive(Clone, Debug, PartialEq, Eq)]
+enum BytesPacketPosition {
+    Size(usize),
+    Payload(u64),
+    Padding(usize),
+}
+
+#[cfg(test)]
+mod tests {
+    use tokio_test::{assert_ok, io::Builder};
+
+    use super::*;
+    use hex_literal::hex;
+
+    /// The maximum length of bytes packets we're willing to accept in the test
+    /// cases.
+    const MAX_LEN: u64 = 1024;
+
+    #[tokio::test]
+    async fn test_read_8_bytes() {
+        let mut mock = Builder::new()
+            .read(&8u64.to_le_bytes())
+            .read(&12345678u64.to_le_bytes())
+            .build();
+
+        assert_eq!(
+            &12345678u64.to_le_bytes(),
+            read_bytes(&mut mock, 0u64..MAX_LEN)
+                .await
+                .unwrap()
+                .as_slice()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_read_9_bytes() {
+        let mut mock = Builder::new()
+            .read(&9u64.to_le_bytes())
+            .read(&hex!("01020304050607080900000000000000"))
+            .build();
+
+        assert_eq!(
+            hex!("010203040506070809"),
+            read_bytes(&mut mock, 0u64..MAX_LEN)
+                .await
+                .unwrap()
+                .as_slice()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_read_0_bytes() {
+        // A empty byte packet is essentially just the 0 length field.
+        // No data is read, and there's zero padding.
+        let mut mock = Builder::new().read(&0u64.to_le_bytes()).build();
+
+        assert_eq!(
+            hex!(""),
+            read_bytes(&mut mock, 0u64..MAX_LEN)
+                .await
+                .unwrap()
+                .as_slice()
+        );
+    }
+
+    #[tokio::test]
+    /// Ensure we don't read any further than the size field if the length
+    /// doesn't match the range we want to accept.
+    async fn test_read_reject_too_large() {
+        let mut mock = Builder::new().read(&100u64.to_le_bytes()).build();
+
+        read_bytes(&mut mock, 10..10)
+            .await
+            .expect_err("expect this to fail");
+    }
+
+    #[tokio::test]
+    async fn test_write_bytes_no_padding() {
+        let input = hex!("6478696f34657661");
+        let len = input.len() as u64;
+        let mut mock = Builder::new()
+            .write(&len.to_le_bytes())
+            .write(&input)
+            .build();
+        assert_ok!(write_bytes(&mut mock, &input).await)
+    }
+    #[tokio::test]
+    async fn test_write_bytes_with_padding() {
+        let input = hex!("322e332e3137");
+        let len = input.len() as u64;
+        let mut mock = Builder::new()
+            .write(&len.to_le_bytes())
+            .write(&hex!("322e332e31370000"))
+            .build();
+        assert_ok!(write_bytes(&mut mock, &input).await)
+    }
+
+    #[tokio::test]
+    async fn test_write_string() {
+        let input = "Hello, World!";
+        let len = input.len() as u64;
+        let mut mock = Builder::new()
+            .write(&len.to_le_bytes())
+            .write(&hex!("48656c6c6f2c20576f726c6421000000"))
+            .build();
+        assert_ok!(write_bytes(&mut mock, &input).await)
+    }
+}
diff --git a/tvix/nix-compat/src/wire/bytes/reader.rs b/tvix/nix-compat/src/wire/bytes/reader.rs
new file mode 100644
index 000000000000..4c450b55db1a
--- /dev/null
+++ b/tvix/nix-compat/src/wire/bytes/reader.rs
@@ -0,0 +1,463 @@
+use pin_project_lite::pin_project;
+use std::{
+    ops::RangeBounds,
+    task::{ready, Poll},
+};
+use tokio::io::AsyncRead;
+
+use super::{padding_len, BytesPacketPosition, LEN_SIZE};
+
+pin_project! {
+    /// Reads a "bytes wire packet" from the underlying reader.
+    /// The format is the same as in [crate::wire::bytes::read_bytes],
+    /// however this structure provides a [AsyncRead] interface,
+    /// allowing to not having to pass around the entire payload in memory.
+    ///
+    /// After being constructed with the underlying reader and an allowed size,
+    /// subsequent requests to poll_read will return payload data until the end
+    /// of the packet is reached.
+    ///
+    /// Internally, it will first read over the size packet, filling payload_size,
+    /// ensuring it fits allowed_size, then return payload data.
+    /// It will only signal EOF (returning `Ok(())` without filling the buffer anymore)
+    /// when all padding has been successfully consumed too.
+    ///
+    /// This also means, it's important for a user to always read to the end,
+    /// and not just call read_exact - otherwise it might not skip over the
+    /// padding, and return garbage when reading the next packet.
+    ///
+    /// In case of an error due to size constraints, or in case of not reading
+    /// all the way to the end (and getting a EOF), the underlying reader is no
+    /// longer usable and might return garbage.
+    pub struct BytesReader<R, S>
+    where
+    R: AsyncRead,
+    S: RangeBounds<u64>,
+
+    {
+        #[pin]
+        inner: R,
+
+        allowed_size: S,
+        payload_size: [u8; 8],
+        state: BytesPacketPosition,
+    }
+}
+
+impl<R, S> BytesReader<R, S>
+where
+    R: AsyncRead + Unpin,
+    S: RangeBounds<u64>,
+{
+    /// Constructs a new BytesReader, using the underlying passed reader.
+    pub fn new(r: R, allowed_size: S) -> Self {
+        Self {
+            inner: r,
+            allowed_size,
+            payload_size: [0; 8],
+            state: BytesPacketPosition::Size(0),
+        }
+    }
+}
+/// Returns an error if the passed usize is 0.
+fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, std::io::Error> {
+    if bytes_read == 0 {
+        Err(std::io::Error::new(
+            std::io::ErrorKind::UnexpectedEof,
+            "underlying reader returned EOF",
+        ))
+    } else {
+        Ok(bytes_read)
+    }
+}
+
+impl<R, S> AsyncRead for BytesReader<R, S>
+where
+    R: AsyncRead,
+    S: RangeBounds<u64>,
+{
+    fn poll_read(
+        self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+        buf: &mut tokio::io::ReadBuf<'_>,
+    ) -> Poll<std::io::Result<()>> {
+        let mut this = self.project();
+
+        // Use a loop, so we can deal with (multiple) state transitions.
+        loop {
+            match *this.state {
+                BytesPacketPosition::Size(LEN_SIZE) => {
+                    // used in case an invalid size was signalled.
+                    Err(std::io::Error::new(
+                        std::io::ErrorKind::InvalidData,
+                        "signalled package size not in allowed range",
+                    ))?
+                }
+                BytesPacketPosition::Size(pos) => {
+                    // try to read more of the size field.
+                    // We wrap a BufRead around this.payload_size here, and set_filled.
+                    let mut read_buf = tokio::io::ReadBuf::new(this.payload_size);
+                    read_buf.advance(pos);
+                    ready!(this.inner.as_mut().poll_read(cx, &mut read_buf))?;
+
+                    ensure_nonzero_bytes_read(read_buf.filled().len() - pos)?;
+
+                    let total_size_read = read_buf.filled().len();
+                    if total_size_read == LEN_SIZE {
+                        // If the entire payload size was read, parse it
+                        let payload_size = u64::from_le_bytes(*this.payload_size);
+
+                        if !this.allowed_size.contains(&payload_size) {
+                            // If it's not in the allowed
+                            // range, transition to failure mode
+                            // `BytesPacketPosition::Size(LEN_SIZE)`, where only
+                            // an error is returned.
+                            *this.state = BytesPacketPosition::Size(LEN_SIZE)
+                        } else if payload_size == 0 {
+                            // If the payload size is 0, move on to reading padding directly.
+                            *this.state = BytesPacketPosition::Padding(0)
+                        } else {
+                            // Else, transition to reading the payload.
+                            *this.state = BytesPacketPosition::Payload(0)
+                        }
+                    } else {
+                        // If we still need to read more of payload size, update
+                        // our position in the state.
+                        *this.state = BytesPacketPosition::Size(total_size_read)
+                    }
+                }
+                BytesPacketPosition::Payload(pos) => {
+                    let signalled_size = u64::from_le_bytes(*this.payload_size);
+                    // We don't enter this match arm at all if we're expecting empty payload
+                    debug_assert!(signalled_size > 0, "signalled size must be larger than 0");
+
+                    // Read from the underlying reader into buf
+                    // We cap the ReadBuf to the size of the payload, as we
+                    // don't want to leak padding to the caller.
+                    let bytes_read = ensure_nonzero_bytes_read({
+                        // Reducing these two u64 to usize on 32bits is fine - we
+                        // only care about not reading too much, not too less.
+                        let mut limited_buf = buf.take((signalled_size - pos) as usize);
+                        ready!(this.inner.as_mut().poll_read(cx, &mut limited_buf))?;
+                        limited_buf.filled().len()
+                    })?;
+
+                    // SAFETY: we just did populate this, but through limited_buf.
+                    unsafe { buf.assume_init(bytes_read) }
+                    buf.advance(bytes_read);
+
+                    if pos + bytes_read as u64 == signalled_size {
+                        // If we now read all payload, transition to padding
+                        // state.
+                        *this.state = BytesPacketPosition::Padding(0);
+                    } else {
+                        // if we didn't read everything yet, update our position
+                        // in the state.
+                        *this.state = BytesPacketPosition::Payload(pos + bytes_read as u64);
+                    }
+
+                    // We return from poll_read here.
+                    // This is important, as any error (or even Pending) from
+                    // the underlying reader on the next read (be it padding or
+                    // payload) would require us to roll back buf, as generally
+                    // a AsyncRead::poll_read may not advance the buffer in case
+                    // of a nonsuccessful read.
+                    // It can't be misinterpreted as EOF, as we definitely *did*
+                    // write something into buf if we come to here (we pass
+                    // `ensure_nonzero_bytes_read`).
+                    return Ok(()).into();
+                }
+                BytesPacketPosition::Padding(pos) => {
+                    // Consume whatever padding is left, ensuring it's all null
+                    // bytes. Only return `Ready(Ok(()))` once we're past the
+                    // padding (or in cases where polling the inner reader
+                    // returns `Poll::Pending`).
+                    let signalled_size = u64::from_le_bytes(*this.payload_size);
+                    let total_padding_len = padding_len(signalled_size) as usize;
+
+                    let padding_len_remaining = total_padding_len - pos;
+                    if padding_len_remaining != 0 {
+                        // create a buffer only accepting the number of remaining padding bytes.
+                        let mut buf = [0; 8];
+                        let mut padding_buf = tokio::io::ReadBuf::new(&mut buf);
+                        let mut padding_buf = padding_buf.take(padding_len_remaining);
+
+                        // read into padding_buf.
+                        ready!(this.inner.as_mut().poll_read(cx, &mut padding_buf))?;
+                        let bytes_read = ensure_nonzero_bytes_read(padding_buf.filled().len())?;
+
+                        *this.state = BytesPacketPosition::Padding(pos + bytes_read);
+
+                        // ensure the bytes are not null bytes
+                        if !padding_buf.filled().iter().all(|e| *e == b'\0') {
+                            return Err(std::io::Error::new(
+                                std::io::ErrorKind::InvalidData,
+                                "padding is not all zeroes",
+                            ))
+                            .into();
+                        }
+
+                        // if we still have padding to read, run the loop again.
+                        continue;
+                    }
+                    // return EOF
+                    return Ok(()).into();
+                }
+            }
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::time::Duration;
+
+    use crate::wire::bytes::write_bytes;
+    use hex_literal::hex;
+    use lazy_static::lazy_static;
+    use rstest::rstest;
+    use tokio::io::AsyncReadExt;
+    use tokio_test::{assert_err, io::Builder};
+
+    use super::*;
+
+    /// The maximum length of bytes packets we're willing to accept in the test
+    /// cases.
+    const MAX_LEN: u64 = 1024;
+
+    lazy_static! {
+        pub static ref LARGE_PAYLOAD: Vec<u8> = (0..255).collect::<Vec<u8>>().repeat(4 * 1024);
+    }
+
+    /// Helper function, calling the (simpler) write_bytes with the payload.
+    /// We use this to create data we want to read from the wire.
+    async fn produce_packet_bytes(payload: &[u8]) -> Vec<u8> {
+        let mut exp = vec![];
+        write_bytes(&mut exp, payload).await.unwrap();
+        exp
+    }
+
+    /// Read bytes packets of various length, and ensure read_to_end returns the
+    /// expected payload.
+    #[rstest]
+    #[case::empty(&[])] // empty bytes packet
+    #[case::size_1b(&[0xff])] // 1 bytes payload
+    #[case::size_8b(&hex!("0001020304050607"))] // 8 bytes payload (no padding)
+    #[case::size_9b( &hex!("000102030405060708"))] // 9 bytes payload (7 bytes padding)
+    #[case::size_1m(LARGE_PAYLOAD.as_slice())] // larger bytes packet
+    #[tokio::test]
+    async fn read_payload_correct(#[case] payload: &[u8]) {
+        let mut mock = Builder::new()
+            .read(&produce_packet_bytes(payload).await)
+            .build();
+
+        let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64);
+        let mut buf = Vec::new();
+        r.read_to_end(&mut buf).await.expect("must succeed");
+
+        assert_eq!(payload, &buf[..]);
+    }
+
+    /// Fail if the bytes packet is larger than allowed
+    #[tokio::test]
+    async fn read_bigger_than_allowed_fail() {
+        let payload = LARGE_PAYLOAD.as_slice();
+        let mut mock = Builder::new()
+            .read(&produce_packet_bytes(payload).await[0..8]) // We stop reading after the size packet
+            .build();
+
+        let mut r = BytesReader::new(&mut mock, ..2048);
+        let mut buf = Vec::new();
+        assert_err!(r.read_to_end(&mut buf).await);
+    }
+
+    /// Fail if the bytes packet is smaller than allowed
+    #[tokio::test]
+    async fn read_smaller_than_allowed_fail() {
+        let payload = &[0x00, 0x01, 0x02];
+        let mut mock = Builder::new()
+            .read(&produce_packet_bytes(payload).await[0..8]) // We stop reading after the size packet
+            .build();
+
+        let mut r = BytesReader::new(&mut mock, 1024..2048);
+        let mut buf = Vec::new();
+        assert_err!(r.read_to_end(&mut buf).await);
+    }
+
+    /// Fail if the padding is not all zeroes
+    #[tokio::test]
+    async fn read_fail_if_nonzero_padding() {
+        let payload = &[0x00, 0x01, 0x02];
+        let mut packet_bytes = produce_packet_bytes(payload).await;
+        // Flip some bits in the padding
+        packet_bytes[12] = 0xff;
+        let mut mock = Builder::new().read(&packet_bytes).build(); // We stop reading after the faulty bit
+
+        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
+        let mut buf = Vec::new();
+
+        r.read_to_end(&mut buf).await.expect_err("must fail");
+    }
+
+    /// Start a 9 bytes payload packet, but have the underlying reader return
+    /// EOF in the middle of the size packet (after 4 bytes).
+    /// We should get an unexpected EOF error, already when trying to read the
+    /// first byte (of payload)
+    #[tokio::test]
+    async fn read_9b_eof_during_size() {
+        let payload = &hex!("FF0102030405060708");
+        let mut mock = Builder::new()
+            .read(&produce_packet_bytes(payload).await[..4])
+            .build();
+
+        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
+        let mut buf = [0u8; 1];
+
+        assert_eq!(
+            r.read_exact(&mut buf).await.expect_err("must fail").kind(),
+            std::io::ErrorKind::UnexpectedEof
+        );
+
+        assert_eq!(&[0], &buf, "buffer should stay empty");
+    }
+
+    /// Start a 9 bytes payload packet, but have the underlying reader return
+    /// EOF in the middle of the payload (4 bytes into the payload).
+    /// We should get an unexpected EOF error, after reading the first 4 bytes
+    /// (successfully).
+    #[tokio::test]
+    async fn read_9b_eof_during_payload() {
+        let payload = &hex!("FF0102030405060708");
+        let mut mock = Builder::new()
+            .read(&produce_packet_bytes(payload).await[..8 + 4])
+            .build();
+
+        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
+        let mut buf = [0; 9];
+
+        r.read_exact(&mut buf[..4]).await.expect("must succeed");
+
+        assert_eq!(
+            r.read_exact(&mut buf[4..=4])
+                .await
+                .expect_err("must fail")
+                .kind(),
+            std::io::ErrorKind::UnexpectedEof
+        );
+    }
+
+    /// Start a 9 bytes payload packet, but return an error at various stages *after* the actual payload.
+    /// read_exact with a 9 bytes buffer is expected to succeed, but any further
+    /// read, as well as read_to_end are expected to fail.
+    #[rstest]
+    #[case::before_padding(8 + 9)]
+    #[case::during_padding(8 + 9 + 2)]
+    #[case::after_padding(8 + 9 + padding_len(9) as usize)]
+    #[tokio::test]
+    async fn read_9b_eof_after_payload(#[case] offset: usize) {
+        let payload = &hex!("FF0102030405060708");
+        let mut mock = Builder::new()
+            .read(&produce_packet_bytes(payload).await[..offset])
+            .build();
+
+        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
+        let mut buf = [0; 9];
+
+        // read_exact of the payload will succeed, but a subsequent read will
+        // return UnexpectedEof error.
+        r.read_exact(&mut buf).await.expect("should succeed");
+        assert_eq!(
+            r.read_exact(&mut buf[4..=4])
+                .await
+                .expect_err("must fail")
+                .kind(),
+            std::io::ErrorKind::UnexpectedEof
+        );
+
+        // read_to_end will fail.
+        let mut mock = Builder::new()
+            .read(&produce_packet_bytes(payload).await[..8 + payload.len()])
+            .build();
+
+        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
+        let mut buf = Vec::new();
+        assert_eq!(
+            r.read_to_end(&mut buf).await.expect_err("must fail").kind(),
+            std::io::ErrorKind::UnexpectedEof
+        );
+    }
+
+    /// Start a 9 bytes payload packet, but return an error after a certain position.
+    /// Ensure that error is propagated.
+    #[rstest]
+    #[case::during_size(4)]
+    #[case::before_payload(8)]
+    #[case::during_payload(8 + 4)]
+    #[case::before_padding(8 + 4)]
+    #[case::during_padding(8 + 9 + 2)]
+    #[tokio::test]
+    async fn propagate_error_from_reader(#[case] offset: usize) {
+        let payload = &hex!("FF0102030405060708");
+        let mut mock = Builder::new()
+            .read(&produce_packet_bytes(payload).await[..offset])
+            .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo"))
+            .build();
+
+        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
+        let mut buf = Vec::new();
+
+        let err = r.read_to_end(&mut buf).await.expect_err("must fail");
+        assert_eq!(
+            err.kind(),
+            std::io::ErrorKind::Other,
+            "error kind must match"
+        );
+
+        assert_eq!(
+            err.into_inner().unwrap().to_string(),
+            "foo",
+            "error payload must contain foo"
+        );
+    }
+
+    /// If there's an error right after the padding, we don't propagate it, as
+    /// we're done reading. We just return EOF.
+    #[tokio::test]
+    async fn no_error_after_eof() {
+        let payload = &hex!("FF0102030405060708");
+        let mut mock = Builder::new()
+            .read(&produce_packet_bytes(payload).await)
+            .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo"))
+            .build();
+
+        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
+        let mut buf = Vec::new();
+
+        r.read_to_end(&mut buf).await.expect("must succeed");
+        assert_eq!(buf.as_slice(), payload);
+    }
+
+    /// Introduce various stalls in various places of the packet, to ensure we
+    /// handle these cases properly, too.
+    #[rstest]
+    #[case::beginning(0)]
+    #[case::before_payload(8)]
+    #[case::during_payload(8 + 4)]
+    #[case::before_padding(8 + 4)]
+    #[case::during_padding(8 + 9 + 2)]
+    #[tokio::test]
+    async fn read_payload_correct_pending(#[case] offset: usize) {
+        let payload = &hex!("FF0102030405060708");
+        let mut mock = Builder::new()
+            .read(&produce_packet_bytes(payload).await[..offset])
+            .wait(Duration::from_nanos(0))
+            .read(&produce_packet_bytes(payload).await[offset..])
+            .build();
+
+        let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64);
+        let mut buf = Vec::new();
+        r.read_to_end(&mut buf).await.expect("must succeed");
+
+        assert_eq!(payload, &buf[..]);
+    }
+}
diff --git a/tvix/nix-compat/src/wire/bytes/writer.rs b/tvix/nix-compat/src/wire/bytes/writer.rs
new file mode 100644
index 000000000000..f278b8335f8f
--- /dev/null
+++ b/tvix/nix-compat/src/wire/bytes/writer.rs
@@ -0,0 +1,521 @@
+use pin_project_lite::pin_project;
+use std::task::{ready, Poll};
+
+use tokio::io::AsyncWrite;
+
+use super::{padding_len, BytesPacketPosition, EMPTY_BYTES, LEN_SIZE};
+
+pin_project! {
+    /// Writes a "bytes wire packet" to the underlying writer.
+    /// The format is the same as in [crate::wire::bytes::write_bytes],
+    /// however this structure provides a [AsyncWrite] interface,
+    /// allowing to not having to pass around the entire payload in memory.
+    ///
+    /// It internally takes care of writing (non-payload) framing (size and
+    /// padding).
+    ///
+    /// During construction, the expected payload size needs to be provided.
+    ///
+    /// After writing the payload to it, the user MUST call flush (or shutdown),
+    /// which will validate the written payload size to match, and write the
+    /// necessary padding.
+    ///
+    /// In case flush is not called at the end, invalid data might be sent
+    /// silently.
+    ///
+    /// The underlying writer returning `Ok(0)` is considered an EOF situation,
+    /// which is stronger than the "typically means the underlying object is no
+    /// longer able to accept bytes" interpretation from the docs. If such a
+    /// situation occurs, an error is returned.
+    ///
+    /// The struct holds three fields, the underlying writer, the (expected)
+    /// payload length, and an enum, tracking the state.
+    pub struct BytesWriter<W>
+    where
+        W: AsyncWrite,
+    {
+        #[pin]
+        inner: W,
+        payload_len: u64,
+        state: BytesPacketPosition,
+    }
+}
+
+impl<W> BytesWriter<W>
+where
+    W: AsyncWrite,
+{
+    /// Constructs a new BytesWriter, using the underlying passed writer.
+    pub fn new(w: W, payload_len: u64) -> Self {
+        Self {
+            inner: w,
+            payload_len,
+            state: BytesPacketPosition::Size(0),
+        }
+    }
+}
+
+/// Returns an error if the passed usize is 0.
+fn ensure_nonzero_bytes_written(bytes_written: usize) -> Result<usize, std::io::Error> {
+    if bytes_written == 0 {
+        Err(std::io::Error::new(
+            std::io::ErrorKind::WriteZero,
+            "underlying writer accepted 0 bytes",
+        ))
+    } else {
+        Ok(bytes_written)
+    }
+}
+
+impl<W> AsyncWrite for BytesWriter<W>
+where
+    W: AsyncWrite,
+{
+    fn poll_write(
+        self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+        buf: &[u8],
+    ) -> Poll<Result<usize, std::io::Error>> {
+        // Use a loop, so we can deal with (multiple) state transitions.
+        let mut this = self.project();
+
+        loop {
+            match *this.state {
+                BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
+                BytesPacketPosition::Size(pos) => {
+                    let size_field = &this.payload_len.to_le_bytes();
+
+                    let bytes_written = ensure_nonzero_bytes_written(ready!(this
+                        .inner
+                        .as_mut()
+                        .poll_write(cx, &size_field[pos..]))?)?;
+
+                    let new_pos = pos + bytes_written;
+                    if new_pos == LEN_SIZE {
+                        *this.state = BytesPacketPosition::Payload(0);
+                    } else {
+                        *this.state = BytesPacketPosition::Size(new_pos);
+                    }
+                }
+                BytesPacketPosition::Payload(pos) => {
+                    // Ensure we still have space for more payload
+                    if pos + (buf.len() as u64) > *this.payload_len {
+                        return Poll::Ready(Err(std::io::Error::new(
+                            std::io::ErrorKind::InvalidData,
+                            "tried to write excess bytes",
+                        )));
+                    }
+                    let bytes_written = ready!(this.inner.as_mut().poll_write(cx, buf))?;
+                    ensure_nonzero_bytes_written(bytes_written)?;
+                    let new_pos = pos + (bytes_written as u64);
+                    if new_pos == *this.payload_len {
+                        *this.state = BytesPacketPosition::Padding(0)
+                    } else {
+                        *this.state = BytesPacketPosition::Payload(new_pos)
+                    }
+
+                    return Poll::Ready(Ok(bytes_written));
+                }
+                // If we're already in padding state, there should be no more payload left to write!
+                BytesPacketPosition::Padding(_pos) => {
+                    return Poll::Ready(Err(std::io::Error::new(
+                        std::io::ErrorKind::InvalidData,
+                        "tried to write excess bytes",
+                    )))
+                }
+            }
+        }
+    }
+
+    fn poll_flush(
+        self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<(), std::io::Error>> {
+        let mut this = self.project();
+
+        loop {
+            match *this.state {
+                BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
+                BytesPacketPosition::Size(pos) => {
+                    // More bytes to write in the size field
+                    let size_field = &this.payload_len.to_le_bytes()[..];
+                    let bytes_written = ensure_nonzero_bytes_written(ready!(this
+                        .inner
+                        .as_mut()
+                        .poll_write(cx, &size_field[pos..]))?)?;
+                    let new_pos = pos + bytes_written;
+                    if new_pos == LEN_SIZE {
+                        // Size field written, now ready to receive payload
+                        *this.state = BytesPacketPosition::Payload(0);
+                    } else {
+                        *this.state = BytesPacketPosition::Size(new_pos);
+                    }
+                }
+                BytesPacketPosition::Payload(_pos) => {
+                    // If we're at position 0 and want to write 0 bytes of payload
+                    // in total, we can transition to padding.
+                    // Otherwise, break, as we're expecting more payload to
+                    // be written.
+                    if *this.payload_len == 0 {
+                        *this.state = BytesPacketPosition::Padding(0);
+                    } else {
+                        break;
+                    }
+                }
+                BytesPacketPosition::Padding(pos) => {
+                    // Write remaining padding, if there is padding to write.
+                    let total_padding_len = padding_len(*this.payload_len) as usize;
+
+                    if pos != total_padding_len {
+                        let bytes_written = ensure_nonzero_bytes_written(ready!(this
+                            .inner
+                            .as_mut()
+                            .poll_write(cx, &EMPTY_BYTES[pos..total_padding_len]))?)?;
+                        *this.state = BytesPacketPosition::Padding(pos + bytes_written);
+                    } else {
+                        // everything written, break
+                        break;
+                    }
+                }
+            }
+        }
+        // Flush the underlying writer.
+        this.inner.as_mut().poll_flush(cx)
+    }
+
+    fn poll_shutdown(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<(), std::io::Error>> {
+        // Call flush.
+        ready!(self.as_mut().poll_flush(cx))?;
+
+        let this = self.project();
+
+        // After a flush, being inside the padding state, and at the end of the padding
+        // is the only way to prevent a dirty shutdown.
+        if let BytesPacketPosition::Padding(pos) = *this.state {
+            let padding_len = padding_len(*this.payload_len) as usize;
+            if padding_len == pos {
+                // Shutdown the underlying writer
+                return this.inner.poll_shutdown(cx);
+            }
+        }
+
+        // Shutdown the underlying writer, bubbling up any errors.
+        ready!(this.inner.poll_shutdown(cx))?;
+
+        // return an error about unclean shutdown
+        Poll::Ready(Err(std::io::Error::new(
+            std::io::ErrorKind::BrokenPipe,
+            "unclean shutdown",
+        )))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::time::Duration;
+
+    use crate::wire::bytes::write_bytes;
+    use hex_literal::hex;
+    use lazy_static::lazy_static;
+    use tokio::io::AsyncWriteExt;
+    use tokio_test::{assert_err, assert_ok, io::Builder};
+
+    use super::*;
+
+    lazy_static! {
+        pub static ref LARGE_PAYLOAD: Vec<u8> = (0..255).collect::<Vec<u8>>().repeat(4 * 1024);
+    }
+
+    /// Helper function, calling the (simpler) write_bytes with the payload.
+    /// We use this to create data we want to see on the wire.
+    async fn produce_exp_bytes(payload: &[u8]) -> Vec<u8> {
+        let mut exp = vec![];
+        write_bytes(&mut exp, payload).await.unwrap();
+        exp
+    }
+
+    /// Write an empty bytes packet.
+    #[tokio::test]
+    async fn write_empty() {
+        let payload = &[];
+        let mut mock = Builder::new()
+            .write(&produce_exp_bytes(payload).await)
+            .build();
+
+        let mut w = BytesWriter::new(&mut mock, 0);
+        assert_ok!(w.write_all(&[]).await, "write all data");
+        assert_ok!(w.flush().await, "flush");
+    }
+
+    /// Write an empty bytes packet, not calling write.
+    #[tokio::test]
+    async fn write_empty_only_flush() {
+        let payload = &[];
+        let mut mock = Builder::new()
+            .write(&produce_exp_bytes(payload).await)
+            .build();
+
+        let mut w = BytesWriter::new(&mut mock, 0);
+        assert_ok!(w.flush().await, "flush");
+    }
+
+    /// Write an empty bytes packet, not calling write or flush, only shutdown.
+    #[tokio::test]
+    async fn write_empty_only_shutdown() {
+        let payload = &[];
+        let mut mock = Builder::new()
+            .write(&produce_exp_bytes(payload).await)
+            .build();
+
+        let mut w = BytesWriter::new(&mut mock, 0);
+        assert_ok!(w.shutdown().await, "shutdown");
+    }
+
+    /// Write a 1 bytes packet
+    #[tokio::test]
+    async fn write_1b() {
+        let payload = &[0xff];
+
+        let mut mock = Builder::new()
+            .write(&produce_exp_bytes(payload).await)
+            .build();
+
+        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
+        assert_ok!(w.write_all(payload).await);
+        assert_ok!(w.flush().await, "flush");
+    }
+
+    /// Write a 8 bytes payload (no padding)
+    #[tokio::test]
+    async fn write_8b() {
+        let payload = &hex!("0001020304050607");
+
+        let mut mock = Builder::new()
+            .write(&produce_exp_bytes(payload).await)
+            .build();
+
+        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
+        assert_ok!(w.write_all(payload).await);
+        assert_ok!(w.flush().await, "flush");
+    }
+
+    /// Write a 9 bytes payload (7 bytes padding)
+    #[tokio::test]
+    async fn write_9b() {
+        let payload = &hex!("000102030405060708");
+
+        let mut mock = Builder::new()
+            .write(&produce_exp_bytes(payload).await)
+            .build();
+
+        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
+        assert_ok!(w.write_all(payload).await);
+        assert_ok!(w.flush().await, "flush");
+    }
+
+    /// Write a 9 bytes packet very granularly, with a lot of flushing in between,
+    /// and a shutdown at the end.
+    #[tokio::test]
+    async fn write_9b_flush() {
+        let payload = &hex!("000102030405060708");
+        let exp_bytes = produce_exp_bytes(payload).await;
+
+        let mut mock = Builder::new().write(&exp_bytes).build();
+
+        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
+        assert_ok!(w.flush().await);
+
+        assert_ok!(w.write_all(&payload[..4]).await);
+        assert_ok!(w.flush().await);
+
+        // empty write, cause why not
+        assert_ok!(w.write_all(&[]).await);
+        assert_ok!(w.flush().await);
+
+        assert_ok!(w.write_all(&payload[4..]).await);
+        assert_ok!(w.flush().await);
+        assert_ok!(w.shutdown().await);
+    }
+
+    /// Write a 9 bytes packet, but cause the sink to only accept half of the
+    /// padding, ensuring we correctly write (only) the rest of the padding later.
+    /// We write another 2 bytes of "bait", where a faulty implementation (pre
+    /// cl/11384) would put too many null bytes.
+    #[tokio::test]
+    async fn write_9b_write_padding_2steps() {
+        let payload = &hex!("000102030405060708");
+        let exp_bytes = produce_exp_bytes(payload).await;
+
+        let mut mock = Builder::new()
+            .write(&exp_bytes[0..8]) // size
+            .write(&exp_bytes[8..17]) // payload
+            .write(&exp_bytes[17..19]) // padding (2 of 7 bytes)
+            // insert a wait to prevent Mock from merging the two writes into one
+            .wait(Duration::from_nanos(1))
+            .write(&hex!("0000000000ffff")) // padding (5 of 7 bytes, plus 2 bytes of "bait")
+            .build();
+
+        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
+        assert_ok!(w.write_all(&payload[..]).await);
+        assert_ok!(w.flush().await);
+        // Write bait
+        assert_ok!(mock.write_all(&hex!("ffff")).await);
+    }
+
+    /// Write a larger bytes packet
+    #[tokio::test]
+    async fn write_1m() {
+        let payload = LARGE_PAYLOAD.as_slice();
+        let exp_bytes = produce_exp_bytes(payload).await;
+
+        let mut mock = Builder::new().write(&exp_bytes).build();
+        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
+
+        assert_ok!(w.write_all(payload).await);
+        assert_ok!(w.flush().await, "flush");
+    }
+
+    /// Not calling flush at the end, but shutdown is also ok if we wrote all
+    /// bytes we promised to write (as shutdown implies flush)
+    #[tokio::test]
+    async fn write_shutdown_without_flush_end() {
+        let payload = &[0xf0, 0xff];
+        let exp_bytes = produce_exp_bytes(payload).await;
+
+        let mut mock = Builder::new().write(&exp_bytes).build();
+        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
+
+        // call flush to write the size field
+        assert_ok!(w.flush().await);
+
+        // write payload
+        assert_ok!(w.write_all(payload).await);
+
+        // call shutdown
+        assert_ok!(w.shutdown().await);
+    }
+
+    /// Writing more bytes than previously signalled should fail.
+    #[tokio::test]
+    async fn write_more_than_signalled_fail() {
+        let mut buf = Vec::new();
+        let mut w = BytesWriter::new(&mut buf, 2);
+
+        assert_err!(w.write_all(&hex!("000102")).await);
+    }
+    /// Writing more bytes than previously signalled, but in two parts
+    #[tokio::test]
+    async fn write_more_than_signalled_split_fail() {
+        let mut buf = Vec::new();
+        let mut w = BytesWriter::new(&mut buf, 2);
+
+        // write two bytes
+        assert_ok!(w.write_all(&hex!("0001")).await);
+
+        // write the excess byte.
+        assert_err!(w.write_all(&hex!("02")).await);
+    }
+
+    /// Writing more bytes than previously signalled, but flushing after the
+    /// signalled amount should fail.
+    #[tokio::test]
+    async fn write_more_than_signalled_flush_fail() {
+        let mut buf = Vec::new();
+        let mut w = BytesWriter::new(&mut buf, 2);
+
+        // write two bytes, then flush
+        assert_ok!(w.write_all(&hex!("0001")).await);
+        assert_ok!(w.flush().await);
+
+        // write the excess byte.
+        assert_err!(w.write_all(&hex!("02")).await);
+    }
+
+    /// Calling shutdown while not having written all bytes that were promised
+    /// returns an error.
+    /// Note there's still cases of silent corruption if the user doesn't call
+    /// shutdown explicitly (only drops).
+    #[tokio::test]
+    async fn premature_shutdown() {
+        let payload = &[0xf0, 0xff];
+        let mut buf = Vec::new();
+        let mut w = BytesWriter::new(&mut buf, payload.len() as u64);
+
+        // call flush to write the size field
+        assert_ok!(w.flush().await);
+
+        // write half of the payload (!)
+        assert_ok!(w.write_all(&payload[0..1]).await);
+
+        // call shutdown, ensure it fails
+        assert_err!(w.shutdown().await);
+    }
+
+    /// Write to a Writer that fails to write during the size packet (after 4 bytes).
+    /// Ensure this error gets propagated on the first call to write.
+    #[tokio::test]
+    async fn inner_writer_fail_during_size_firstwrite() {
+        let payload = &[0xf0];
+
+        let mut mock = Builder::new()
+            .write(&1u32.to_le_bytes())
+            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "๐Ÿฟ"))
+            .build();
+        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
+
+        assert_err!(w.write_all(payload).await);
+    }
+
+    /// Write to a Writer that fails to write during the size packet (after 4 bytes).
+    /// Ensure this error gets propagated during an initial flush
+    #[tokio::test]
+    async fn inner_writer_fail_during_size_initial_flush() {
+        let payload = &[0xf0];
+
+        let mut mock = Builder::new()
+            .write(&1u32.to_le_bytes())
+            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "๐Ÿฟ"))
+            .build();
+        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
+
+        assert_err!(w.flush().await);
+    }
+
+    /// Write to a writer that fails to write during the payload (after 9 bytes).
+    /// Ensure this error gets propagated when we're writing this byte.
+    #[tokio::test]
+    async fn inner_writer_fail_during_write() {
+        let payload = &hex!("f0ff");
+
+        let mut mock = Builder::new()
+            .write(&2u64.to_le_bytes())
+            .write(&hex!("f0"))
+            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "๐Ÿฟ"))
+            .build();
+        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
+
+        assert_ok!(w.write(&hex!("f0")).await);
+        assert_err!(w.write(&hex!("ff")).await);
+    }
+
+    /// Write to a writer that fails to write during the padding (after 10 bytes).
+    /// Ensure this error gets propagated during a flush.
+    #[tokio::test]
+    async fn inner_writer_fail_during_padding_flush() {
+        let payload = &hex!("f0");
+
+        let mut mock = Builder::new()
+            .write(&1u64.to_le_bytes())
+            .write(&hex!("f0"))
+            .write(&hex!("00"))
+            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "๐Ÿฟ"))
+            .build();
+        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
+
+        assert_ok!(w.write(&hex!("f0")).await);
+        assert_err!(w.flush().await);
+    }
+}