use std::{ io::{Error, ErrorKind}, ops::RangeInclusive, }; use tokio::io::{AsyncReadExt, AsyncWriteExt}; mod reader; pub use reader::BytesReader; mod writer; pub use writer::BytesWriter; /// 8 null bytes, used to write out padding. const EMPTY_BYTES: &[u8; 8] = &[0u8; 8]; /// The length of the size field, in bytes is always 8. const LEN_SIZE: usize = 8; #[allow(dead_code)] /// Read a "bytes wire packet" from the AsyncRead. /// Rejects reading more than `allowed_size` bytes of payload. /// /// The packet is made up of three parts: /// - a length header, u64, LE-encoded /// - the payload itself /// - null bytes to the next 8 byte boundary /// /// Ensures the payload size fits into the `allowed_size` passed, /// and that the padding is actual null bytes. /// /// On success, the returned `Vec` only contains the payload itself. /// On failure (for example if a too large byte packet was sent), the reader /// becomes unusable. /// /// This buffers the entire payload into memory, /// a streaming version is available at [crate::wire::bytes::BytesReader]. pub async fn read_bytes( r: &mut R, allowed_size: RangeInclusive, ) -> std::io::Result> where R: AsyncReadExt + Unpin, { // read the length field let len = r.read_u64_le().await?; let len: usize = len .try_into() .ok() .filter(|len| allowed_size.contains(len)) .ok_or_else(|| { std::io::Error::new( std::io::ErrorKind::InvalidData, "signalled package size not in allowed range", ) })?; // calculate the total length, including padding. // byte packets are padded to 8 byte blocks each. let padded_len = padding_len(len as u64) as u64 + (len as u64); let mut limited_reader = r.take(padded_len); let mut buf = Vec::new(); let s = limited_reader.read_to_end(&mut buf).await?; // make sure we got exactly the number of bytes, and not less. if s as u64 != padded_len { return Err(std::io::ErrorKind::UnexpectedEof.into()); } let (_content, padding) = buf.split_at(len); // ensure the padding is all zeroes. if !padding.iter().all(|e| *e == b'\0') { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "padding is not all zeroes", )); } // return the data without the padding buf.truncate(len); Ok(buf) } /// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string. /// Internally uses [read_bytes]. /// Rejects reading more than `allowed_size` bytes of payload. pub async fn read_string( r: &mut R, allowed_size: RangeInclusive, ) -> std::io::Result where R: AsyncReadExt + Unpin, { let bytes = read_bytes(r, allowed_size).await?; String::from_utf8(bytes).map_err(|e| Error::new(ErrorKind::InvalidData, e)) } /// Writes a "bytes wire packet" to a (hopefully buffered) [AsyncWriteExt]. /// /// Accepts anything implementing AsRef<[u8]> as payload. /// /// See [read_bytes] for a description of the format. /// /// Note: if performance matters to you, make sure your /// [AsyncWriteExt] handle is buffered. This function is quite /// write-intesive. pub async fn write_bytes>( w: &mut W, b: B, ) -> std::io::Result<()> { // write the size packet. w.write_u64_le(b.as_ref().len() as u64).await?; // write the payload w.write_all(b.as_ref()).await?; // write padding if needed let padding_len = padding_len(b.as_ref().len() as u64) as usize; if padding_len != 0 { w.write_all(&EMPTY_BYTES[..padding_len]).await?; } Ok(()) } /// Computes the number of bytes we should add to len (a length in /// bytes) to be aligned on 64 bits (8 bytes). fn padding_len(len: u64) -> u8 { let aligned = len.wrapping_add(7) & !7; aligned.wrapping_sub(len) as u8 } #[cfg(test)] mod tests { use tokio_test::{assert_ok, io::Builder}; use super::*; use hex_literal::hex; /// The maximum length of bytes packets we're willing to accept in the test /// cases. const MAX_LEN: usize = 1024; #[tokio::test] async fn test_read_8_bytes() { let mut mock = Builder::new() .read(&8u64.to_le_bytes()) .read(&12345678u64.to_le_bytes()) .build(); assert_eq!( &12345678u64.to_le_bytes(), read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice() ); } #[tokio::test] async fn test_read_9_bytes() { let mut mock = Builder::new() .read(&9u64.to_le_bytes()) .read(&hex!("01020304050607080900000000000000")) .build(); assert_eq!( hex!("010203040506070809"), read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice() ); } #[tokio::test] async fn test_read_0_bytes() { // A empty byte packet is essentially just the 0 length field. // No data is read, and there's zero padding. let mut mock = Builder::new().read(&0u64.to_le_bytes()).build(); assert_eq!( hex!(""), read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice() ); } #[tokio::test] /// Ensure we don't read any further than the size field if the length /// doesn't match the range we want to accept. async fn test_read_reject_too_large() { let mut mock = Builder::new().read(&100u64.to_le_bytes()).build(); read_bytes(&mut mock, 10..=10) .await .expect_err("expect this to fail"); } #[tokio::test] async fn test_write_bytes_no_padding() { let input = hex!("6478696f34657661"); let len = input.len() as u64; let mut mock = Builder::new() .write(&len.to_le_bytes()) .write(&input) .build(); assert_ok!(write_bytes(&mut mock, &input).await) } #[tokio::test] async fn test_write_bytes_with_padding() { let input = hex!("322e332e3137"); let len = input.len() as u64; let mut mock = Builder::new() .write(&len.to_le_bytes()) .write(&hex!("322e332e31370000")) .build(); assert_ok!(write_bytes(&mut mock, &input).await) } #[tokio::test] async fn test_write_string() { let input = "Hello, World!"; let len = input.len() as u64; let mut mock = Builder::new() .write(&len.to_le_bytes()) .write(&hex!("48656c6c6f2c20576f726c6421000000")) .build(); assert_ok!(write_bytes(&mut mock, &input).await) } #[test] fn padding_len_u64_max() { assert_eq!(padding_len(u64::MAX), 1); } }