diff options
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/mod.rs')
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/mod.rs | 56 |
1 files changed, 27 insertions, 29 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs index fc777bafe20f..740a7ebfd03e 100644 --- a/tvix/nix-compat/src/wire/bytes/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/mod.rs @@ -1,6 +1,6 @@ use std::{ io::{Error, ErrorKind}, - ops::RangeBounds, + ops::RangeInclusive, }; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -33,24 +33,29 @@ const LEN_SIZE: usize = 8; /// /// This buffers the entire payload into memory, /// a streaming version is available at [crate::wire::bytes::BytesReader]. -pub async fn read_bytes<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<Vec<u8>> +pub async fn read_bytes<R>( + r: &mut R, + allowed_size: RangeInclusive<usize>, +) -> std::io::Result<Vec<u8>> where R: AsyncReadExt + Unpin, - S: RangeBounds<u64>, { // read the length field let len = r.read_u64_le().await?; - - if !allowed_size.contains(&len) { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "signalled package size not in allowed range", - )); - } + let len: usize = len + .try_into() + .ok() + .filter(|len| allowed_size.contains(len)) + .ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + "signalled package size not in allowed range", + ) + })?; // calculate the total length, including padding. // byte packets are padded to 8 byte blocks each. - let padded_len = padding_len(len) as u64 + len; + let padded_len = padding_len(len as u64) as u64 + (len as u64); let mut limited_reader = r.take(padded_len); let mut buf = Vec::new(); @@ -62,7 +67,7 @@ where return Err(std::io::ErrorKind::UnexpectedEof.into()); } - let (_content, padding) = buf.split_at(len as usize); + let (_content, padding) = buf.split_at(len); // ensure the padding is all zeroes. if !padding.iter().all(|e| *e == b'\0') { @@ -73,17 +78,19 @@ where } // return the data without the padding - buf.truncate(len as usize); + buf.truncate(len); Ok(buf) } /// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string. /// Internally uses [read_bytes]. /// Rejects reading more than `allowed_size` bytes of payload. -pub async fn read_string<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<String> +pub async fn read_string<R>( + r: &mut R, + allowed_size: RangeInclusive<usize>, +) -> std::io::Result<String> where R: AsyncReadExt + Unpin, - S: RangeBounds<u64>, { let bytes = read_bytes(r, allowed_size).await?; String::from_utf8(bytes).map_err(|e| Error::new(ErrorKind::InvalidData, e)) @@ -132,7 +139,7 @@ mod tests { /// The maximum length of bytes packets we're willing to accept in the test /// cases. - const MAX_LEN: u64 = 1024; + const MAX_LEN: usize = 1024; #[tokio::test] async fn test_read_8_bytes() { @@ -143,10 +150,7 @@ mod tests { assert_eq!( &12345678u64.to_le_bytes(), - read_bytes(&mut mock, 0u64..MAX_LEN) - .await - .unwrap() - .as_slice() + read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice() ); } @@ -159,10 +163,7 @@ mod tests { assert_eq!( hex!("010203040506070809"), - read_bytes(&mut mock, 0u64..MAX_LEN) - .await - .unwrap() - .as_slice() + read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice() ); } @@ -174,10 +175,7 @@ mod tests { assert_eq!( hex!(""), - read_bytes(&mut mock, 0u64..MAX_LEN) - .await - .unwrap() - .as_slice() + read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice() ); } @@ -187,7 +185,7 @@ mod tests { async fn test_read_reject_too_large() { let mut mock = Builder::new().read(&100u64.to_le_bytes()).build(); - read_bytes(&mut mock, 10..10) + read_bytes(&mut mock, 10..=10) .await .expect_err("expect this to fail"); } |