diff options
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes')
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/reader/mod.rs | 75 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/reader/trailer.rs | 12 |
2 files changed, 72 insertions, 15 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs index cd45f78a0c84..4a8cfd1f6599 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs @@ -1,6 +1,7 @@ use std::{ future::Future, io, + num::NonZeroU64, ops::RangeBounds, pin::Pin, task::{self, ready, Poll}, @@ -33,14 +34,26 @@ pub struct BytesReader<R, T: Tag = Pad> { state: State<R, T>, } +/// Split the `user_len` into `body_len` and `tail_len`, which are respectively +/// the non-terminal 8-byte blocks, and the ≤8 bytes of user data contained in +/// the trailer block. +#[inline(always)] +fn split_user_len(user_len: NonZeroU64) -> (u64, u8) { + let n = user_len.get() - 1; + let body_len = n & !7; + let tail_len = (n & 7) as u8 + 1; + (body_len, tail_len) +} + #[derive(Debug)] enum State<R, T: Tag> { /// Full 8-byte blocks are being read and released to the caller. + /// NOTE: The final 8-byte block is *always* part of the trailer. Body { reader: Option<R>, consumed: u64, /// The total length of all user data contained in both the body and trailer. - user_len: u64, + user_len: NonZeroU64, }, /// The trailer is in the process of being read. ReadTrailer(ReadTrailer<R, T>), @@ -76,10 +89,16 @@ where } Ok(Self { - state: State::Body { - reader: Some(reader), - consumed: 0, - user_len: size, + state: match NonZeroU64::new(size) { + Some(size) => State::Body { + reader: Some(reader), + consumed: 0, + user_len: size, + }, + None => State::ReleaseTrailer { + consumed: 0, + data: read_trailer::<R, T>(reader, 0).await?, + }, }, }) } @@ -96,7 +115,7 @@ where match self.state { State::Body { consumed, user_len, .. - } => user_len - consumed, + } => user_len.get() - consumed, State::ReadTrailer(ref fut) => fut.len() as u64, State::ReleaseTrailer { consumed, ref data } => data.len() as u64 - consumed as u64, } @@ -119,13 +138,12 @@ impl<R: AsyncRead + Unpin, T: Tag> AsyncRead for BytesReader<R, T> { consumed, user_len, } => { - let body_len = *user_len & !7; + let (body_len, tail_len) = split_user_len(*user_len); let remaining = body_len - *consumed; let reader = if remaining == 0 { let reader = reader.take().unwrap(); - let user_len = (*user_len & 7) as u8; - *this = State::ReadTrailer(read_trailer(reader, user_len)); + *this = State::ReadTrailer(read_trailer(reader, tail_len)); continue; } else { reader.as_mut().unwrap() @@ -277,6 +295,45 @@ mod tests { ); } + /// Read the trailer immediately if there is no payload. + #[tokio::test] + async fn read_trailer_immediately() { + use crate::nar::wire::PadPar; + + let mut mock = Builder::new() + .read(&[0; 8]) + .read(&PadPar::PATTERN[8..]) + .build(); + + BytesReader::<_, PadPar>::new_internal(&mut mock, ..) + .await + .unwrap(); + + // The mock reader will panic if dropped without reading all data. + } + + /// Read the trailer even if we only read the exact payload size. + #[tokio::test] + async fn read_exact_trailer() { + use crate::nar::wire::PadPar; + + let mut mock = Builder::new() + .read(&16u64.to_le_bytes()) + .read(&[0x55; 16]) + .read(&PadPar::PATTERN[8..]) + .build(); + + let mut reader = BytesReader::<_, PadPar>::new_internal(&mut mock, ..) + .await + .unwrap(); + + let mut buf = [0; 16]; + reader.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, [0x55; 16]); + + // The mock reader will panic if dropped without reading all data. + } + /// Fail if the padding is not all zeroes #[tokio::test] async fn read_fail_if_nonzero_padding() { diff --git a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs index 82aa2a228095..3a5bb75e7103 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs @@ -9,11 +9,11 @@ use std::{ use tokio::io::{self, AsyncRead, ReadBuf}; -/// Trailer represents up to 7 bytes of data read as part of the trailer block(s) +/// Trailer represents up to 8 bytes of data read as part of the trailer block(s) #[derive(Debug)] pub(crate) struct Trailer { data_len: u8, - buf: [u8; 7], + buf: [u8; 8], } impl Deref for Trailer { @@ -28,7 +28,7 @@ impl Deref for Trailer { pub(crate) trait Tag { /// The expected suffix /// - /// The first 7 bytes may be ignored, and it must be an 8-byte aligned size. + /// The first 8 bytes may be ignored, and it must be an 8-byte aligned size. const PATTERN: &'static [u8]; /// Suitably sized buffer for reading [Self::PATTERN] @@ -67,7 +67,7 @@ pub(crate) fn read_trailer<R: AsyncRead + Unpin, T: Tag>( reader: R, data_len: u8, ) -> ReadTrailer<R, T> { - assert!(data_len < 8, "payload in trailer must be less than 8 bytes"); + assert!(data_len <= 8, "payload in trailer must be <= 8 bytes"); let buf = T::make_buf(); assert_eq!(buf.as_ref().len(), T::PATTERN.len()); @@ -108,8 +108,8 @@ impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> { } if this.filled as usize == T::PATTERN.len() { - let mut buf = [0; 7]; - buf.copy_from_slice(&this.buf.as_ref()[..7]); + let mut buf = [0; 8]; + buf.copy_from_slice(&this.buf.as_ref()[..8]); return Ok(Trailer { data_len: this.data_len, |