diff options
Diffstat (limited to 'tvix/nix-compat/src/nix_daemon/framing/stderr_read.rs')
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/framing/stderr_read.rs | 393 |
1 files changed, 393 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/nix_daemon/framing/stderr_read.rs b/tvix/nix-compat/src/nix_daemon/framing/stderr_read.rs new file mode 100644 index 000000000000..0074df901a2a --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/framing/stderr_read.rs @@ -0,0 +1,393 @@ +use std::{ + io::Result, + pin::Pin, + task::{ready, Poll}, +}; + +use bytes::{BufMut, BytesMut}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use crate::worker_protocol::STDERR_READ; + +#[derive(Debug)] +struct U64WriteState { + bytes: [u8; 8], + written: usize, +} + +impl U64WriteState { + fn remaining(&self) -> &[u8] { + &self.bytes[self.written..] + } +} + +/// State machine for [`StderrReadFramedReader`]. +/// +/// As the reader progresses it linearly cycles through the states. +#[derive(Debug)] +enum StderrReaderState { + /// Represents the state indicating that we are about to request a new frame. + /// + /// When poll_read is called, it writes STDERR_READ into the writer and + /// progresses to the [`StderrReaderState::RequestingFrameLen`] state + /// + /// The reader always starts in this state and is reached after every frame has + /// been fully read. + RequestingNextFrame { write_state: U64WriteState }, + /// At this point the reader writes the desired payload length we want to receive + /// based on read_buf.remaining(). + RequestingFrameLen { + // We need to write 8 bytes of the length u64 value, + // this variable stores how many we've written so far. + write_state: U64WriteState, + }, + /// At this point the reader just flushes the writer and gets ready to receive + /// the actual payload size that is about to be sent to us by transitioning to + /// the [`StderrReaderState::ReadingSize`] state. + FrameLenRequested, + /// The size is a u64 which is 8 bytes long, while it's likely that we will receive + /// the whole u64 in one read, it's possible that it will arrive in smaller chunks. + /// So in this state we read up to 8 bytes and transition to + /// [`StderrReaderState::ReadingPayload`] when done. + ReadingSize { buf: [u8; 8], filled: usize }, + /// This is where we read the actual payload that is sent to us. + /// All of the previous states were just internal bookkeeping where we did not return + /// any data to the conumer, and only returned Poll::Pending. + /// + /// Having read the full payload, progresses to the [`StderrReaderState::RequestingNextFrame`] + /// state to read the next frame when/if requested. + ReadingPayload { + /// Represents the remaining number of bytes we expect to read based on the value + /// read in the previous state. + remaining: u64, + /// Represents the remaining of padding we expect to read before switching back + /// to the RequestingNextFrame state. + pad: usize, + /// In an ideal case this reader does not allocate, but in the scenario where + /// we've read the whol payload frame but still have padding remaining, it's not + /// safe to return the payload to the consumer as there is risk that the reader + /// won't be called again, leaving dangling padding. In this case we store the + /// payload in this buffer until we've read the padding, and then return the data + /// from here. + tmp_buf: BytesMut, + }, +} + +impl StderrReaderState { + fn request_next_frame() -> Self { + Self::RequestingNextFrame { + write_state: U64WriteState { + bytes: STDERR_READ.to_le_bytes(), + written: 0, + }, + } + } + + fn read_written(len: u64) -> Self { + Self::RequestingFrameLen { + write_state: U64WriteState { + bytes: len.to_le_bytes(), + written: 0, + }, + } + } +} + +pin_project! { + /// Implements the reader protocol for STDERR_READ in nix protocol version 1.21..1.23. + /// + /// See logging.md#stderr_read and [`StderrReaderState`] for details. + /// + /// FUTUREWORK: As per the nix protocol, it should be possible to send logging messages + /// concurrently with reads, however this reader currently monopolizes the writer until eof is + /// reached or the writer is dropped. It's important we don't allow certain interleavings of + /// log writes, i.e. it's not ok to issue a log message right after we've requested + /// STDERR_READ, but before requesting the length. + pub struct StderrReadFramedReader<R, W> { + #[pin] + reader: R, + #[pin] + writer: W, + state: StderrReaderState + } +} + +impl<R, W> StderrReadFramedReader<R, W> { + pub fn new(reader: R, writer: W) -> Self { + Self { + reader, + writer, + state: StderrReaderState::request_next_frame(), + } + } +} + +impl<R: AsyncRead, W: AsyncWrite> AsyncRead for StderrReadFramedReader<R, W> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + read_buf: &mut ReadBuf<'_>, + ) -> Poll<Result<()>> { + loop { + let mut this = self.as_mut().project(); + match this.state { + StderrReaderState::RequestingNextFrame { write_state } => { + write_state.written += + ready!(this.writer.poll_write(cx, write_state.remaining()))?; + if write_state.written == 8 { + *this.state = StderrReaderState::read_written(read_buf.remaining() as u64); + } + } + StderrReaderState::RequestingFrameLen { write_state } => { + write_state.written += + ready!(this.writer.poll_write(cx, write_state.remaining()))?; + if write_state.written == 8 { + *this.state = StderrReaderState::FrameLenRequested; + } + } + StderrReaderState::FrameLenRequested => { + ready!(this.writer.poll_flush(cx))?; + *this.state = StderrReaderState::ReadingSize { + buf: [0u8; 8], + filled: 0, + }; + } + StderrReaderState::ReadingSize { buf, filled } => { + if *filled < buf.len() { + let mut size_buf = ReadBuf::new(buf); + size_buf.advance(*filled); + + ready!(this.reader.poll_read(cx, &mut size_buf))?; + let bytes_read = size_buf.filled().len() - *filled; + if bytes_read == 0 { + // oef + return Poll::Ready(Ok(())); + } + *filled += bytes_read; + continue; + } + let size = u64::from_le_bytes(*buf); + if size == 0 { + // eof + *this.state = StderrReaderState::request_next_frame(); + return Poll::Ready(Ok(())); + } + let pad = (8 - (size % 8) as usize) % 8; + *this.state = StderrReaderState::ReadingPayload { + remaining: size, + pad, + tmp_buf: BytesMut::new(), + }; + } + StderrReaderState::ReadingPayload { + remaining, + pad, + tmp_buf, + } => { + // Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms. + let safe_remaining = if *remaining <= (usize::MAX - *pad) as u64 { + *remaining as usize + *pad + } else { + usize::MAX + }; + if safe_remaining - *pad > 0 { + // The buffer is no larger than the amount of data that we expect. + // Otherwise we will trim the buffer below and come back here. + if read_buf.remaining() <= safe_remaining { + let filled_before = read_buf.filled().len(); + + ready!(this.reader.as_mut().poll_read(cx, read_buf))?; + let bytes_read = read_buf.filled().len() - filled_before; + let payload_size = std::cmp::min(bytes_read, safe_remaining - *pad); + + // we don't want to include padding bytes in the result, so we remove them from read_buf. + read_buf.set_filled(filled_before + payload_size); + + *remaining -= payload_size as u64; + if *remaining > 0 { + // We have more data to read so we just return ok, knowing that the consumer + // will read again. + return Poll::Ready(Ok(())); + } + + // If we don't have any remaining data to read, consume any padding we may have just read. + *pad -= bytes_read - payload_size; + if *pad != 0 { + // We haven't read all the padding yet, so we stash it away to return to the caller + // once we've read the remaining padding. + tmp_buf.clear(); + tmp_buf.put_slice(&read_buf.filled()[filled_before..payload_size]); + read_buf.set_filled(filled_before); + continue; + } + *this.state = StderrReaderState::request_next_frame(); + return Poll::Ready(Ok(())); + } + + // Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes + // internal bookkeeping simpler. + let mut smaller_buf = read_buf.take(safe_remaining); + ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?; + + let bytes_read = smaller_buf.filled().len(); + + // SAFETY: we just read this number of bytes into read_buf's backing slice above. + unsafe { read_buf.assume_init(bytes_read) }; + read_buf.advance(bytes_read); + return Poll::Ready(Ok(())); + } else if *pad > 0 { + // if we've read the whole payload but there is still padding remaining, + // we read it into a stack allocated array + let mut pad_arr = [0u8; 7]; + let mut pad_buf = ReadBuf::new(&mut pad_arr); + pad_buf.advance(7 - *pad); + ready!(this.reader.poll_read(cx, &mut pad_buf))?; + *pad = pad_buf.remaining(); + if *pad != 0 { + continue; + } + } + // now it's finally time to hand out the read data to the caller and reset to the RequestingNextFrame state. + read_buf.put_slice(tmp_buf); + tmp_buf.clear(); + *this.state = StderrReaderState::request_next_frame(); + return Poll::Ready(Ok(())); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use hex_literal::hex; + use tokio::io::{split, AsyncReadExt, BufReader}; + use tokio_test::io::Builder; + + use crate::{nix_daemon::framing::StderrReadFramedReader, worker_protocol::STDERR_READ}; + + #[tokio::test] + async fn test_single_two_byte_read_with_desired_size_ten() { + let mock = Builder::new() + // The reader should first write STDERR_READ and requested number of bytes into the writer + .write(&STDERR_READ.to_le_bytes()) + .write(&10u64.to_le_bytes()) + .wait(Duration::ZERO) + // The client sent not 10 but 2 bytes + .read(&2u64.to_le_bytes()) + // Immediately followed by the bytes and padding + .read("hi".as_bytes()) + .read(&hex!("0000 0000 0000")) + .build(); + let (r, w) = split(mock); + let mut reader = StderrReadFramedReader::new(r, w); + + let mut result = [0u8; 2]; + let mut buf_reader = BufReader::with_capacity(10, &mut reader); + let n = buf_reader.read_exact(&mut result).await.unwrap(); + + assert_eq!(2, n); + assert_eq!("hi".as_bytes(), result); + } + + #[tokio::test] + async fn test_single_read_with_padding_delayed() { + let mock = Builder::new() + // The reader should first write STDERR_READ and requested number of bytes into the writer + .write(&STDERR_READ.to_le_bytes()) + .write(&10u64.to_le_bytes()) + // The client sent 9 bytes not 10. + .read(&9u64.to_le_bytes()) + // Immeditaly followed by bytes + .read(&hex!("0202 0104 ffff ffaa 00")) + // Followed by a delayed padding + .wait(Duration::ZERO) + .read(&hex!("0000 0000 0000 00")) + .build(); + let (r, w) = split(mock); + let mut reader = StderrReadFramedReader::new(r, w); + + let mut result = [0u8; 9]; + let mut buf_reader = BufReader::with_capacity(10, &mut reader); + let n = buf_reader.read_exact(&mut result).await.unwrap(); + + assert_eq!(9, n); + assert_eq!(hex!("0202 0104 ffff ffaa 00"), result); + } + + #[tokio::test] + async fn test_multiple_consecutive_reads_with_arbitrary_delays() { + let mock = Builder::new() + // The reader should first write STDERR_READ and requested number of bytes into the writer + .write(&STDERR_READ.to_le_bytes()) + .write(&8192u64.to_le_bytes()) + .wait(Duration::ZERO) + // The client sends 6 bytes 'hello ' plus padding + .read(&6u64.to_le_bytes()) + .wait(Duration::ZERO) + .read("hello ".as_bytes()) + .read(&hex!("0000")) + // The reader sends desired length again + .write(&STDERR_READ.to_le_bytes()) + .write(&8192u64.to_le_bytes()) + // The client sends 11 bytes 'racerunners' with 's' and padding delayed + .wait(Duration::ZERO) + .read(&11u64.to_le_bytes()) + .read("racerunner".as_bytes()) + .wait(Duration::ZERO) + .read("s".as_bytes()) + .read(&hex!("0000 0000")) + .wait(Duration::ZERO) + .read(&hex!("00")) + .write(&STDERR_READ.to_le_bytes()) + .write(&8192u64.to_le_bytes()) + .wait(Duration::ZERO) + .read(&0u64.to_le_bytes()) + .build(); + let (r, w) = split(mock); + let mut reader = StderrReadFramedReader::new(r, w); + + let mut res = String::new(); + let mut buf_reader = BufReader::with_capacity(8192, &mut reader); + let n = buf_reader.read_to_string(&mut res).await.unwrap(); + + assert_eq!(17, n); + assert_eq!("hello racerunners", &res); + } + + #[tokio::test] + async fn test_single_read_where_writing_stderr_and_desired_size_take_more_than_one_write() { + let stderr_bytes = STDERR_READ.to_le_bytes(); + let length_bytes = 10u64.to_le_bytes(); + let mock = Builder::new() + .write(&stderr_bytes[..4]) + .wait(Duration::ZERO) + .write(&stderr_bytes[4..]) + .wait(Duration::ZERO) + .write(&length_bytes[..4]) + .wait(Duration::ZERO) + .write(&length_bytes[4..]) + .wait(Duration::ZERO) + // The client sent not 10 but 2 bytes + .read(&2u64.to_le_bytes()) + // Immediately followed by the bytes and padding + .read("hi".as_bytes()) + .read(&hex!("0000 0000 0000")) + .build(); + let (r, w) = split(mock); + let mut reader = StderrReadFramedReader::new(r, w); + + let mut result = [0u8; 2]; + let mut buf_reader = BufReader::with_capacity(10, &mut reader); + let n = buf_reader.read_exact(&mut result).await.unwrap(); + + assert_eq!(2, n); + assert_eq!("hi".as_bytes(), result); + } + + #[tokio::test] + async fn hello() {} +} |