use std::{ io::Result, pin::Pin, task::{ready, Poll}, }; use pin_project_lite::pin_project; use tokio::io::{AsyncRead, ReadBuf}; /// State machine for [`NixFramedReader`]. /// /// As the reader progresses it linearly cycles through the states. #[derive(Debug)] enum NixFramedReaderState { /// The reader always starts in this state. /// /// Before the payload, the client first sends its size. /// The size is a u64 which is 8 bytes long, while it's likely that we will receive /// the whole u64 in one read, it's possible that it will arrive in smaller chunks. /// So in this state we read up to 8 bytes and transition to /// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero, /// otherwise we reset filled to 0, and read the next size value. ReadingSize { buf: [u8; 8], filled: usize }, /// This is where we read the actual payload that is sent to us. /// /// Once we've read the expected number of bytes, we go back to the /// [`NixFramedReaderState::ReadingSize`] state. ReadingPayload { /// Represents the remaining number of bytes we expect to read based on the value /// read in the previous state. remaining: u64, }, } pin_project! { /// Implements Nix's Framed reader protocol for protocol versions >= 1.23. /// /// See serialization.md#framed and [`NixFramedReaderState`] for details. pub struct NixFramedReader { #[pin] reader: R, state: NixFramedReaderState, } } impl NixFramedReader { pub fn new(reader: R) -> Self { Self { reader, state: NixFramedReaderState::ReadingSize { buf: [0; 8], filled: 0, }, } } } impl AsyncRead for NixFramedReader { fn poll_read( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, read_buf: &mut ReadBuf<'_>, ) -> Poll> { let mut this = self.as_mut().project(); match this.state { NixFramedReaderState::ReadingSize { buf, filled } => { if *filled < buf.len() { let mut size_buf = ReadBuf::new(buf); size_buf.advance(*filled); ready!(this.reader.poll_read(cx, &mut size_buf))?; let bytes_read = size_buf.filled().len() - *filled; if bytes_read == 0 { // oef return Poll::Ready(Ok(())); } *filled += bytes_read; // Schedule ourselves to run again. return self.poll_read(cx, read_buf); } let size = u64::from_le_bytes(*buf); if size == 0 { // eof *filled = 0; return Poll::Ready(Ok(())); } *this.state = NixFramedReaderState::ReadingPayload { remaining: size }; self.poll_read(cx, read_buf) } NixFramedReaderState::ReadingPayload { remaining } => { // Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms. let safe_remaining = if *remaining <= usize::MAX as u64 { *remaining as usize } else { usize::MAX }; if safe_remaining > 0 { // The buffer is no larger than the amount of data that we expect. // Otherwise we will trim the buffer below and come back here. if read_buf.remaining() <= safe_remaining { let filled_before = read_buf.filled().len(); ready!(this.reader.as_mut().poll_read(cx, read_buf))?; let bytes_read = read_buf.filled().len() - filled_before; *remaining -= bytes_read as u64; if *remaining == 0 { *this.state = NixFramedReaderState::ReadingSize { buf: [0; 8], filled: 0, }; } return Poll::Ready(Ok(())); } // Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes // internal bookkeeping simpler. let mut smaller_buf = read_buf.take(safe_remaining); ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?; let bytes_read = smaller_buf.filled().len(); // SAFETY: we just read this number of bytes into read_buf's backing slice above. unsafe { read_buf.assume_init(bytes_read) }; read_buf.advance(bytes_read); return Poll::Ready(Ok(())); } *this.state = NixFramedReaderState::ReadingSize { buf: [0; 8], filled: 0, }; self.poll_read(cx, read_buf) } } } } #[cfg(test)] mod nix_framed_tests { use std::time::Duration; use tokio::io::AsyncReadExt; use tokio_test::io::Builder; use crate::nix_daemon::framing::NixFramedReader; #[tokio::test] async fn read_hello_world_in_two_frames() { let mut mock = Builder::new() // The client sends len .read(&5u64.to_le_bytes()) // Immediately followed by the bytes .read("hello".as_bytes()) .wait(Duration::ZERO) // Send more data separately .read(&6u64.to_le_bytes()) .read(" world".as_bytes()) .build(); let mut reader = NixFramedReader::new(&mut mock); let mut result = String::new(); reader .read_to_string(&mut result) .await .expect("Could not read into result"); assert_eq!("hello world", result); } #[tokio::test] async fn read_hello_world_in_two_frames_followed_by_zero_sized_frame() { let mut mock = Builder::new() // The client sends len .read(&5u64.to_le_bytes()) // Immediately followed by the bytes .read("hello".as_bytes()) .wait(Duration::ZERO) // Send more data separately .read(&6u64.to_le_bytes()) .read(" world".as_bytes()) .read(&0u64.to_le_bytes()) .build(); let mut reader = NixFramedReader::new(&mut mock); let mut result = String::new(); reader .read_to_string(&mut result) .await .expect("Could not read into result"); assert_eq!("hello world", result); } }