diff options
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes')
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/reader/trailer.rs | 190 |
1 files changed, 134 insertions, 56 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs index 1a084d0eeb01..958cead42d8f 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs @@ -1,42 +1,148 @@ use std::{ + future::Future, + marker::PhantomData, + ops::Deref, pin::Pin, task::{self, ready, Poll}, }; use tokio::io::{self, AsyncRead, ReadBuf}; +/// Trailer represents up to 7 bytes of data read as part of the trailer block(s) #[derive(Debug)] -pub enum TrailerReader<R> { - Reading { - reader: R, - user_len: u8, - filled: u8, - buf: [u8; 8], - }, - Releasing { - off: u8, - len: u8, - buf: [u8; 8], - }, - Done, +pub(crate) struct Trailer { + data_len: u8, + buf: [u8; 7], } -impl<R: AsyncRead + Unpin> TrailerReader<R> { - pub fn new(reader: R, user_len: u8) -> Self { - if user_len == 0 { - return Self::Done; - } +impl Deref for Trailer { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.buf[..self.data_len as usize] + } +} + +/// Tag defines a "trailer tag": specific, fixed bytes that must follow wire data. +pub(crate) trait Tag { + /// The expected suffix + /// + /// The first 7 bytes may be ignored, and it must be an 8-byte aligned size. + const PATTERN: &'static [u8]; + + /// Suitably sized buffer for reading [Self::PATTERN] + /// + /// HACK: This is a workaround for const generics limitations. + type Buf: AsRef<[u8]> + AsMut<[u8]> + Unpin; + + /// Make an instance of [Self::Buf] + fn make_buf() -> Self::Buf; +} + +#[derive(Debug)] +pub(crate) enum Pad {} + +impl Tag for Pad { + const PATTERN: &'static [u8] = &[0; 8]; + + type Buf = [u8; 8]; + + fn make_buf() -> Self::Buf { + [0; 8] + } +} + +#[derive(Debug)] +pub(crate) struct ReadTrailer<R, T: Tag> { + reader: R, + data_len: u8, + filled: u8, + buf: T::Buf, + _phantom: PhantomData<*const T>, +} + +/// read_trailer returns a [Future] that reads a trailer with a given [Tag] from `reader` +pub(crate) fn read_trailer<R: AsyncRead + Unpin, T: Tag>( + reader: R, + data_len: u8, +) -> ReadTrailer<R, T> { + assert!(data_len < 8, "payload in trailer must be less than 8 bytes"); + + let buf = T::make_buf(); + assert_eq!(buf.as_ref().len(), T::PATTERN.len()); + assert_eq!(T::PATTERN.len() % 8, 0); + + ReadTrailer { + reader, + data_len, + filled: if data_len != 0 { 0 } else { 8 }, + buf, + _phantom: PhantomData, + } +} + +impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> { + type Output = io::Result<Trailer>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> { + let this = &mut *self; - assert!(user_len < 8, "payload in trailer must be less than 8 bytes"); - Self::Reading { - reader, - user_len, - filled: 0, - buf: [0; 8], + loop { + if this.filled >= this.data_len { + let check_range = || this.data_len as usize..this.filled as usize; + + if this.buf.as_ref()[check_range()] != T::PATTERN[check_range()] { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid trailer", + )) + .into(); + } + } + + if this.filled as usize == T::PATTERN.len() { + let mut buf = [0; 7]; + buf.copy_from_slice(&this.buf.as_ref()[..7]); + + return Ok(Trailer { + data_len: this.data_len, + buf, + }) + .into(); + } + + let mut buf = ReadBuf::new(this.buf.as_mut()); + buf.advance(this.filled as usize); + + ready!(Pin::new(&mut this.reader).poll_read(cx, &mut buf))?; + + this.filled = { + let prev_filled = this.filled; + let filled = buf.filled().len() as u8; + + if filled == prev_filled { + return Err(io::ErrorKind::UnexpectedEof.into()).into(); + } + + filled + }; } } } +#[derive(Debug)] +pub(crate) enum TrailerReader<R> { + Reading(ReadTrailer<R, Pad>), + Releasing { off: u8, data: Trailer }, + Done, +} + +impl<R: AsyncRead + Unpin> TrailerReader<R> { + pub fn new(reader: R, data_len: u8) -> Self { + Self::Reading(read_trailer(reader, data_len)) + } +} + impl<R: AsyncRead + Unpin> AsyncRead for TrailerReader<R> { fn poll_read( mut self: Pin<&mut Self>, @@ -47,47 +153,19 @@ impl<R: AsyncRead + Unpin> AsyncRead for TrailerReader<R> { loop { match this { - &mut Self::Reading { - reader: _, - user_len, - filled: 8, - buf, - } => { + Self::Reading(fut) => { *this = Self::Releasing { off: 0, - len: user_len, - buf, + data: ready!(Pin::new(fut).poll(cx))?, }; } - Self::Reading { - reader, - user_len, - filled, - buf, - } => { - let mut read_buf = ReadBuf::new(&mut buf[..]); - read_buf.advance(*filled as usize); - ready!(Pin::new(reader).poll_read(cx, &mut read_buf))?; - - let new_filled = read_buf.filled().len() as u8; - if *filled == new_filled { - return Err(io::ErrorKind::UnexpectedEof.into()).into(); - } - - *filled = new_filled; - - // ensure the padding is all zeroes - if (u64::from_le_bytes(*buf) >> (*user_len * 8)) != 0 { - return Err(io::ErrorKind::InvalidData.into()).into(); - } - } Self::Releasing { off: 8, .. } => { *this = Self::Done; } - Self::Releasing { off, len, buf } => { + Self::Releasing { off, data } => { assert_ne!(user_buf.remaining(), 0); - let buf = &buf[*off as usize..*len as usize]; + let buf = &data[*off as usize..]; let buf = &buf[..usize::min(buf.len(), user_buf.remaining())]; user_buf.put_slice(buf); |