about summary refs log tree commit diff
path: root/tvix/nix-compat
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat')
-rw-r--r--tvix/nix-compat/src/nix_daemon/framing/framed_read.rs189
-rw-r--r--tvix/nix-compat/src/nix_daemon/framing/mod.rs2
2 files changed, 191 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/nix_daemon/framing/framed_read.rs b/tvix/nix-compat/src/nix_daemon/framing/framed_read.rs
new file mode 100644
index 000000000000..320ab64af049
--- /dev/null
+++ b/tvix/nix-compat/src/nix_daemon/framing/framed_read.rs
@@ -0,0 +1,189 @@
+use std::{
+    io::Result,
+    pin::Pin,
+    task::{ready, Poll},
+};
+
+use pin_project_lite::pin_project;
+use tokio::io::{AsyncRead, ReadBuf};
+
+/// State machine for [`NixFramedReader`].
+///
+/// As the reader progresses it linearly cycles through the states.
+#[derive(Debug)]
+enum NixFramedReaderState {
+    /// The reader always starts in this state.
+    ///
+    /// Before the payload, the client first sends its size.
+    /// The size is a u64 which is 8 bytes long, while it's likely that we will receive
+    /// the whole u64 in one read, it's possible that it will arrive in smaller chunks.
+    /// So in this state we read up to 8 bytes and transition to
+    /// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero,
+    /// otherwise we reset filled to 0, and read the next size value.
+    ReadingSize { buf: [u8; 8], filled: usize },
+    /// This is where we read the actual payload that is sent to us.
+    ///
+    /// Once we've read the expected number of bytes, we go back to the
+    /// [`NixFramedReaderState::ReadingSize`] state.
+    ReadingPayload {
+        /// Represents the remaining number of bytes we expect to read based on the value
+        /// read in the previous state.
+        remaining: u64,
+    },
+}
+
+pin_project! {
+    /// Implements Nix's Framed reader protocol for protocol versions >= 1.23.
+    ///
+    /// See serialization.md#framed and [`NixFramedReaderState`] for details.
+    pub struct NixFramedReader<R> {
+        #[pin]
+        reader: R,
+        state: NixFramedReaderState,
+    }
+}
+
+impl<R> NixFramedReader<R> {
+    pub fn new(reader: R) -> Self {
+        Self {
+            reader,
+            state: NixFramedReaderState::ReadingSize {
+                buf: [0; 8],
+                filled: 0,
+            },
+        }
+    }
+}
+
+impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
+    fn poll_read(
+        mut self: Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+        read_buf: &mut ReadBuf<'_>,
+    ) -> Poll<Result<()>> {
+        let mut this = self.as_mut().project();
+        match this.state {
+            NixFramedReaderState::ReadingSize { buf, filled } => {
+                if *filled < buf.len() {
+                    let mut size_buf = ReadBuf::new(buf);
+                    size_buf.advance(*filled);
+
+                    ready!(this.reader.poll_read(cx, &mut size_buf))?;
+                    let bytes_read = size_buf.filled().len() - *filled;
+                    if bytes_read == 0 {
+                        // oef
+                        return Poll::Ready(Ok(()));
+                    }
+                    *filled += bytes_read;
+                    // Schedule ourselves to run again.
+                    return self.poll_read(cx, read_buf);
+                }
+                let size = u64::from_le_bytes(*buf);
+                if size == 0 {
+                    // eof
+                    *filled = 0;
+                    return Poll::Ready(Ok(()));
+                }
+                *this.state = NixFramedReaderState::ReadingPayload { remaining: size };
+                self.poll_read(cx, read_buf)
+            }
+            NixFramedReaderState::ReadingPayload { remaining } => {
+                // Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms.
+                let safe_remaining = if *remaining <= usize::MAX as u64 {
+                    *remaining as usize
+                } else {
+                    usize::MAX
+                };
+                if safe_remaining > 0 {
+                    // The buffer is no larger than the amount of data that we expect.
+                    // Otherwise we will trim the buffer below and come back here.
+                    if read_buf.remaining() <= safe_remaining {
+                        let filled_before = read_buf.filled().len();
+
+                        ready!(this.reader.as_mut().poll_read(cx, read_buf))?;
+                        let bytes_read = read_buf.filled().len() - filled_before;
+
+                        *remaining -= bytes_read as u64;
+                        if *remaining == 0 {
+                            *this.state = NixFramedReaderState::ReadingSize {
+                                buf: [0; 8],
+                                filled: 0,
+                            };
+                        }
+                        return Poll::Ready(Ok(()));
+                    }
+                    // Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes
+                    // internal bookkeeping simpler.
+                    let mut smaller_buf = read_buf.take(safe_remaining);
+                    ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?;
+
+                    let bytes_read = smaller_buf.filled().len();
+
+                    // SAFETY: we just read this number of bytes into read_buf's backing slice above.
+                    unsafe { read_buf.assume_init(bytes_read) };
+                    read_buf.advance(bytes_read);
+                    return Poll::Ready(Ok(()));
+                }
+                *this.state = NixFramedReaderState::ReadingSize {
+                    buf: [0; 8],
+                    filled: 0,
+                };
+                self.poll_read(cx, read_buf)
+            }
+        }
+    }
+}
+
+#[cfg(test)]
+mod nix_framed_tests {
+    use std::time::Duration;
+
+    use tokio::io::AsyncReadExt;
+    use tokio_test::io::Builder;
+
+    use crate::nix_daemon::framing::NixFramedReader;
+
+    #[tokio::test]
+    async fn read_hello_world_in_two_frames() {
+        let mut mock = Builder::new()
+            // The client sends len
+            .read(&5u64.to_le_bytes())
+            // Immediately followed by the bytes
+            .read("hello".as_bytes())
+            .wait(Duration::ZERO)
+            // Send more data separately
+            .read(&6u64.to_le_bytes())
+            .read(" world".as_bytes())
+            .build();
+
+        let mut reader = NixFramedReader::new(&mut mock);
+        let mut result = String::new();
+        reader
+            .read_to_string(&mut result)
+            .await
+            .expect("Could not read into result");
+        assert_eq!("hello world", result);
+    }
+    #[tokio::test]
+    async fn read_hello_world_in_two_frames_followed_by_zero_sized_frame() {
+        let mut mock = Builder::new()
+            // The client sends len
+            .read(&5u64.to_le_bytes())
+            // Immediately followed by the bytes
+            .read("hello".as_bytes())
+            .wait(Duration::ZERO)
+            // Send more data separately
+            .read(&6u64.to_le_bytes())
+            .read(" world".as_bytes())
+            .read(&0u64.to_le_bytes())
+            .build();
+
+        let mut reader = NixFramedReader::new(&mut mock);
+        let mut result = String::new();
+        reader
+            .read_to_string(&mut result)
+            .await
+            .expect("Could not read into result");
+        assert_eq!("hello world", result);
+    }
+}
diff --git a/tvix/nix-compat/src/nix_daemon/framing/mod.rs b/tvix/nix-compat/src/nix_daemon/framing/mod.rs
index d4e19c2bb7db..d78c6c05f7ef 100644
--- a/tvix/nix-compat/src/nix_daemon/framing/mod.rs
+++ b/tvix/nix-compat/src/nix_daemon/framing/mod.rs
@@ -1,2 +1,4 @@
+mod framed_read;
+pub use framed_read::NixFramedReader;
 mod stderr_read;
 pub use stderr_read::StderrReadFramedReader;