about summary refs log tree commit diff
path: root/tvix/nix-compat/src/nix_daemon/framing/framed_read.rs
use std::{
    io::Result,
    pin::Pin,
    task::{ready, Poll},
};

use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, ReadBuf};

/// State machine for [`NixFramedReader`].
///
/// As the reader progresses it linearly cycles through the states.
#[derive(Debug)]
enum NixFramedReaderState {
    /// The reader always starts in this state.
    ///
    /// Before the payload, the client first sends its size.
    /// The size is a u64 which is 8 bytes long, while it's likely that we will receive
    /// the whole u64 in one read, it's possible that it will arrive in smaller chunks.
    /// So in this state we read up to 8 bytes and transition to
    /// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero,
    /// otherwise we reset filled to 0, and read the next size value.
    ReadingSize { buf: [u8; 8], filled: usize },
    /// This is where we read the actual payload that is sent to us.
    ///
    /// Once we've read the expected number of bytes, we go back to the
    /// [`NixFramedReaderState::ReadingSize`] state.
    ReadingPayload {
        /// Represents the remaining number of bytes we expect to read based on the value
        /// read in the previous state.
        remaining: u64,
    },
}

pin_project! {
    /// Implements Nix's Framed reader protocol for protocol versions >= 1.23.
    ///
    /// See serialization.md#framed and [`NixFramedReaderState`] for details.
    pub struct NixFramedReader<R> {
        #[pin]
        reader: R,
        state: NixFramedReaderState,
    }
}

impl<R> NixFramedReader<R> {
    pub fn new(reader: R) -> Self {
        Self {
            reader,
            state: NixFramedReaderState::ReadingSize {
                buf: [0; 8],
                filled: 0,
            },
        }
    }
}

impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        read_buf: &mut ReadBuf<'_>,
    ) -> Poll<Result<()>> {
        let mut this = self.as_mut().project();
        match this.state {
            NixFramedReaderState::ReadingSize { buf, filled } => {
                if *filled < buf.len() {
                    let mut size_buf = ReadBuf::new(buf);
                    size_buf.advance(*filled);

                    ready!(this.reader.poll_read(cx, &mut size_buf))?;
                    let bytes_read = size_buf.filled().len() - *filled;
                    if bytes_read == 0 {
                        // oef
                        return Poll::Ready(Ok(()));
                    }
                    *filled += bytes_read;
                    // Schedule ourselves to run again.
                    return self.poll_read(cx, read_buf);
                }
                let size = u64::from_le_bytes(*buf);
                if size == 0 {
                    // eof
                    *filled = 0;
                    return Poll::Ready(Ok(()));
                }
                *this.state = NixFramedReaderState::ReadingPayload { remaining: size };
                self.poll_read(cx, read_buf)
            }
            NixFramedReaderState::ReadingPayload { remaining } => {
                // Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms.
                let safe_remaining = if *remaining <= usize::MAX as u64 {
                    *remaining as usize
                } else {
                    usize::MAX
                };
                if safe_remaining > 0 {
                    // The buffer is no larger than the amount of data that we expect.
                    // Otherwise we will trim the buffer below and come back here.
                    if read_buf.remaining() <= safe_remaining {
                        let filled_before = read_buf.filled().len();

                        ready!(this.reader.as_mut().poll_read(cx, read_buf))?;
                        let bytes_read = read_buf.filled().len() - filled_before;

                        *remaining -= bytes_read as u64;
                        if *remaining == 0 {
                            *this.state = NixFramedReaderState::ReadingSize {
                                buf: [0; 8],
                                filled: 0,
                            };
                        }
                        return Poll::Ready(Ok(()));
                    }
                    // Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes
                    // internal bookkeeping simpler.
                    let mut smaller_buf = read_buf.take(safe_remaining);
                    ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?;

                    let bytes_read = smaller_buf.filled().len();

                    // SAFETY: we just read this number of bytes into read_buf's backing slice above.
                    unsafe { read_buf.assume_init(bytes_read) };
                    read_buf.advance(bytes_read);
                    return Poll::Ready(Ok(()));
                }
                *this.state = NixFramedReaderState::ReadingSize {
                    buf: [0; 8],
                    filled: 0,
                };
                self.poll_read(cx, read_buf)
            }
        }
    }
}

#[cfg(test)]
mod nix_framed_tests {
    use std::time::Duration;

    use tokio::io::AsyncReadExt;
    use tokio_test::io::Builder;

    use crate::nix_daemon::framing::NixFramedReader;

    #[tokio::test]
    async fn read_hello_world_in_two_frames() {
        let mut mock = Builder::new()
            // The client sends len
            .read(&5u64.to_le_bytes())
            // Immediately followed by the bytes
            .read("hello".as_bytes())
            .wait(Duration::ZERO)
            // Send more data separately
            .read(&6u64.to_le_bytes())
            .read(" world".as_bytes())
            .build();

        let mut reader = NixFramedReader::new(&mut mock);
        let mut result = String::new();
        reader
            .read_to_string(&mut result)
            .await
            .expect("Could not read into result");
        assert_eq!("hello world", result);
    }
    #[tokio::test]
    async fn read_hello_world_in_two_frames_followed_by_zero_sized_frame() {
        let mut mock = Builder::new()
            // The client sends len
            .read(&5u64.to_le_bytes())
            // Immediately followed by the bytes
            .read("hello".as_bytes())
            .wait(Duration::ZERO)
            // Send more data separately
            .read(&6u64.to_le_bytes())
            .read(" world".as_bytes())
            .read(&0u64.to_le_bytes())
            .build();

        let mut reader = NixFramedReader::new(&mut mock);
        let mut result = String::new();
        reader
            .read_to_string(&mut result)
            .await
            .expect("Could not read into result");
        assert_eq!("hello world", result);
    }
}