about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader/mod.rs202
1 files changed, 65 insertions, 137 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
index b46b0b53396b..ef59e9c160e4 100644
--- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
@@ -1,12 +1,14 @@
 use std::{
     future::Future,
     io,
-    ops::{Bound, RangeBounds, RangeInclusive},
+    ops::RangeBounds,
     pin::Pin,
     task::{self, ready, Poll},
 };
 use tokio::io::{AsyncRead, ReadBuf};
 
+use crate::wire::read_u64;
+
 use trailer::{read_trailer, ReadTrailer, Trailer};
 mod trailer;
 
@@ -15,32 +17,21 @@ mod trailer;
 /// however this structure provides a [AsyncRead] interface,
 /// allowing to not having to pass around the entire payload in memory.
 ///
-/// After being constructed with the underlying reader and an allowed size,
-/// subsequent requests to poll_read will return payload data until the end
-/// of the packet is reached.
-///
-/// Internally, it will first read over the size packet, filling payload_size,
-/// ensuring it fits allowed_size, then return payload data.
+/// It is constructed by reading a size with [BytesReader::new],
+/// and yields payload data until the end of the packet is reached.
 ///
 /// It will not return the final bytes before all padding has been successfully
 /// consumed as well, but the full length of the reader must be consumed.
 ///
