about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/mod.rs')
-rw-r--r--tvix/nix-compat/src/wire/bytes/mod.rs56
1 files changed, 27 insertions, 29 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs
index fc777bafe20f..740a7ebfd03e 100644
--- a/tvix/nix-compat/src/wire/bytes/mod.rs
+++ b/tvix/nix-compat/src/wire/bytes/mod.rs
@@ -1,6 +1,6 @@
 use std::{
     io::{Error, ErrorKind},
-    ops::RangeBounds,
+    ops::RangeInclusive,
 };
 use tokio::io::{AsyncReadExt, AsyncWriteExt};
 
@@ -33,24 +33,29 @@ const LEN_SIZE: usize = 8;
 ///
 /// This buffers the entire payload into memory,
 /// a streaming version is available at [crate::wire::bytes::BytesReader].
-pub async fn read_bytes<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<Vec<u8>>
+pub async fn read_bytes<R>(
+    r: &mut R,
+    allowed_size: RangeInclusive<usize>,
+) -> std::io::Result<Vec<u8>>
 where
     R: AsyncReadExt + Unpin,
-    S: RangeBounds<u64>,
 {
     // read the length field
     let len = r.read_u64_le().await?;
-
-    if !allowed_size.contains(&len) {
-        return Err(std::io::Error::new(
-            std::io::ErrorKind::InvalidData,
-            "signalled package size not in allowed range",
-        ));
-    }
+    let len: usize = len
+        .try_into()
+        .ok()
+        .filter(|len| allowed_size.contains(len))
+        .ok_or_else(|| {
+            std::io::Error::new(
+                std::io::ErrorKind::InvalidData,
+                "signalled package size not in allowed range",
+            )
+        })?;
 
     // calculate the total length, including padding.
     // byte packets are padded to 8 byte blocks each.
-    let padded_len = padding_len(len) as u64 + len;
+    let padded_len = padding_len(len as u64) as u64 + (len as u64);
     let mut limited_reader = r.take(padded_len);
 
     let mut buf = Vec::new();
@@ -62,7 +67,7 @@ where
         return Err(std::io::ErrorKind::UnexpectedEof.into());
     }
 
-    let (_content, padding) = buf.split_at(len as usize);
+    let (_content, padding) = buf.split_at(len);
 
     // ensure the padding is all zeroes.
     if !padding.iter().all(|e| *e == b'\0') {
@@ -73,17 +78,19 @@ where
     }
 
     // return the data without the padding
-    buf.truncate(len as usize);
+    buf.truncate(len);
     Ok(buf)
 }
 
 /// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string.
 /// Internally uses [read_bytes].
 /// Rejects reading more than `allowed_size` bytes of payload.
-pub async fn read_string<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<String>
+pub async fn read_string<R>(
+    r: &mut R,
+    allowed_size: RangeInclusive<usize>,
+) -> std::io::Result<String>
 where
     R: AsyncReadExt + Unpin,
-    S: RangeBounds<u64>,
 {
     let bytes = read_bytes(r, allowed_size).await?;
     String::from_utf8(bytes).map_err(|e| Error::new(ErrorKind::InvalidData, e))
@@ -132,7 +139,7 @@ mod tests {
 
     /// The maximum length of bytes packets we're willing to accept in the test
     /// cases.
-    const MAX_LEN: u64 = 1024;
+    const MAX_LEN: usize = 1024;
 
     #[tokio::test]
     async fn test_read_8_bytes() {
@@ -143,10 +150,7 @@ mod tests {
 
         assert_eq!(
             &12345678u64.to_le_bytes(),
-            read_bytes(&mut mock, 0u64..MAX_LEN)
-                .await
-                .unwrap()
-                .as_slice()
+            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
         );
     }
 
@@ -159,10 +163,7 @@ mod tests {
 
         assert_eq!(
             hex!("010203040506070809"),
-            read_bytes(&mut mock, 0u64..MAX_LEN)
-                .await
-                .unwrap()
-                .as_slice()
+            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
         );
     }
 
@@ -174,10 +175,7 @@ mod tests {
 
         assert_eq!(
             hex!(""),
-            read_bytes(&mut mock, 0u64..MAX_LEN)
-                .await
-                .unwrap()
-                .as_slice()
+            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
         );
     }
 
@@ -187,7 +185,7 @@ mod tests {
     async fn test_read_reject_too_large() {
         let mut mock = Builder::new().read(&100u64.to_le_bytes()).build();
 
-        read_bytes(&mut mock, 10..10)
+        read_bytes(&mut mock, 10..=10)
             .await
             .expect_err("expect this to fail");
     }