about summary refs log tree commit diff
path: root/tvix/nix-compat/src/wire/bytes/reader/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/reader/mod.rs')
-rw-r--r--tvix/nix-compat/src/wire/bytes/reader/mod.rs21
1 files changed, 7 insertions, 14 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
index 5f6081b40451..b46b0b53396b 100644
--- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs
+++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs
@@ -1,7 +1,7 @@
 use std::{
     future::Future,
     io,
-    ops::{Bound, RangeBounds},
+    ops::{Bound, RangeBounds, RangeInclusive},
     pin::Pin,
     task::{self, ready, Poll},
 };
@@ -37,10 +37,7 @@ enum State<R> {
     /// The data size is being read.
     Size {
         reader: Option<R>,
-        /// Minimum length (inclusive)
-        user_len_min: u64,
-        /// Maximum length (inclusive)
-        user_len_max: u64,
+        allowed_size: RangeInclusive<u64>,
         filled: u8,
         buf: [u8; 8],
     },
@@ -64,13 +61,11 @@ where
 {
     /// Constructs a new BytesReader, using the underlying passed reader.
     pub fn new<S: RangeBounds<u64>>(reader: R, allowed_size: S) -> Self {
-        let user_len_min = match allowed_size.start_bound() {
+        let allowed_size = match allowed_size.start_bound() {
             Bound::Included(&n) => n,
             Bound::Excluded(&n) => n.saturating_add(1),
             Bound::Unbounded => 0,
-        };
-
-        let user_len_max = match allowed_size.end_bound() {
+        }..=match allowed_size.end_bound() {
             Bound::Included(&n) => n,
             Bound::Excluded(&n) => n.checked_sub(1).unwrap(),
             Bound::Unbounded => u64::MAX,
@@ -79,8 +74,7 @@ where
         Self {
             state: State::Size {
                 reader: Some(reader),
-                user_len_min,
-                user_len_max,
+                allowed_size,
                 filled: 0,
                 buf: [0; 8],
             },
@@ -128,15 +122,14 @@ impl<R: AsyncRead + Unpin> AsyncRead for BytesReader<R> {
             match this {
                 State::Size {
                     reader,
-                    user_len_min,
-                    user_len_max,
+                    allowed_size,
                     filled: 8,
                     buf,
                 } => {
                     let reader = reader.take().unwrap();
 
                     let data_len = u64::from_le_bytes(*buf);
-                    if data_len < *user_len_min || data_len > *user_len_max {
+                    if !allowed_size.contains(&data_len) {
                         return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size"))
                             .into();
                     }