about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader.rs120
1 files changed, 57 insertions, 63 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader.rs b/tvix/nix-compat/src/wire/bytes/reader.rs
index 9e8bd8704bbe..4239b40fec2d 100644
--- a/tvix/nix-compat/src/wire/bytes/reader.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader.rs
@@ -1,45 +1,39 @@
-use pin_project_lite::pin_project;
 use std::{
+    io,
     ops::{Bound, RangeBounds, RangeInclusive},
-    task::{ready, Poll},
+    pin::Pin,
+    task::{self, ready, Poll},
 };
 use tokio::io::AsyncRead;
 
 use super::{padding_len, BytesPacketPosition, LEN_SIZE};
 
-pin_project! {
-    /// Reads a "bytes wire packet" from the underlying reader.
-    /// The format is the same as in [crate::wire::bytes::read_bytes],
-    /// however this structure provides a [AsyncRead] interface,
-    /// allowing to not having to pass around the entire payload in memory.
-    ///
-    /// After being constructed with the underlying reader and an allowed size,
-    /// subsequent requests to poll_read will return payload data until the end
-    /// of the packet is reached.
-    ///
-    /// Internally, it will first read over the size packet, filling payload_size,
-    /// ensuring it fits allowed_size, then return payload data.
-    /// It will only signal EOF (returning `Ok(())` without filling the buffer anymore)
-    /// when all padding has been successfully consumed too.
-    ///
-    /// This also means, it's important for a user to always read to the end,
-    /// and not just call read_exact - otherwise it might not skip over the
-    /// padding, and return garbage when reading the next packet.
-    ///
-    /// In case of an error due to size constraints, or in case of not reading
-    /// all the way to the end (and getting a EOF), the underlying reader is no
-    /// longer usable and might return garbage.
-    pub struct BytesReader<R>
-    where
-    R: AsyncRead
-    {
-        #[pin]
-        inner: R,
-
-        allowed_size: RangeInclusive<u64>,
-        payload_size: [u8; 8],
-        state: BytesPacketPosition,
-    }
+/// Reads a "bytes wire packet" from the underlying reader.
+/// The format is the same as in [crate::wire::bytes::read_bytes],
+/// however this structure provides a [AsyncRead] interface,
+/// allowing to not having to pass around the entire payload in memory.
+///
+/// After being constructed with the underlying reader and an allowed size,
+/// subsequent requests to poll_read will return payload data until the end
+/// of the packet is reached.
+///
+/// Internally, it will first read over the size packet, filling payload_size,
+/// ensuring it fits allowed_size, then return payload data.
+/// It will only signal EOF (returning `Ok(())` without filling the buffer anymore)
+/// when all padding has been successfully consumed too.
+///
+/// This also means, it's important for a user to always read to the end,
+/// and not just call read_exact - otherwise it might not skip over the
+/// padding, and return garbage when reading the next packet.
+///
+/// In case of an error due to size constraints, or in case of not reading
+/// all the way to the end (and getting a EOF), the underlying reader is no
+/// longer usable and might return garbage.
+pub struct BytesReader<R> {
+    inner: R,
+    allowed_size: RangeInclusive<u64>,
+    payload_size: [u8; 8],
+    state: BytesPacketPosition,
 }
 
 impl<R> BytesReader<R>
@@ -70,10 +64,10 @@ where
 }
 /// Returns an error if the passed usize is 0.
 #[inline]
-fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, std::io::Error> {
+fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, io::Error> {
     if bytes_read == 0 {
-        Err(std::io::Error::new(
-            std::io::ErrorKind::UnexpectedEof,
+        Err(io::Error::new(
+            io::ErrorKind::UnexpectedEof,
             "underlying reader returned EOF",
         ))
     } else {
@@ -83,60 +77,60 @@ fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, std::io::Error>
 
 impl<R> AsyncRead for BytesReader<R>
 where
-    R: AsyncRead,
+    R: AsyncRead + Unpin,
 {
     fn poll_read(
-        self: std::pin::Pin<&mut Self>,
-        cx: &mut std::task::Context<'_>,
-        buf: &mut tokio::io::ReadBuf<'_>,
-    ) -> Poll<std::io::Result<()>> {
-        let mut this = self.project();
+        self: Pin<&mut Self>,
+        cx: &mut task::Context,
+        buf: &mut tokio::io::ReadBuf,
+    ) -> Poll<io::Result<()>> {
+        let this = self.get_mut();
 
         // Use a loop, so we can deal with (multiple) state transitions.
         loop {
-            match *this.state {
+            match this.state {
                 BytesPacketPosition::Size(LEN_SIZE) => {
                     // used in case an invalid size was signalled.
-                    Err(std::io::Error::new(
-                        std::io::ErrorKind::InvalidData,
+                    Err(io::Error::new(
+                        io::ErrorKind::InvalidData,
                         "signalled package size not in allowed range",
                     ))?
                 }
                 BytesPacketPosition::Size(pos) => {
                     // try to read more of the size field.
                     // We wrap a ReadBuf around this.payload_size here, and set_filled.
-                    let mut read_buf = tokio::io::ReadBuf::new(this.payload_size);
+                    let mut read_buf = tokio::io::ReadBuf::new(&mut this.payload_size);
                     read_buf.advance(pos);
-                    ready!(this.inner.as_mut().poll_read(cx, &mut read_buf))?;
+                    ready!(Pin::new(&mut this.inner).poll_read(cx, &mut read_buf))?;
 
                     ensure_nonzero_bytes_read(read_buf.filled().len() - pos)?;
 
                     let total_size_read = read_buf.filled().len();
                     if total_size_read == LEN_SIZE {
                         // If the entire payload size was read, parse it
-                        let payload_size = u64::from_le_bytes(*this.payload_size);
+                        let payload_size = u64::from_le_bytes(this.payload_size);
 
                         if !this.allowed_size.contains(&payload_size) {
                             // If it's not in the allowed
                             // range, transition to failure mode
                             // `BytesPacketPosition::Size(LEN_SIZE)`, where only
                             // an error is returned.
-                            *this.state = BytesPacketPosition::Size(LEN_SIZE)
+                            this.state = BytesPacketPosition::Size(LEN_SIZE)
                         } else if payload_size == 0 {
                             // If the payload size is 0, move on to reading padding directly.
-                            *this.state = BytesPacketPosition::Padding(0)
+                            this.state = BytesPacketPosition::Padding(0)
                         } else {
                             // Else, transition to reading the payload.
-                            *this.state = BytesPacketPosition::Payload(0)
+                            this.state = BytesPacketPosition::Payload(0)
                         }
                     } else {
                         // If we still need to read more of payload size, update
                         // our position in the state.
-                        *this.state = BytesPacketPosition::Size(total_size_read)
+                        this.state = BytesPacketPosition::Size(total_size_read)
                     }
                 }
                 BytesPacketPosition::Payload(pos) => {
-                    let signalled_size = u64::from_le_bytes(*this.payload_size);
+                    let signalled_size = u64::from_le_bytes(this.payload_size);
                     // We don't enter this match arm at all if we're expecting empty payload
                     debug_assert!(signalled_size > 0, "signalled size must be larger than 0");
 
@@ -147,7 +141,7 @@ where
                         // Reducing these two u64 to usize on 32bits is fine - we
                         // only care about not reading too much, not too less.
                         let mut limited_buf = buf.take((signalled_size - pos) as usize);
-                        ready!(this.inner.as_mut().poll_read(cx, &mut limited_buf))?;
+                        ready!(Pin::new(&mut this.inner).poll_read(cx, &mut limited_buf))?;
                         limited_buf.filled().len()
                     })?;
 
@@ -158,11 +152,11 @@ where
                     if pos + bytes_read as u64 == signalled_size {
                         // If we now read all payload, transition to padding
                         // state.
-                        *this.state = BytesPacketPosition::Padding(0);
+                        this.state = BytesPacketPosition::Padding(0);
                     } else {
                         // if we didn't read everything yet, update our position
                         // in the state.
-                        *this.state = BytesPacketPosition::Payload(pos + bytes_read as u64);
+                        this.state = BytesPacketPosition::Payload(pos + bytes_read as u64);
                     }
 
                     // We return from poll_read here.
@@ -181,7 +175,7 @@ where
                     // bytes. Only return `Ready(Ok(()))` once we're past the
                     // padding (or in cases where polling the inner reader
                     // returns `Poll::Pending`).
-                    let signalled_size = u64::from_le_bytes(*this.payload_size);
+                    let signalled_size = u64::from_le_bytes(this.payload_size);
                     let total_padding_len = padding_len(signalled_size) as usize;
 
                     let padding_len_remaining = total_padding_len - pos;
@@ -192,15 +186,15 @@ where
                         let mut padding_buf = padding_buf.take(padding_len_remaining);
 
                         // read into padding_buf.
-                        ready!(this.inner.as_mut().poll_read(cx, &mut padding_buf))?;
+                        ready!(Pin::new(&mut this.inner).poll_read(cx, &mut padding_buf))?;
                         let bytes_read = ensure_nonzero_bytes_read(padding_buf.filled().len())?;
 
-                        *this.state = BytesPacketPosition::Padding(pos + bytes_read);
+                        this.state = BytesPacketPosition::Padding(pos + bytes_read);
 
                         // ensure the bytes are not null bytes
                         if !padding_buf.filled().iter().all(|e| *e == b'\0') {
-                            return Err(std::io::Error::new(
-                                std::io::ErrorKind::InvalidData,
+                            return Err(io::Error::new(
+                                io::ErrorKind::InvalidData,
                                 "padding is not all zeroes",
                             ))
                             .into();