about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes/reader/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/reader/mod.rs')
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader/mod.rs35
1 files changed, 29 insertions, 6 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
index 9f7013a4e9b1..5f6081b40451 100644
--- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
@@ -1,4 +1,5 @@
 use std::{
+    future::Future,
     io,
     ops::{Bound, RangeBounds},
     pin::Pin,
@@ -6,7 +7,7 @@ use std::{
 };
 use tokio::io::{AsyncRead, ReadBuf};
 
-use trailer::TrailerReader;
+use trailer::{read_trailer, ReadTrailer, Trailer};
 mod trailer;
 
 /// Reads a "bytes wire packet" from the underlying reader.
@@ -33,6 +34,7 @@ pub struct BytesReader<R> {
 
 #[derive(Debug)]
 enum State<R> {
+    /// The data size is being read.
     Size {
         reader: Option<R>,
         /// Minimum length (inclusive)
@@ -42,12 +44,18 @@ enum State<R> {
         filled: u8,
         buf: [u8; 8],
     },
+    /// Full 8-byte blocks are being read and released to the caller.
     Body {
         reader: Option<R>,
         consumed: u64,
+        /// The total length of all user data contained in both the body and trailer.
         user_len: u64,
     },
-    Trailer(TrailerReader<R>),
+    /// The trailer is in the process of being read.
+    ReadTrailer(ReadTrailer<R>),
+    /// The trailer has been fully read and validated,
+    /// and data can now be released to the caller.
+    ReleaseTrailer { consumed: u8, data: Trailer },
 }
 
 impl<R> BytesReader<R>
@@ -100,7 +108,10 @@ where
             State::Body {
                 consumed, user_len, ..
             } => Some(user_len - consumed),
-            State::Trailer(ref r) => Some(r.len() as u64),
+            State::ReadTrailer(ref fut) => Some(fut.len() as u64),
+            State::ReleaseTrailer { consumed, ref data } => {
+                Some(data.len() as u64 - consumed as u64)
+            }
         }
     }
 }
@@ -166,7 +177,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for BytesReader<R> {
                     let reader = if remaining == 0 {
                         let reader = reader.take().unwrap();
                         let user_len = (*user_len & 7) as u8;
-                        *this = State::Trailer(TrailerReader::new(reader, user_len));
+                        *this = State::ReadTrailer(read_trailer(reader, user_len));
                         continue;
                     } else {
                         reader.as_mut().unwrap()
@@ -188,8 +199,20 @@ impl<R: AsyncRead + Unpin> AsyncRead for BytesReader<R> {
                     }
                     .into();
                 }
-                State::Trailer(reader) => {
-                    return Pin::new(reader).poll_read(cx, buf);
+                State::ReadTrailer(fut) => {
+                    *this = State::ReleaseTrailer {
+                        consumed: 0,
+                        data: ready!(Pin::new(fut).poll(cx))?,
+                    };
+                }
+                State::ReleaseTrailer { consumed, data } => {
+                    let data = &data[*consumed as usize..];
+                    let data = &data[..usize::min(data.len(), buf.remaining())];
+
+                    buf.put_slice(data);
+                    *consumed += data.len() as u8;
+
+                    return Ok(()).into();
                 }
             }
         }