diff options
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/reader/mod.rs | 202 |
1 files changed, 65 insertions, 137 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs index b46b0b53396b..ef59e9c160e4 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs @@ -1,12 +1,14 @@ use std::{ future::Future, io, - ops::{Bound, RangeBounds, RangeInclusive}, + ops::RangeBounds, pin::Pin, task::{self, ready, Poll}, }; use tokio::io::{AsyncRead, ReadBuf}; +use crate::wire::read_u64; + use trailer::{read_trailer, ReadTrailer, Trailer}; mod trailer; @@ -15,32 +17,21 @@ mod trailer; /// however this structure provides a [AsyncRead] interface, /// allowing to not having to pass around the entire payload in memory. /// -/// After being constructed with the underlying reader and an allowed size, -/// subsequent requests to poll_read will return payload data until the end -/// of the packet is reached. -/// -/// Internally, it will first read over the size packet, filling payload_size, -/// ensuring it fits allowed_size, then return payload data. +/// It is constructed by reading a size with [BytesReader::new], +/// and yields payload data until the end of the packet is reached. /// /// It will not return the final bytes before all padding has been successfully /// consumed as well, but the full length of the reader must be consumed. /// -/// In case of an error due to size constraints, or in case of not reading -/// all the way to the end (and getting a EOF), the underlying reader is no -/// longer usable and might return garbage. +/// If the data is not read all the way to the end, or an error is encountered, +/// the underlying reader is no longer usable and might return garbage. +#[derive(Debug)] pub struct BytesReader<R> { state: State<R>, } #[derive(Debug)] enum State<R> { - /// The data size is being read. - Size { - reader: Option<R>, - allowed_size: RangeInclusive<u64>, - filled: u8, - buf: [u8; 8], - }, /// Full 8-byte blocks are being read and released to the caller. Body { reader: Option<R>, @@ -60,52 +51,37 @@ where R: AsyncRead + Unpin, { /// Constructs a new BytesReader, using the underlying passed reader. - pub fn new<S: RangeBounds<u64>>(reader: R, allowed_size: S) -> Self { - let allowed_size = match allowed_size.start_bound() { - Bound::Included(&n) => n, - Bound::Excluded(&n) => n.saturating_add(1), - Bound::Unbounded => 0, - }..=match allowed_size.end_bound() { - Bound::Included(&n) => n, - Bound::Excluded(&n) => n.checked_sub(1).unwrap(), - Bound::Unbounded => u64::MAX, - }; - - Self { - state: State::Size { - reader: Some(reader), - allowed_size, - filled: 0, - buf: [0; 8], - }, + pub async fn new<S: RangeBounds<u64>>(mut reader: R, allowed_size: S) -> io::Result<Self> { + let size = read_u64(&mut reader).await?; + + if !allowed_size.contains(&size) { + return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size")); } - } - /// Construct a new BytesReader with a known, and already-read size. - pub fn with_size(reader: R, size: u64) -> Self { - Self { + Ok(Self { state: State::Body { reader: Some(reader), consumed: 0, user_len: size, }, - } + }) + } + + /// Returns whether there is any remaining data to be read. + pub fn is_empty(&self) -> bool { + self.len() == 0 } /// Remaining data length, ie not including data already read. /// /// If the size has not been read yet, this is [None]. - #[allow(clippy::len_without_is_empty)] // if size is unknown, we can't answer that - pub fn len(&self) -> Option<u64> { + pub fn len(&self) -> u64 { match self.state { - State::Size { .. } => None, State::Body { consumed, user_len, .. - } => Some(user_len - consumed), - State::ReadTrailer(ref fut) => Some(fut.len() as u64), - State::ReleaseTrailer { consumed, ref data } => { - Some(data.len() as u64 - consumed as u64) - } + } => user_len - consumed, + State::ReadTrailer(ref fut) => fut.len() as u64, + State::ReleaseTrailer { consumed, ref data } => data.len() as u64 - consumed as u64, } } } @@ -120,45 +96,6 @@ impl<R: AsyncRead + Unpin> AsyncRead for BytesReader<R> { loop { match this { - State::Size { - reader, - allowed_size, - filled: 8, - buf, - } => { - let reader = reader.take().unwrap(); - - let data_len = u64::from_le_bytes(*buf); - if !allowed_size.contains(&data_len) { - return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size")) - .into(); - } - - *this = State::Body { - reader: Some(reader), - consumed: 0, - user_len: data_len, - }; - } - State::Size { - reader, - filled, - buf, - .. - } => { - let reader = reader.as_mut().unwrap(); - - let mut read_buf = ReadBuf::new(&mut buf[..]); - read_buf.advance(*filled as usize); - ready!(Pin::new(reader).poll_read(cx, &mut read_buf))?; - - let new_filled = read_buf.filled().len() as u8; - if *filled == new_filled { - return Err(io::ErrorKind::UnexpectedEof.into()).into(); - } - - *filled = new_filled; - } State::Body { reader, consumed, @@ -245,7 +182,7 @@ mod tests { use lazy_static::lazy_static; use rstest::rstest; use tokio::io::AsyncReadExt; - use tokio_test::{assert_err, io::Builder}; + use tokio_test::io::Builder; use super::*; @@ -279,34 +216,9 @@ mod tests { .read(&produce_packet_bytes(payload).await) .build(); - let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64); - let mut buf = Vec::new(); - r.read_to_end(&mut buf).await.expect("must succeed"); - - assert_eq!(payload, &buf[..]); - } - - /// Read bytes packets of various length, and ensure read_to_end returns the - /// expected payload. - #[rstest] - #[case::empty(&[])] // empty bytes packet - #[case::size_1b(&[0xff])] // 1 bytes payload - #[case::size_8b(&hex!("0001020304050607"))] // 8 bytes payload (no padding) - #[case::size_9b(&hex!("000102030405060708"))] // 9 bytes payload (7 bytes padding) - #[case::size_1m(LARGE_PAYLOAD.as_slice())] // larger bytes packet - #[tokio::test] - async fn read_payload_correct_known(#[case] payload: &[u8]) { - let packet = produce_packet_bytes(payload).await; - - let size = u64::from_le_bytes({ - let mut buf = [0; 8]; - buf.copy_from_slice(&packet[..8]); - buf - }); - - let mut mock = Builder::new().read(&packet[8..]).build(); - - let mut r = BytesReader::with_size(&mut mock, size); + let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64) + .await + .unwrap(); let mut buf = Vec::new(); r.read_to_end(&mut buf).await.expect("must succeed"); @@ -321,9 +233,13 @@ mod tests { .read(&produce_packet_bytes(payload).await[0..8]) // We stop reading after the size packet .build(); - let mut r = BytesReader::new(&mut mock, ..2048); - let mut buf = Vec::new(); - assert_err!(r.read_to_end(&mut buf).await); + assert_eq!( + BytesReader::new(&mut mock, ..2048) + .await + .unwrap_err() + .kind(), + io::ErrorKind::InvalidData + ); } /// Fail if the bytes packet is smaller than allowed @@ -334,9 +250,13 @@ mod tests { .read(&produce_packet_bytes(payload).await[0..8]) // We stop reading after the size packet .build(); - let mut r = BytesReader::new(&mut mock, 1024..2048); - let mut buf = Vec::new(); - assert_err!(r.read_to_end(&mut buf).await); + assert_eq!( + BytesReader::new(&mut mock, 1024..2048) + .await + .unwrap_err() + .kind(), + io::ErrorKind::InvalidData + ); } /// Fail if the padding is not all zeroes @@ -348,7 +268,7 @@ mod tests { packet_bytes[12] = 0xff; let mut mock = Builder::new().read(&packet_bytes).build(); // We stop reading after the faulty bit - let mut r = BytesReader::new(&mut mock, ..MAX_LEN); + let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap(); let mut buf = Vec::new(); r.read_to_end(&mut buf).await.expect_err("must fail"); @@ -365,15 +285,13 @@ mod tests { .read(&produce_packet_bytes(payload).await[..4]) .build(); - let mut r = BytesReader::new(&mut mock, ..MAX_LEN); - let mut buf = [0u8; 1]; - assert_eq!( - r.read_exact(&mut buf).await.expect_err("must fail").kind(), - std::io::ErrorKind::UnexpectedEof + BytesReader::new(&mut mock, ..MAX_LEN) + .await + .expect_err("must fail") + .kind(), + io::ErrorKind::UnexpectedEof ); - - assert_eq!(&[0], &buf, "buffer should stay empty"); } /// Start a 9 bytes payload packet, but have the underlying reader return @@ -387,7 +305,7 @@ mod tests { .read(&produce_packet_bytes(payload).await[..8 + 4]) .build(); - let mut r = BytesReader::new(&mut mock, ..MAX_LEN); + let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap(); let mut buf = [0; 9]; r.read_exact(&mut buf[..4]).await.expect("must succeed"); @@ -414,7 +332,7 @@ mod tests { .read(&produce_packet_bytes(payload).await[..offset]) .build(); - let mut r = BytesReader::new(&mut mock, ..MAX_LEN); + let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap(); // read_exact of the payload *body* will succeed, but a subsequent read will // return UnexpectedEof error. @@ -441,10 +359,18 @@ mod tests { .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo")) .build(); - let mut r = BytesReader::new(&mut mock, ..MAX_LEN); - let mut buf = Vec::new(); + // Either length reading or data reading can fail, depending on which test case we're in. + let err: io::Error = async { + let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await?; + let mut buf = Vec::new(); + + r.read_to_end(&mut buf).await?; + + Ok(()) + } + .await + .expect_err("must fail"); - let err = r.read_to_end(&mut buf).await.expect_err("must fail"); assert_eq!( err.kind(), std::io::ErrorKind::Other, @@ -468,7 +394,7 @@ mod tests { .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo")) .build(); - let mut r = BytesReader::new(&mut mock, ..MAX_LEN); + let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap(); let mut buf = Vec::new(); r.read_to_end(&mut buf).await.expect("must succeed"); @@ -492,7 +418,9 @@ mod tests { .read(&produce_packet_bytes(payload).await[offset..]) .build(); - let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64); + let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64) + .await + .unwrap(); let mut buf = Vec::new(); r.read_to_end(&mut buf).await.expect("must succeed"); |