diff options
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/reader/mod.rs')
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/reader/mod.rs | 75 |
1 files changed, 66 insertions, 9 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs index cd45f78a0c84..4a8cfd1f6599 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs @@ -1,6 +1,7 @@ use std::{ future::Future, io, + num::NonZeroU64, ops::RangeBounds, pin::Pin, task::{self, ready, Poll}, @@ -33,14 +34,26 @@ pub struct BytesReader<R, T: Tag = Pad> { state: State<R, T>, } +/// Split the `user_len` into `body_len` and `tail_len`, which are respectively +/// the non-terminal 8-byte blocks, and the ≤8 bytes of user data contained in +/// the trailer block. +#[inline(always)] +fn split_user_len(user_len: NonZeroU64) -> (u64, u8) { + let n = user_len.get() - 1; + let body_len = n & !7; + let tail_len = (n & 7) as u8 + 1; + (body_len, tail_len) +} + #[derive(Debug)] enum State<R, T: Tag> { /// Full 8-byte blocks are being read and released to the caller. + /// NOTE: The final 8-byte block is *always* part of the trailer. Body { reader: Option<R>, consumed: u64, /// The total length of all user data contained in both the body and trailer. - user_len: u64, + user_len: NonZeroU64, }, /// The trailer is in the process of being read. ReadTrailer(ReadTrailer<R, T>), @@ -76,10 +89,16 @@ where } Ok(Self { - state: State::Body { - reader: Some(reader), - consumed: 0, - user_len: size, + state: match NonZeroU64::new(size) { + Some(size) => State::Body { + reader: Some(reader), + consumed: 0, + user_len: size, + }, + None => State::ReleaseTrailer { + consumed: 0, + data: read_trailer::<R, T>(reader, 0).await?, + }, }, }) } @@ -96,7 +115,7 @@ where match self.state { State::Body { consumed, user_len, .. - } => user_len - consumed, + } => user_len.get() - consumed, State::ReadTrailer(ref fut) => fut.len() as u64, State::ReleaseTrailer { consumed, ref data } => data.len() as u64 - consumed as u64, } @@ -119,13 +138,12 @@ impl<R: AsyncRead + Unpin, T: Tag> AsyncRead for BytesReader<R, T> { consumed, user_len, } => { - let body_len = *user_len & !7; + let (body_len, tail_len) = split_user_len(*user_len); let remaining = body_len - *consumed; let reader = if remaining == 0 { let reader = reader.take().unwrap(); - let user_len = (*user_len & 7) as u8; - *this = State::ReadTrailer(read_trailer(reader, user_len)); + *this = State::ReadTrailer(read_trailer(reader, tail_len)); continue; } else { reader.as_mut().unwrap() @@ -277,6 +295,45 @@ mod tests { ); } + /// Read the trailer immediately if there is no payload. + #[tokio::test] + async fn read_trailer_immediately() { + use crate::nar::wire::PadPar; + + let mut mock = Builder::new() + .read(&[0; 8]) + .read(&PadPar::PATTERN[8..]) + .build(); + + BytesReader::<_, PadPar>::new_internal(&mut mock, ..) + .await + .unwrap(); + + // The mock reader will panic if dropped without reading all data. + } + + /// Read the trailer even if we only read the exact payload size. + #[tokio::test] + async fn read_exact_trailer() { + use crate::nar::wire::PadPar; + + let mut mock = Builder::new() + .read(&16u64.to_le_bytes()) + .read(&[0x55; 16]) + .read(&PadPar::PATTERN[8..]) + .build(); + + let mut reader = BytesReader::<_, PadPar>::new_internal(&mut mock, ..) + .await + .unwrap(); + + let mut buf = [0; 16]; + reader.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, [0x55; 16]); + + // The mock reader will panic if dropped without reading all data. + } + /// Fail if the padding is not all zeroes #[tokio::test] async fn read_fail_if_nonzero_padding() { |