use std::{ pin::Pin, task::{self, ready, Poll}, }; use tokio::io::{self, AsyncRead, ReadBuf}; #[derive(Debug)] pub enum TrailerReader { Reading { reader: R, user_len: u8, filled: u8, buf: [u8; 8], }, Releasing { off: u8, len: u8, buf: [u8; 8], }, Done, } impl TrailerReader { pub fn new(reader: R, user_len: u8) -> Self { if user_len == 0 { return Self::Done; } assert!(user_len < 8, "payload in trailer must be less than 8 bytes"); Self::Reading { reader, user_len, filled: 0, buf: [0; 8], } } } impl AsyncRead for TrailerReader { fn poll_read( self: Pin<&mut Self>, cx: &mut task::Context, user_buf: &mut ReadBuf, ) -> Poll> { let this = self.get_mut(); loop { match this { &mut Self::Reading { reader: _, user_len, filled: 8, buf, } => { *this = Self::Releasing { off: 0, len: user_len, buf, }; } Self::Reading { reader, user_len, filled, buf, } => { let mut read_buf = ReadBuf::new(&mut buf[..]); read_buf.advance(*filled as usize); ready!(Pin::new(reader).poll_read(cx, &mut read_buf))?; let new_filled = read_buf.filled().len() as u8; if *filled == new_filled { return Err(io::ErrorKind::UnexpectedEof.into()).into(); } *filled = new_filled; // ensure the padding is all zeroes if (u64::from_le_bytes(*buf) >> (*user_len * 8)) != 0 { return Err(io::ErrorKind::InvalidData.into()).into(); } } Self::Releasing { off: 8, .. } => { *this = Self::Done; } Self::Releasing { off, len, buf } => { assert_ne!(user_buf.remaining(), 0); let buf = &buf[*off as usize..*len as usize]; let buf = &buf[..usize::min(buf.len(), user_buf.remaining())]; user_buf.put_slice(buf); *off += buf.len() as u8; break; } Self::Done => break, } } Ok(()).into() } } #[cfg(test)] mod tests { use std::time::Duration; use tokio::io::AsyncReadExt; use super::*; #[tokio::test] async fn unexpected_eof() { let reader = tokio_test::io::Builder::new() .read(&[0xed]) .wait(Duration::ZERO) .read(&[0xef, 0x00]) .build(); let mut reader = TrailerReader::new(reader, 2); let mut buf = vec![]; assert_eq!( reader.read_to_end(&mut buf).await.unwrap_err().kind(), io::ErrorKind::UnexpectedEof ); } #[tokio::test] async fn invalid_padding() { let reader = tokio_test::io::Builder::new() .read(&[0xed]) .wait(Duration::ZERO) .read(&[0xef, 0x01, 0x00]) .wait(Duration::ZERO) .build(); let mut reader = TrailerReader::new(reader, 2); let mut buf = vec![]; assert_eq!( reader.read_to_end(&mut buf).await.unwrap_err().kind(), io::ErrorKind::InvalidData ); } #[tokio::test] async fn success() { let reader = tokio_test::io::Builder::new() .read(&[0xed]) .wait(Duration::ZERO) .read(&[0xef, 0x00]) .wait(Duration::ZERO) .read(&[0x00, 0x00, 0x00, 0x00, 0x00]) .build(); let mut reader = TrailerReader::new(reader, 2); let mut buf = vec![]; reader.read_to_end(&mut buf).await.unwrap(); assert_eq!(buf, &[0xed, 0xef]); } #[tokio::test] async fn no_padding() { let reader = tokio_test::io::Builder::new().build(); let mut reader = TrailerReader::new(reader, 0); let mut buf = vec![]; reader.read_to_end(&mut buf).await.unwrap(); assert!(buf.is_empty()); } }