-/// In case of an error due to size constraints, or in case of not reading
-/// all the way to the end (and getting a EOF), the underlying reader is no
-/// longer usable and might return garbage.
+/// If the data is not read all the way to the end, or an error is encountered,
+/// the underlying reader is no longer usable and might return garbage.
+#[derive(Debug)]
 pub struct BytesReader<R> {
     state: State<R>,
 }
 
 #[derive(Debug)]
 enum State<R> {
-    /// The data size is being read.
-    Size {
-        reader: Option<R>,
-        allowed_size: RangeInclusive<u64>,
-        filled: u8,
-        buf: [u8; 8],
-    },
     /// Full 8-byte blocks are being read and released to the caller.
     Body {
         reader: Option<R>,
@@ -60,52 +51,37 @@ where
     R: AsyncRead + Unpin,
 {
     /// Constructs a new BytesReader, using the underlying passed reader.
-    pub fn new<S: RangeBounds<u64>>(reader: R, allowed_size: S) -> Self {
-        let allowed_size = match allowed_size.start_bound() {
-            Bound::Included(&n) => n,
-            Bound::Excluded(&n) => n.saturating_add(1),
-            Bound::Unbounded => 0,
-        }..=match allowed_size.end_bound() {
-            Bound::Included(&n) => n,
-            Bound::Excluded(&n) => n.checked_sub(1).unwrap(),
-            Bound::Unbounded => u64::MAX,
-        };
-
-        Self {
-            state: State::Size {
-                reader: Some(reader),
-                allowed_size,
-                filled: 0,
-                buf: [0; 8],
-            },
+    pub async fn new<S: RangeBounds<u64>>(mut reader: R, allowed_size: S) -> io::Result<Self> {
+        let size = read_u64(&mut reader).await?;
+
+        if !allowed_size.contains(&size) {
+            return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size"));
         }
-    }
 
-    /// Construct a new BytesReader with a known, and already-read size.
-    pub fn with_size(reader: R, size: u64) -> Self {
-        Self {
+        Ok(Self {
             state: State::Body {
                 reader: Some(reader),
                 consumed: 0,
                 user_len: size,
             },
-        }
+        })
+    }
+
+    /// Returns whether there is any remaining data to be read.
+    pub fn is_empty(&self) -> bool {
+        self.len() == 0
     }
 
     /// Remaining data length, ie not including data already read.
     ///
     /// If the size has not been read yet, this is [None].
-    #[allow(clippy::len_without_is_empty)] // if size is unknown, we can't answer that
-    pub fn len(&self) -> Option<u64> {
+    pub fn len(&self) -> u64 {
         match self.state {
-            State::Size { .. } => None,
             State::Body {
                 consumed, user_len, ..
-            } => Some(user_len - consumed),
-            State::ReadTrailer(ref fut) => Some(fut.len() as u64),
-            State::ReleaseTrailer { consumed, ref data } => {
-                Some(data.len() as u64 - consumed as u64)
-            }
+            } => user_len - consumed,
+            State::ReadTrailer(ref fut) => fut.len() as u64,
+            State::ReleaseTrailer { consumed, ref data } => data.len() as u64 - consumed as u64,
         }
     }
 }
@@ -120,45 +96,6 @@ impl<R: AsyncRead + Unpin> AsyncRead for BytesReader<R> {
 
         loop {
             match this {
-                State::Size {
-                    reader,
-                    allowed_size,
-                    filled: 8,
-                    buf,
-                } => {
-                    let reader = reader.take().unwrap();
-
-                    let data_len = u64::from_le_bytes(*buf);
-                    if !allowed_size.contains(&data_len) {
-                        return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size"))
-                            .into();
-                    }
-
-                    *this = State::Body {
-                        reader: Some(reader),
-                        consumed: 0,
-                        user_len: data_len,
-                    };
-                }
-                State::Size {
-                    reader,
-                    filled,
-                    buf,
-                    ..
-                } => {
-                    let reader = reader.as_mut().unwrap();
-
-                    let mut read_buf = ReadBuf::new(&mut buf[..]);
-                    read_buf.advance(*filled as usize);
-                    ready!(Pin::new(reader).poll_read(cx, &mut read_buf))?;
-
-                    let new_filled = read_buf.filled().len() as u8;
-                    if *filled == new_filled {
-                        return Err(io::ErrorKind::UnexpectedEof.into()).into();
-                    }
-
-                    *filled = new_filled;
-                }
                 State::Body {
                     reader,
                     consumed,
@@ -245,7 +182,7 @@ mod tests {
     use lazy_static::lazy_static;
     use rstest::rstest;
     use tokio::io::AsyncReadExt;
-    use tokio_test::{assert_err, io::Builder};
+    use tokio_test::io::Builder;
 
     use super::*;
 
@@ -279,34 +216,9 @@ mod tests {
             .read(&produce_packet_bytes(payload).await)
             .build();
 
-        let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64);
-        let mut buf = Vec::new();
-        r.read_to_end(&mut buf).await.expect("must succeed");
-
-        assert_eq!(payload, &buf[..]);
-    }
-
-    /// Read bytes packets of various length, and ensure read_to_end returns the
-    /// expected payload.
-    #[rstest]
-    #[case::empty(&[])] // empty bytes packet
-    #[case::size_1b(&[0xff])] // 1 bytes payload
-    #[case::size_8b(&hex!("0001020304050607"))] // 8 bytes payload (no padding)
-    #[case::size_9b(&hex!("000102030405060708"))] // 9 bytes payload (7 bytes padding)
-    #[case::size_1m(LARGE_PAYLOAD.as_slice())] // larger bytes packet
-    #[tokio::test]
-    async fn read_payload_correct_known(#[case] payload: &[u8]) {
-        let packet = produce_packet_bytes(payload).await;
-
-        let size = u64::from_le_bytes({
-            let mut buf = [0; 8];
-            buf.copy_from_slice(&packet[..8]);
-            buf
-        });
-
-        let mut mock = Builder::new().read(&packet[8..]).build();
-
-        let mut r = BytesReader::with_size(&mut mock, size);
+        let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64)
+            .await
+            .unwrap();
         let mut buf = Vec::new();
         r.read_to_end(&mut buf).await.expect("must succeed");
 
@@ -321,9 +233,13 @@ mod tests {
             .read(&produce_packet_bytes(payload).await[0..8]) // We stop reading after the size packet
             .build();
 
-        let mut r = BytesReader::new(&mut mock, ..2048);
-        let mut buf = Vec::new();
-        assert_err!(r.read_to_end(&mut buf).await);
+        assert_eq!(
+            BytesReader::new(&mut mock, ..2048)
+                .await
+                .unwrap_err()
+                .kind(),
+            io::ErrorKind::InvalidData
+        );
     }
 
     /// Fail if the bytes packet is smaller than allowed
@@ -334,9 +250,13 @@ mod tests {
             .read(&produce_packet_bytes(payload).await[0..8]) // We stop reading after the size packet
             .build();
 
-        let mut r = BytesReader::new(&mut mock, 1024..2048);
-        let mut buf = Vec::new();
-        assert_err!(r.read_to_end(&mut buf).await);
+        assert_eq!(
+            BytesReader::new(&mut mock, 1024..2048)
+                .await
+                .unwrap_err()
+                .kind(),
+            io::ErrorKind::InvalidData
+        );
     }
 
     /// Fail if the padding is not all zeroes
@@ -348,7 +268,7 @@ mod tests {
         packet_bytes[12] = 0xff;
         let mut mock = Builder::new().read(&packet_bytes).build(); // We stop reading after the faulty bit
 
-        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
+        let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap();
         let mut buf = Vec::new();
 
         r.read_to_end(&mut buf).await.expect_err("must fail");
@@ -365,15 +285,13 @@ mod tests {
             .read(&produce_packet_bytes(payload).await[..4])
             .build();
 
-        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
-        let mut buf = [0u8; 1];
-
         assert_eq!(
-            r.read_exact(&mut buf).await.expect_err("must fail").kind(),
-            std::io::ErrorKind::UnexpectedEof
+            BytesReader::new(&mut mock, ..MAX_LEN)
+                .await
+                .expect_err("must fail")
+                .kind(),
+            io::ErrorKind::UnexpectedEof
         );
-
-        assert_eq!(&[0], &buf, "buffer should stay empty");
     }
 
     /// Start a 9 bytes payload packet, but have the underlying reader return
@@ -387,7 +305,7 @@ mod tests {
             .read(&produce_packet_bytes(payload).await[..8 + 4])
             .build();
 
-        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
+        let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap();
         let mut buf = [0; 9];
 
         r.read_exact(&mut buf[..4]).await.expect("must succeed");
@@ -414,7 +332,7 @@ mod tests {
             .read(&produce_packet_bytes(payload).await[..offset])
             .build();
 
-        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
+        let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap();
 
         // read_exact of the payload *body* will succeed, but a subsequent read will
         // return UnexpectedEof error.
@@ -441,10 +359,18 @@ mod tests {
             .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo"))
             .build();
 
-        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
-        let mut buf = Vec::new();
+        // Either length reading or data reading can fail, depending on which test case we're in.
+        let err: io::Error = async {
+            let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await?;
+            let mut buf = Vec::new();
+
+            r.read_to_end(&mut buf).await?;
+
+            Ok(())
+        }
+        .await
+        .expect_err("must fail");
 
-        let err = r.read_to_end(&mut buf).await.expect_err("must fail");
         assert_eq!(
             err.kind(),
             std::io::ErrorKind::Other,
@@ -468,7 +394,7 @@ mod tests {
             .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo"))
             .build();
 
-        let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
+        let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap();
         let mut buf = Vec::new();
 
         r.read_to_end(&mut buf).await.expect("must succeed");
@@ -492,7 +418,9 @@ mod tests {
             .read(&produce_packet_bytes(payload).await[offset..])
             .build();
 
-        let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64);
+        let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64)
+            .await
+            .unwrap();
         let mut buf = Vec::new();
         r.read_to_end(&mut buf).await.expect("must succeed");