diff options
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/reader/trailer.rs')
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/reader/trailer.rs | 197 |
1 files changed, 197 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs new file mode 100644 index 0000000000..3a5bb75e71 --- /dev/null +++ b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs @@ -0,0 +1,197 @@ +use std::{ + fmt::Debug, + future::Future, + marker::PhantomData, + ops::Deref, + pin::Pin, + task::{self, ready, Poll}, +}; + +use tokio::io::{self, AsyncRead, ReadBuf}; + +/// Trailer represents up to 8 bytes of data read as part of the trailer block(s) +#[derive(Debug)] +pub(crate) struct Trailer { + data_len: u8, + buf: [u8; 8], +} + +impl Deref for Trailer { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.buf[..self.data_len as usize] + } +} + +/// Tag defines a "trailer tag": specific, fixed bytes that must follow wire data. +pub(crate) trait Tag { + /// The expected suffix + /// + /// The first 8 bytes may be ignored, and it must be an 8-byte aligned size. + const PATTERN: &'static [u8]; + + /// Suitably sized buffer for reading [Self::PATTERN] + /// + /// HACK: This is a workaround for const generics limitations. + type Buf: AsRef<[u8]> + AsMut<[u8]> + Debug + Unpin; + + /// Make an instance of [Self::Buf] + fn make_buf() -> Self::Buf; +} + +#[derive(Debug)] +pub enum Pad {} + +impl Tag for Pad { + const PATTERN: &'static [u8] = &[0; 8]; + + type Buf = [u8; 8]; + + fn make_buf() -> Self::Buf { + [0; 8] + } +} + +#[derive(Debug)] +pub(crate) struct ReadTrailer<R, T: Tag> { + reader: R, + data_len: u8, + filled: u8, + buf: T::Buf, + _phantom: PhantomData<fn(T) -> T>, +} + +/// read_trailer returns a [Future] that reads a trailer with a given [Tag] from `reader` +pub(crate) fn read_trailer<R: AsyncRead + Unpin, T: Tag>( + reader: R, + data_len: u8, +) -> ReadTrailer<R, T> { + assert!(data_len <= 8, "payload in trailer must be <= 8 bytes"); + + let buf = T::make_buf(); + assert_eq!(buf.as_ref().len(), T::PATTERN.len()); + assert_eq!(T::PATTERN.len() % 8, 0); + + ReadTrailer { + reader, + data_len, + filled: if data_len != 0 { 0 } else { 8 }, + buf, + _phantom: PhantomData, + } +} + +impl<R, T: Tag> ReadTrailer<R, T> { + pub fn len(&self) -> u8 { + self.data_len + } +} + +impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> { + type Output = io::Result<Trailer>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> { + let this = &mut *self; + + loop { + if this.filled >= this.data_len { + let check_range = || this.data_len as usize..this.filled as usize; + + if this.buf.as_ref()[check_range()] != T::PATTERN[check_range()] { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid trailer", + )) + .into(); + } + } + + if this.filled as usize == T::PATTERN.len() { + let mut buf = [0; 8]; + buf.copy_from_slice(&this.buf.as_ref()[..8]); + + return Ok(Trailer { + data_len: this.data_len, + buf, + }) + .into(); + } + + let mut buf = ReadBuf::new(this.buf.as_mut()); + buf.advance(this.filled as usize); + + ready!(Pin::new(&mut this.reader).poll_read(cx, &mut buf))?; + + this.filled = { + let filled = buf.filled().len() as u8; + + if filled == this.filled { + return Err(io::ErrorKind::UnexpectedEof.into()).into(); + } + + filled + }; + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + + #[tokio::test] + async fn unexpected_eof() { + let reader = tokio_test::io::Builder::new() + .read(&[0xed]) + .wait(Duration::ZERO) + .read(&[0xef, 0x00]) + .build(); + + assert_eq!( + read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(), + io::ErrorKind::UnexpectedEof + ); + } + + #[tokio::test] + async fn invalid_padding() { + let reader = tokio_test::io::Builder::new() + .read(&[0xed]) + .wait(Duration::ZERO) + .read(&[0xef, 0x01, 0x00]) + .wait(Duration::ZERO) + .build(); + + assert_eq!( + read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(), + io::ErrorKind::InvalidData + ); + } + + #[tokio::test] + async fn success() { + let reader = tokio_test::io::Builder::new() + .read(&[0xed]) + .wait(Duration::ZERO) + .read(&[0xef, 0x00]) + .wait(Duration::ZERO) + .read(&[0x00, 0x00, 0x00, 0x00, 0x00]) + .build(); + + assert_eq!( + &*read_trailer::<_, Pad>(reader, 2).await.unwrap(), + &[0xed, 0xef] + ); + } + + #[tokio::test] + async fn no_padding() { + assert!(read_trailer::<_, Pad>(io::empty(), 0) + .await + .unwrap() + .is_empty()); + } +} |