about summary refs log tree commit diff
path: root/tvix/nix-compat
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat')
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader/mod.rs35
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader/trailer.rs97
2 files changed, 41 insertions, 91 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
index 9f7013a4e9b1..5f6081b40451 100644
--- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
@@ -1,4 +1,5 @@
 use std::{
+    future::Future,
     io,
     ops::{Bound, RangeBounds},
     pin::Pin,
@@ -6,7 +7,7 @@ use std::{
 };
 use tokio::io::{AsyncRead, ReadBuf};
 
-use trailer::TrailerReader;
+use trailer::{read_trailer, ReadTrailer, Trailer};
 mod trailer;
 
 /// Reads a "bytes wire packet" from the underlying reader.
@@ -33,6 +34,7 @@ pub struct BytesReader<R> {
 
 #[derive(Debug)]
 enum State<R> {
+    /// The data size is being read.
     Size {
         reader: Option<R>,
         /// Minimum length (inclusive)
@@ -42,12 +44,18 @@ enum State<R> {
         filled: u8,
         buf: [u8; 8],
     },
+    /// Full 8-byte blocks are being read and released to the caller.
     Body {
         reader: Option<R>,
         consumed: u64,
+        /// The total length of all user data contained in both the body and trailer.
         user_len: u64,
     },
-    Trailer(TrailerReader<R>),
+    /// The trailer is in the process of being read.
+    ReadTrailer(ReadTrailer<R>),
+    /// The trailer has been fully read and validated,
+    /// and data can now be released to the caller.
+    ReleaseTrailer { consumed: u8, data: Trailer },
 }
 
 impl<R> BytesReader<R>
@@ -100,7 +108,10 @@ where
             State::Body {
                 consumed, user_len, ..
             } => Some(user_len - consumed),
-            State::Trailer(ref r) => Some(r.len() as u64),
+            State::ReadTrailer(ref fut) => Some(fut.len() as u64),
+            State::ReleaseTrailer { consumed, ref data } => {
+                Some(data.len() as u64 - consumed as u64)
+            }
         }
     }
 }
@@ -166,7 +177,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for BytesReader<R> {
                     let reader = if remaining == 0 {
                         let reader = reader.take().unwrap();
                         let user_len = (*user_len & 7) as u8;
-                        *this = State::Trailer(TrailerReader::new(reader, user_len));
+                        *this = State::ReadTrailer(read_trailer(reader, user_len));
                         continue;
                     } else {
                         reader.as_mut().unwrap()
@@ -188,8 +199,20 @@ impl<R: AsyncRead + Unpin> AsyncRead for BytesReader<R> {
                     }
                     .into();
                 }
-                State::Trailer(reader) => {
-                    return Pin::new(reader).poll_read(cx, buf);
+                State::ReadTrailer(fut) => {
+                    *this = State::ReleaseTrailer {
+                        consumed: 0,
+                        data: ready!(Pin::new(fut).poll(cx))?,
+                    };
+                }
+                State::ReleaseTrailer { consumed, data } => {
+                    let data = &data[*consumed as usize..];
+                    let data = &data[..usize::min(data.len(), buf.remaining())];
+
+                    buf.put_slice(data);
+                    *consumed += data.len() as u8;
+
+                    return Ok(()).into();
                 }
             }
         }
diff --git a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
index 61a77678080a..9b8bcaa2de4a 100644
--- a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs
@@ -53,7 +53,7 @@ impl Tag for Pad {
 }
 
 #[derive(Debug)]
-pub(crate) struct ReadTrailer<R, T: Tag> {
+pub(crate) struct ReadTrailer<R, T: Tag = Pad> {
     reader: R,
     data_len: u8,
     filled: u8,
@@ -90,7 +90,7 @@ impl<R, T: Tag> ReadTrailer<R, T> {
 impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
     type Output = io::Result<Trailer>;
 
-    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
+    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
         let this = &mut *self;
 
         loop {
@@ -136,72 +136,9 @@ impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
     }
 }
 
-#[derive(Debug)]
-pub(crate) enum TrailerReader<R> {
-    Reading(ReadTrailer<R, Pad>),
-    Releasing { off: u8, data: Trailer },
-    Done,
-}
-
-impl<R: AsyncRead + Unpin> TrailerReader<R> {
-    pub fn new(reader: R, data_len: u8) -> Self {
-        Self::Reading(read_trailer(reader, data_len))
-    }
-
-    pub fn len(&self) -> u8 {
-        match self {
-            TrailerReader::Reading(fut) => fut.len(),
-            &TrailerReader::Releasing {
-                off,
-                data: Trailer { data_len, .. },
-            } => data_len - off,
-            TrailerReader::Done => 0,
-        }
-    }
-}
-
-impl<R: AsyncRead + Unpin> AsyncRead for TrailerReader<R> {
-    fn poll_read(
-        mut self: Pin<&mut Self>,
-        cx: &mut task::Context,
-        user_buf: &mut ReadBuf,
-    ) -> Poll<io::Result<()>> {
-        let this = &mut *self;
-
-        loop {
-            match this {
-                Self::Reading(fut) => {
-                    *this = Self::Releasing {
-                        off: 0,
-                        data: ready!(Pin::new(fut).poll(cx))?,
-                    };
-                }
-                Self::Releasing { off: 8, .. } => {
-                    *this = Self::Done;
-                }
-                Self::Releasing { off, data } => {
-                    assert_ne!(user_buf.remaining(), 0);
-
-                    let buf = &data[*off as usize..];
-                    let buf = &buf[..usize::min(buf.len(), user_buf.remaining())];
-
-                    user_buf.put_slice(buf);
-                    *off += buf.len() as u8;
-
-                    break;
-                }
-                Self::Done => break,
-            }
-        }
-
-        Ok(()).into()
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use std::time::Duration;
-    use tokio::io::AsyncReadExt;
 
     use super::*;
 
@@ -213,11 +150,8 @@ mod tests {
             .read(&[0xef, 0x00])
             .build();
 
-        let mut reader = TrailerReader::new(reader, 2);
-
-        let mut buf = vec![];
         assert_eq!(
-            reader.read_to_end(&mut buf).await.unwrap_err().kind(),
+            read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(),
             io::ErrorKind::UnexpectedEof
         );
     }
@@ -231,11 +165,8 @@ mod tests {
             .wait(Duration::ZERO)
             .build();
 
-        let mut reader = TrailerReader::new(reader, 2);
-
-        let mut buf = vec![];
         assert_eq!(
-            reader.read_to_end(&mut buf).await.unwrap_err().kind(),
+            read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(),
             io::ErrorKind::InvalidData
         );
     }
@@ -250,21 +181,17 @@ mod tests {
             .read(&[0x00, 0x00, 0x00, 0x00, 0x00])
             .build();
 
-        let mut reader = TrailerReader::new(reader, 2);
-
-        let mut buf = vec![];
-        reader.read_to_end(&mut buf).await.unwrap();
-
-        assert_eq!(buf, &[0xed, 0xef]);
+        assert_eq!(
+            &*read_trailer::<_, Pad>(reader, 2).await.unwrap(),
+            &[0xed, 0xef]
+        );
     }
 
     #[tokio::test]
     async fn no_padding() {
-        let reader = tokio_test::io::Builder::new().build();
-        let mut reader = TrailerReader::new(reader, 0);
-
-        let mut buf = vec![];
-        reader.read_to_end(&mut buf).await.unwrap();
-        assert!(buf.is_empty());
+        assert!(read_trailer::<_, Pad>(io::empty(), 0)
+            .await
+            .unwrap()
+            .is_empty());
     }
 }