diff options
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/reader/mod.rs | 21 |
1 files changed, 7 insertions, 14 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs index 5f6081b40451..b46b0b53396b 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs @@ -1,7 +1,7 @@ use std::{ future::Future, io, - ops::{Bound, RangeBounds}, + ops::{Bound, RangeBounds, RangeInclusive}, pin::Pin, task::{self, ready, Poll}, }; @@ -37,10 +37,7 @@ enum State<R> { /// The data size is being read. Size { reader: Option<R>, - /// Minimum length (inclusive) - user_len_min: u64, - /// Maximum length (inclusive) - user_len_max: u64, + allowed_size: RangeInclusive<u64>, filled: u8, buf: [u8; 8], }, @@ -64,13 +61,11 @@ where { /// Constructs a new BytesReader, using the underlying passed reader. pub fn new<S: RangeBounds<u64>>(reader: R, allowed_size: S) -> Self { - let user_len_min = match allowed_size.start_bound() { + let allowed_size = match allowed_size.start_bound() { Bound::Included(&n) => n, Bound::Excluded(&n) => n.saturating_add(1), Bound::Unbounded => 0, - }; - - let user_len_max = match allowed_size.end_bound() { + }..=match allowed_size.end_bound() { Bound::Included(&n) => n, Bound::Excluded(&n) => n.checked_sub(1).unwrap(), Bound::Unbounded => u64::MAX, @@ -79,8 +74,7 @@ where Self { state: State::Size { reader: Some(reader), - user_len_min, - user_len_max, + allowed_size, filled: 0, buf: [0; 8], }, @@ -128,15 +122,14 @@ impl<R: AsyncRead + Unpin> AsyncRead for BytesReader<R> { match this { State::Size { reader, - user_len_min, - user_len_max, + allowed_size, filled: 8, buf, } => { let reader = reader.take().unwrap(); let data_len = u64::from_le_bytes(*buf); - if data_len < *user_len_min || data_len > *user_len_max { + if !allowed_size.contains(&data_len) { return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size")) .into(); } |