diff options
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/framing/framed_read.rs | 189 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/framing/mod.rs | 2 |
2 files changed, 191 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/nix_daemon/framing/framed_read.rs b/tvix/nix-compat/src/nix_daemon/framing/framed_read.rs new file mode 100644 index 000000000000..320ab64af049 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/framing/framed_read.rs @@ -0,0 +1,189 @@ +use std::{ + io::Result, + pin::Pin, + task::{ready, Poll}, +}; + +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, ReadBuf}; + +/// State machine for [`NixFramedReader`]. +/// +/// As the reader progresses it linearly cycles through the states. +#[derive(Debug)] +enum NixFramedReaderState { + /// The reader always starts in this state. + /// + /// Before the payload, the client first sends its size. + /// The size is a u64 which is 8 bytes long, while it's likely that we will receive + /// the whole u64 in one read, it's possible that it will arrive in smaller chunks. + /// So in this state we read up to 8 bytes and transition to + /// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero, + /// otherwise we reset filled to 0, and read the next size value. + ReadingSize { buf: [u8; 8], filled: usize }, + /// This is where we read the actual payload that is sent to us. + /// + /// Once we've read the expected number of bytes, we go back to the + /// [`NixFramedReaderState::ReadingSize`] state. + ReadingPayload { + /// Represents the remaining number of bytes we expect to read based on the value + /// read in the previous state. + remaining: u64, + }, +} + +pin_project! { + /// Implements Nix's Framed reader protocol for protocol versions >= 1.23. + /// + /// See serialization.md#framed and [`NixFramedReaderState`] for details. + pub struct NixFramedReader<R> { + #[pin] + reader: R, + state: NixFramedReaderState, + } +} + +impl<R> NixFramedReader<R> { + pub fn new(reader: R) -> Self { + Self { + reader, + state: NixFramedReaderState::ReadingSize { + buf: [0; 8], + filled: 0, + }, + } + } +} + +impl<R: AsyncRead> AsyncRead for NixFramedReader<R> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + read_buf: &mut ReadBuf<'_>, + ) -> Poll<Result<()>> { + let mut this = self.as_mut().project(); + match this.state { + NixFramedReaderState::ReadingSize { buf, filled } => { + if *filled < buf.len() { + let mut size_buf = ReadBuf::new(buf); + size_buf.advance(*filled); + + ready!(this.reader.poll_read(cx, &mut size_buf))?; + let bytes_read = size_buf.filled().len() - *filled; + if bytes_read == 0 { + // oef + return Poll::Ready(Ok(())); + } + *filled += bytes_read; + // Schedule ourselves to run again. + return self.poll_read(cx, read_buf); + } + let size = u64::from_le_bytes(*buf); + if size == 0 { + // eof + *filled = 0; + return Poll::Ready(Ok(())); + } + *this.state = NixFramedReaderState::ReadingPayload { remaining: size }; + self.poll_read(cx, read_buf) + } + NixFramedReaderState::ReadingPayload { remaining } => { + // Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms. + let safe_remaining = if *remaining <= usize::MAX as u64 { + *remaining as usize + } else { + usize::MAX + }; + if safe_remaining > 0 { + // The buffer is no larger than the amount of data that we expect. + // Otherwise we will trim the buffer below and come back here. + if read_buf.remaining() <= safe_remaining { + let filled_before = read_buf.filled().len(); + + ready!(this.reader.as_mut().poll_read(cx, read_buf))?; + let bytes_read = read_buf.filled().len() - filled_before; + + *remaining -= bytes_read as u64; + if *remaining == 0 { + *this.state = NixFramedReaderState::ReadingSize { + buf: [0; 8], + filled: 0, + }; + } + return Poll::Ready(Ok(())); + } + // Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes + // internal bookkeeping simpler. + let mut smaller_buf = read_buf.take(safe_remaining); + ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?; + + let bytes_read = smaller_buf.filled().len(); + + // SAFETY: we just read this number of bytes into read_buf's backing slice above. + unsafe { read_buf.assume_init(bytes_read) }; + read_buf.advance(bytes_read); + return Poll::Ready(Ok(())); + } + *this.state = NixFramedReaderState::ReadingSize { + buf: [0; 8], + filled: 0, + }; + self.poll_read(cx, read_buf) + } + } + } +} + +#[cfg(test)] +mod nix_framed_tests { + use std::time::Duration; + + use tokio::io::AsyncReadExt; + use tokio_test::io::Builder; + + use crate::nix_daemon::framing::NixFramedReader; + + #[tokio::test] + async fn read_hello_world_in_two_frames() { + let mut mock = Builder::new() + // The client sends len + .read(&5u64.to_le_bytes()) + // Immediately followed by the bytes + .read("hello".as_bytes()) + .wait(Duration::ZERO) + // Send more data separately + .read(&6u64.to_le_bytes()) + .read(" world".as_bytes()) + .build(); + + let mut reader = NixFramedReader::new(&mut mock); + let mut result = String::new(); + reader + .read_to_string(&mut result) + .await + .expect("Could not read into result"); + assert_eq!("hello world", result); + } + #[tokio::test] + async fn read_hello_world_in_two_frames_followed_by_zero_sized_frame() { + let mut mock = Builder::new() + // The client sends len + .read(&5u64.to_le_bytes()) + // Immediately followed by the bytes + .read("hello".as_bytes()) + .wait(Duration::ZERO) + // Send more data separately + .read(&6u64.to_le_bytes()) + .read(" world".as_bytes()) + .read(&0u64.to_le_bytes()) + .build(); + + let mut reader = NixFramedReader::new(&mut mock); + let mut result = String::new(); + reader + .read_to_string(&mut result) + .await + .expect("Could not read into result"); + assert_eq!("hello world", result); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/framing/mod.rs b/tvix/nix-compat/src/nix_daemon/framing/mod.rs index d4e19c2bb7db..d78c6c05f7ef 100644 --- a/tvix/nix-compat/src/nix_daemon/framing/mod.rs +++ b/tvix/nix-compat/src/nix_daemon/framing/mod.rs @@ -1,2 +1,4 @@ +mod framed_read; +pub use framed_read::NixFramedReader; mod stderr_read; pub use stderr_read::StderrReadFramedReader; |