From fdecf52a52026c47d393df3316c97d54109f58c4 Mon Sep 17 00:00:00 2001 From: edef Date: Mon, 29 Apr 2024 14:52:34 +0000 Subject: refactor(nix-compat/wire/bytes): fold TrailerReader into BytesReader The TrailerReader has no purpose separate from BytesReader, and the code gets a fair bit simpler this way. EOF handling is simplified, since we just rely on the implicit behaviour of the existing case. Change-Id: Id9b9f022c7c89fbc47968a96032fc43553af8290 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11539 Reviewed-by: Brian Olsen Tested-by: BuildkiteCI Reviewed-by: flokli --- tvix/nix-compat/src/wire/bytes/reader/mod.rs | 35 +++++++-- tvix/nix-compat/src/wire/bytes/reader/trailer.rs | 97 +++--------------------- 2 files changed, 41 insertions(+), 91 deletions(-) diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs index 9f7013a4e9b1..5f6081b40451 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs @@ -1,4 +1,5 @@ use std::{ + future::Future, io, ops::{Bound, RangeBounds}, pin::Pin, @@ -6,7 +7,7 @@ use std::{ }; use tokio::io::{AsyncRead, ReadBuf}; -use trailer::TrailerReader; +use trailer::{read_trailer, ReadTrailer, Trailer}; mod trailer; /// Reads a "bytes wire packet" from the underlying reader. @@ -33,6 +34,7 @@ pub struct BytesReader { #[derive(Debug)] enum State { + /// The data size is being read. Size { reader: Option, /// Minimum length (inclusive) @@ -42,12 +44,18 @@ enum State { filled: u8, buf: [u8; 8], }, + /// Full 8-byte blocks are being read and released to the caller. Body { reader: Option, consumed: u64, + /// The total length of all user data contained in both the body and trailer. user_len: u64, }, - Trailer(TrailerReader), + /// The trailer is in the process of being read. + ReadTrailer(ReadTrailer), + /// The trailer has been fully read and validated, + /// and data can now be released to the caller. + ReleaseTrailer { consumed: u8, data: Trailer }, } impl BytesReader @@ -100,7 +108,10 @@ where State::Body { consumed, user_len, .. } => Some(user_len - consumed), - State::Trailer(ref r) => Some(r.len() as u64), + State::ReadTrailer(ref fut) => Some(fut.len() as u64), + State::ReleaseTrailer { consumed, ref data } => { + Some(data.len() as u64 - consumed as u64) + } } } } @@ -166,7 +177,7 @@ impl AsyncRead for BytesReader { let reader = if remaining == 0 { let reader = reader.take().unwrap(); let user_len = (*user_len & 7) as u8; - *this = State::Trailer(TrailerReader::new(reader, user_len)); + *this = State::ReadTrailer(read_trailer(reader, user_len)); continue; } else { reader.as_mut().unwrap() @@ -188,8 +199,20 @@ impl AsyncRead for BytesReader { } .into(); } - State::Trailer(reader) => { - return Pin::new(reader).poll_read(cx, buf); + State::ReadTrailer(fut) => { + *this = State::ReleaseTrailer { + consumed: 0, + data: ready!(Pin::new(fut).poll(cx))?, + }; + } + State::ReleaseTrailer { consumed, data } => { + let data = &data[*consumed as usize..]; + let data = &data[..usize::min(data.len(), buf.remaining())]; + + buf.put_slice(data); + *consumed += data.len() as u8; + + return Ok(()).into(); } } } diff --git a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs index 61a77678080a..9b8bcaa2de4a 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs @@ -53,7 +53,7 @@ impl Tag for Pad { } #[derive(Debug)] -pub(crate) struct ReadTrailer { +pub(crate) struct ReadTrailer { reader: R, data_len: u8, filled: u8, @@ -90,7 +90,7 @@ impl ReadTrailer { impl Future for ReadTrailer { type Output = io::Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll { let this = &mut *self; loop { @@ -136,72 +136,9 @@ impl Future for ReadTrailer { } } -#[derive(Debug)] -pub(crate) enum TrailerReader { - Reading(ReadTrailer), - Releasing { off: u8, data: Trailer }, - Done, -} - -impl TrailerReader { - pub fn new(reader: R, data_len: u8) -> Self { - Self::Reading(read_trailer(reader, data_len)) - } - - pub fn len(&self) -> u8 { - match self { - TrailerReader::Reading(fut) => fut.len(), - &TrailerReader::Releasing { - off, - data: Trailer { data_len, .. }, - } => data_len - off, - TrailerReader::Done => 0, - } - } -} - -impl AsyncRead for TrailerReader { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut task::Context, - user_buf: &mut ReadBuf, - ) -> Poll> { - let this = &mut *self; - - loop { - match this { - Self::Reading(fut) => { - *this = Self::Releasing { - off: 0, - data: ready!(Pin::new(fut).poll(cx))?, - }; - } - Self::Releasing { off: 8, .. } => { - *this = Self::Done; - } - Self::Releasing { off, data } => { - assert_ne!(user_buf.remaining(), 0); - - let buf = &data[*off as usize..]; - let buf = &buf[..usize::min(buf.len(), user_buf.remaining())]; - - user_buf.put_slice(buf); - *off += buf.len() as u8; - - break; - } - Self::Done => break, - } - } - - Ok(()).into() - } -} - #[cfg(test)] mod tests { use std::time::Duration; - use tokio::io::AsyncReadExt; use super::*; @@ -213,11 +150,8 @@ mod tests { .read(&[0xef, 0x00]) .build(); - let mut reader = TrailerReader::new(reader, 2); - - let mut buf = vec![]; assert_eq!( - reader.read_to_end(&mut buf).await.unwrap_err().kind(), + read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(), io::ErrorKind::UnexpectedEof ); } @@ -231,11 +165,8 @@ mod tests { .wait(Duration::ZERO) .build(); - let mut reader = TrailerReader::new(reader, 2); - - let mut buf = vec![]; assert_eq!( - reader.read_to_end(&mut buf).await.unwrap_err().kind(), + read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(), io::ErrorKind::InvalidData ); } @@ -250,21 +181,17 @@ mod tests { .read(&[0x00, 0x00, 0x00, 0x00, 0x00]) .build(); - let mut reader = TrailerReader::new(reader, 2); - - let mut buf = vec![]; - reader.read_to_end(&mut buf).await.unwrap(); - - assert_eq!(buf, &[0xed, 0xef]); + assert_eq!( + &*read_trailer::<_, Pad>(reader, 2).await.unwrap(), + &[0xed, 0xef] + ); } #[tokio::test] async fn no_padding() { - let reader = tokio_test::io::Builder::new().build(); - let mut reader = TrailerReader::new(reader, 0); - - let mut buf = vec![]; - reader.read_to_end(&mut buf).await.unwrap(); - assert!(buf.is_empty()); + assert!(read_trailer::<_, Pad>(io::empty(), 0) + .await + .unwrap() + .is_empty()); } } -- cgit 1.4.1