about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/reader/trailer.rs')
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader/trailer.rs175
1 files changed, 175 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
new file mode 100644
index 000000000000..d2b867c2c338
--- /dev/null
+++ b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
@@ -0,0 +1,175 @@
+use std::{
+    pin::Pin,
+    task::{self, ready, Poll},
+};
+
+use tokio::io::{self, AsyncRead, ReadBuf};
+
+#[derive(Debug)]
+pub enum TrailerReader<R> {
+    Reading {
+        reader: R,
+        user_len: u8,
+        filled: u8,
+        buf: [u8; 8],
+    },
+    Releasing {
+        off: u8,
+        len: u8,
+        buf: [u8; 8],
+    },
+    Done,
+}
+
+impl<R: AsyncRead + Unpin> TrailerReader<R> {
+    pub fn new(reader: R, user_len: u8) -> Self {
+        if user_len == 0 {
+            return Self::Done;
+        }
+
+        assert!(user_len < 8, "payload in trailer must be less than 8 bytes");
+        Self::Reading {
+            reader,
+            user_len,
+            filled: 0,
+            buf: [0; 8],
+        }
+    }
+}
+
+impl<R: AsyncRead + Unpin> AsyncRead for TrailerReader<R> {
+    fn poll_read(
+        self: Pin<&mut Self>,
+        cx: &mut task::Context,
+        user_buf: &mut ReadBuf,
+    ) -> Poll<io::Result<()>> {
+        let this = self.get_mut();
+
+        loop {
+            match this {
+                &mut Self::Reading {
+                    reader: _,
+                    user_len,
+                    filled: 8,
+                    buf,
+                } => {
+                    *this = Self::Releasing {
+                        off: 0,
+                        len: user_len,
+                        buf,
+                    };
+                }
+                Self::Reading {
+                    reader,
+                    user_len,
+                    filled,
+                    buf,
+                } => {
+                    let mut read_buf = ReadBuf::new(&mut buf[..]);
+                    read_buf.advance(*filled as usize);
+                    ready!(Pin::new(reader).poll_read(cx, &mut read_buf))?;
+
+                    let new_filled = read_buf.filled().len() as u8;
+                    if *filled == new_filled {
+                        return Err(io::ErrorKind::UnexpectedEof.into()).into();
+                    }
+
+                    *filled = new_filled;
+
+                    // ensure the padding is all zeroes
+                    if (u64::from_le_bytes(*buf) >> (*user_len * 8)) != 0 {
+                        return Err(io::ErrorKind::InvalidData.into()).into();
+                    }
+                }
+                Self::Releasing { off: 8, .. } => {
+                    *this = Self::Done;
+                }
+                Self::Releasing { off, len, buf } => {
+                    assert_ne!(user_buf.remaining(), 0);
+
+                    let buf = &buf[*off as usize..*len as usize];
+                    let buf = &buf[..usize::min(buf.len(), user_buf.remaining())];
+
+                    user_buf.put_slice(buf);
+                    *off += buf.len() as u8;
+
+                    break;
+                }
+                Self::Done => break,
+            }
+        }
+
+        Ok(()).into()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::time::Duration;
+    use tokio::io::AsyncReadExt;
+
+    use super::*;
+
+    #[tokio::test]
+    async fn unexpected_eof() {
+        let reader = tokio_test::io::Builder::new()
+            .read(&[0xed])
+            .wait(Duration::ZERO)
+            .read(&[0xef, 0x00])
+            .build();
+
+        let mut reader = TrailerReader::new(reader, 2);
+
+        let mut buf = vec![];
+        assert_eq!(
+            reader.read_to_end(&mut buf).await.unwrap_err().kind(),
+            io::ErrorKind::UnexpectedEof
+        );
+    }
+
+    #[tokio::test]
+    async fn invalid_padding() {
+        let reader = tokio_test::io::Builder::new()
+            .read(&[0xed])
+            .wait(Duration::ZERO)
+            .read(&[0xef, 0x01, 0x00])
+            .wait(Duration::ZERO)
+            .build();
+
+        let mut reader = TrailerReader::new(reader, 2);
+
+        let mut buf = vec![];
+        assert_eq!(
+            reader.read_to_end(&mut buf).await.unwrap_err().kind(),
+            io::ErrorKind::InvalidData
+        );
+    }
+
+    #[tokio::test]
+    async fn success() {
+        let reader = tokio_test::io::Builder::new()
+            .read(&[0xed])
+            .wait(Duration::ZERO)
+            .read(&[0xef, 0x00])
+            .wait(Duration::ZERO)
+            .read(&[0x00, 0x00, 0x00, 0x00, 0x00])
+            .build();
+
+        let mut reader = TrailerReader::new(reader, 2);
+
+        let mut buf = vec![];
+        reader.read_to_end(&mut buf).await.unwrap();
+
+        assert_eq!(buf, &[0xed, 0xef]);
+    }
+
+    #[tokio::test]
+    async fn no_padding() {
+        let reader = tokio_test::io::Builder::new().build();
+        let mut reader = TrailerReader::new(reader, 0);
+
+        let mut buf = vec![];
+        reader.read_to_end(&mut buf).await.unwrap();
+        assert!(buf.is_empty());
+    }
+}