about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader.rs32
1 files changed, 20 insertions, 12 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader.rs b/tvix/nix-compat/src/wire/bytes/reader.rs
index 9aea677645e7..9e8bd8704bbe 100644
--- a/tvix/nix-compat/src/wire/bytes/reader.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader.rs
@@ -1,6 +1,6 @@
 use pin_project_lite::pin_project;
 use std::{
-    ops::RangeBounds,
+    ops::{Bound, RangeBounds, RangeInclusive},
     task::{ready, Poll},
 };
 use tokio::io::AsyncRead;
@@ -29,31 +29,40 @@ pin_project! {
     /// In case of an error due to size constraints, or in case of not reading
     /// all the way to the end (and getting a EOF), the underlying reader is no
     /// longer usable and might return garbage.
-    pub struct BytesReader<R, S>
+    pub struct BytesReader<R>
     where
-    R: AsyncRead,
-    S: RangeBounds<u64>,
-
+    R: AsyncRead
     {
         #[pin]
         inner: R,
 
-        allowed_size: S,
+        allowed_size: RangeInclusive<u64>,
         payload_size: [u8; 8],
         state: BytesPacketPosition,
     }
 }
 
-impl<R, S> BytesReader<R, S>
+impl<R> BytesReader<R>
 where
     R: AsyncRead + Unpin,
-    S: RangeBounds<u64>,
 {
     /// Constructs a new BytesReader, using the underlying passed reader.
-    pub fn new(r: R, allowed_size: S) -> Self {
+    pub fn new<S: RangeBounds<u64>>(r: R, allowed_size: S) -> Self {
+        let user_len_min = match allowed_size.start_bound() {
+            Bound::Included(&n) => n,
+            Bound::Excluded(&n) => n.saturating_add(1),
+            Bound::Unbounded => 0,
+        };
+
+        let user_len_max = match allowed_size.end_bound() {
+            Bound::Included(&n) => n,
+            Bound::Excluded(&n) => n.checked_sub(1).unwrap(),
+            Bound::Unbounded => u64::MAX,
+        };
+
         Self {
             inner: r,
-            allowed_size,
+            allowed_size: user_len_min..=user_len_max,
             payload_size: [0; 8],
             state: BytesPacketPosition::Size(0),
         }
@@ -72,10 +81,9 @@ fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, std::io::Error>
     }
 }
 
-impl<R, S> AsyncRead for BytesReader<R, S>
+impl<R> AsyncRead for BytesReader<R>
 where
     R: AsyncRead,
-    S: RangeBounds<u64>,
 {
     fn poll_read(
         self: std::pin::Pin<&mut Self>,