about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader/trailer.rs190
1 files changed, 134 insertions, 56 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
index 1a084d0eeb01..958cead42d8f 100644
--- a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
@@ -1,42 +1,148 @@
 use std::{
+    future::Future,
+    marker::PhantomData,
+    ops::Deref,
     pin::Pin,
     task::{self, ready, Poll},
 };
 
 use tokio::io::{self, AsyncRead, ReadBuf};
 
+/// Trailer represents up to 7 bytes of data read as part of the trailer block(s)
 #[derive(Debug)]
-pub enum TrailerReader<R> {
-    Reading {
-        reader: R,
-        user_len: u8,
-        filled: u8,
-        buf: [u8; 8],
-    },
-    Releasing {
-        off: u8,
-        len: u8,
-        buf: [u8; 8],
-    },
-    Done,
+pub(crate) struct Trailer {
+    data_len: u8,
+    buf: [u8; 7],
 }
 
-impl<R: AsyncRead + Unpin> TrailerReader<R> {
-    pub fn new(reader: R, user_len: u8) -> Self {
-        if user_len == 0 {
-            return Self::Done;
-        }
+impl Deref for Trailer {
+    type Target = [u8];
+
+    fn deref(&self) -> &Self::Target {
+        &self.buf[..self.data_len as usize]
+    }
+}
+
+/// Tag defines a "trailer tag": specific, fixed bytes that must follow wire data.
+pub(crate) trait Tag {
+    /// The expected suffix
+    ///
+    /// The first 7 bytes may be ignored, and it must be an 8-byte aligned size.
+    const PATTERN: &'static [u8];
+
+    /// Suitably sized buffer for reading [Self::PATTERN]
+    ///
+    /// HACK: This is a workaround for const generics limitations.
+    type Buf: AsRef<[u8]> + AsMut<[u8]> + Unpin;
+
+    /// Make an instance of [Self::Buf]
+    fn make_buf() -> Self::Buf;
+}
+
+#[derive(Debug)]
+pub(crate) enum Pad {}
+
+impl Tag for Pad {
+    const PATTERN: &'static [u8] = &[0; 8];
+
+    type Buf = [u8; 8];
+
+    fn make_buf() -> Self::Buf {
+        [0; 8]
+    }
+}
+
+#[derive(Debug)]
+pub(crate) struct ReadTrailer<R, T: Tag> {
+    reader: R,
+    data_len: u8,
+    filled: u8,
+    buf: T::Buf,
+    _phantom: PhantomData<*const T>,
+}
+
+/// read_trailer returns a [Future] that reads a trailer with a given [Tag] from `reader`
+pub(crate) fn read_trailer<R: AsyncRead + Unpin, T: Tag>(
+    reader: R,
+    data_len: u8,
+) -> ReadTrailer<R, T> {
+    assert!(data_len < 8, "payload in trailer must be less than 8 bytes");
+
+    let buf = T::make_buf();
+    assert_eq!(buf.as_ref().len(), T::PATTERN.len());
+    assert_eq!(T::PATTERN.len() % 8, 0);
+
+    ReadTrailer {
+        reader,
+        data_len,
+        filled: if data_len != 0 { 0 } else { 8 },
+        buf,
+        _phantom: PhantomData,
+    }
+}
+
+impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
+    type Output = io::Result<Trailer>;
+
+    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
+        let this = &mut *self;
 
-        assert!(user_len < 8, "payload in trailer must be less than 8 bytes");
-        Self::Reading {
-            reader,
-            user_len,
-            filled: 0,
-            buf: [0; 8],
+        loop {
+            if this.filled >= this.data_len {
+                let check_range = || this.data_len as usize..this.filled as usize;
+
+                if this.buf.as_ref()[check_range()] != T::PATTERN[check_range()] {
+                    return Err(io::Error::new(
+                        io::ErrorKind::InvalidData,
+                        "invalid trailer",
+                    ))
+                    .into();
+                }
+            }
+
+            if this.filled as usize == T::PATTERN.len() {
+                let mut buf = [0; 7];
+                buf.copy_from_slice(&this.buf.as_ref()[..7]);
+
+                return Ok(Trailer {
+                    data_len: this.data_len,
+                    buf,
+                })
+                .into();
+            }
+
+            let mut buf = ReadBuf::new(this.buf.as_mut());
+            buf.advance(this.filled as usize);
+
+            ready!(Pin::new(&mut this.reader).poll_read(cx, &mut buf))?;
+
+            this.filled = {
+                let prev_filled = this.filled;
+                let filled = buf.filled().len() as u8;
+
+                if filled == prev_filled {
+                    return Err(io::ErrorKind::UnexpectedEof.into()).into();
+                }
+
+                filled
+            };
         }
     }
 }
 
+#[derive(Debug)]
+pub(crate) enum TrailerReader<R> {
+    Reading(ReadTrailer<R, Pad>),
+    Releasing { off: u8, data: Trailer },
+    Done,
+}
+
+impl<R: AsyncRead + Unpin> TrailerReader<R> {
+    pub fn new(reader: R, data_len: u8) -> Self {
+        Self::Reading(read_trailer(reader, data_len))
+    }
+}
+
 impl<R: AsyncRead + Unpin> AsyncRead for TrailerReader<R> {
     fn poll_read(
         mut self: Pin<&mut Self>,
@@ -47,47 +153,19 @@ impl<R: AsyncRead + Unpin> AsyncRead for TrailerReader<R> {
 
         loop {
             match this {
-                &mut Self::Reading {
-                    reader: _,
-                    user_len,
-                    filled: 8,
-                    buf,
-                } => {
+                Self::Reading(fut) => {
                     *this = Self::Releasing {
                         off: 0,
-                        len: user_len,
-                        buf,
+                        data: ready!(Pin::new(fut).poll(cx))?,
                     };
                 }
-                Self::Reading {
-                    reader,
-                    user_len,
-                    filled,
-                    buf,
-                } => {
-                    let mut read_buf = ReadBuf::new(&mut buf[..]);
-                    read_buf.advance(*filled as usize);
-                    ready!(Pin::new(reader).poll_read(cx, &mut read_buf))?;
-
-                    let new_filled = read_buf.filled().len() as u8;
-                    if *filled == new_filled {
-                        return Err(io::ErrorKind::UnexpectedEof.into()).into();
-                    }
-
-                    *filled = new_filled;
-
-                    // ensure the padding is all zeroes
-                    if (u64::from_le_bytes(*buf) >> (*user_len * 8)) != 0 {
-                        return Err(io::ErrorKind::InvalidData.into()).into();
-                    }
-                }
                 Self::Releasing { off: 8, .. } => {
                     *this = Self::Done;
                 }
-                Self::Releasing { off, len, buf } => {
+                Self::Releasing { off, data } => {
                     assert_ne!(user_buf.remaining(), 0);
 
-                    let buf = &buf[*off as usize..*len as usize];
+                    let buf = &data[*off as usize..];
                     let buf = &buf[..usize::min(buf.len(), user_buf.remaining())];
 
                     user_buf.put_slice(buf);