use std::ops::RangeBounds; use tokio::io::AsyncReadExt; use super::primitive; #[allow(dead_code)] /// Read a limited number of bytes from the AsyncRead. /// Rejects reading more than `allowed_size` bytes of payload. /// Internally takes care of dealing with the padding, so the returned `Vec` /// only contains the payload. /// This always buffers the entire contents into memory, we'll add a streaming /// version later. pub async fn read_bytes(r: &mut R, allowed_size: S) -> std::io::Result> where R: AsyncReadExt + Unpin, S: RangeBounds, { // read the length field let len = primitive::read_u64(r).await?; if !allowed_size.contains(&len) { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "signalled package size not in allowed range", )); } // calculate the total length, including padding. // byte packets are padded to 8 byte blocks each. let padded_len = if len % 8 == 0 { len } else { len + (8 - len % 8) }; let mut limited_reader = r.take(padded_len); let mut buf = Vec::new(); let s = limited_reader.read_to_end(&mut buf).await?; // make sure we got exactly the number of bytes, and not less. if s as u64 != padded_len { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "got less bytes than expected", )); } let (_content, padding) = buf.split_at(len as usize); // ensure the padding is all zeroes. if !padding.iter().all(|e| *e == b'\0') { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "padding is not all zeroes", )); } // return the data without the padding buf.truncate(len as usize); Ok(buf) } #[allow(dead_code)] /// Read an unlimited number of bytes from the AsyncRead. /// Note this can exhaust memory. /// Internally uses [read_bytes], which takes care of dealing with the padding, /// so the returned `Vec` only contains the payload. pub async fn read_bytes_unchecked(r: &mut R) -> std::io::Result> { read_bytes(r, 0u64..).await } #[cfg(test)] mod tests { use tokio_test::io::Builder; use super::*; use hex_literal::hex; #[tokio::test] async fn test_read_8_bytes_unchecked() { let mut mock = Builder::new() .read(&8u64.to_le_bytes()) .read(&12345678u64.to_le_bytes()) .build(); assert_eq!( &12345678u64.to_le_bytes(), read_bytes_unchecked(&mut mock).await.unwrap().as_slice() ); } #[tokio::test] async fn test_read_9_bytes_unchecked() { let mut mock = Builder::new() .read(&9u64.to_le_bytes()) .read(&hex!("01020304050607080900000000000000")) .build(); assert_eq!( hex!("010203040506070809"), read_bytes_unchecked(&mut mock).await.unwrap().as_slice() ); } #[tokio::test] async fn test_read_0_bytes_unchecked() { // A empty byte packet is essentially just the 0 length field. // No data is read, and there's zero padding. let mut mock = Builder::new().read(&0u64.to_le_bytes()).build(); assert_eq!( hex!(""), read_bytes_unchecked(&mut mock).await.unwrap().as_slice() ); } #[tokio::test] /// Ensure we don't read any further than the size field if the length /// doesn't match the range we want to accept. async fn test_reject_too_large() { let mut mock = Builder::new().read(&100u64.to_le_bytes()).build(); read_bytes(&mut mock, 10..10) .await .expect_err("expect this to fail"); } }