From 84b27760d0e3260434656e15e82362dcea39319f Mon Sep 17 00:00:00 2001 From: edef Date: Mon, 29 Apr 2024 14:56:20 +0000 Subject: refactor(tvix/nix-compat/wire/bytes): use RangeInclusive for limits The (min, max) pair is already a RangeInclusive in essence, so we might as well represent it that way. Change-Id: I2f67f3c47dc36b87e866ff5dc2e0cd28f01fbb04 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11540 Tested-by: BuildkiteCI Reviewed-by: flokli --- tvix/nix-compat/src/wire/bytes/reader/mod.rs | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) (limited to 'tvix') diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs index 5f6081b40451..b46b0b53396b 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs @@ -1,7 +1,7 @@ use std::{ future::Future, io, - ops::{Bound, RangeBounds}, + ops::{Bound, RangeBounds, RangeInclusive}, pin::Pin, task::{self, ready, Poll}, }; @@ -37,10 +37,7 @@ enum State { /// The data size is being read. Size { reader: Option, - /// Minimum length (inclusive) - user_len_min: u64, - /// Maximum length (inclusive) - user_len_max: u64, + allowed_size: RangeInclusive, filled: u8, buf: [u8; 8], }, @@ -64,13 +61,11 @@ where { /// Constructs a new BytesReader, using the underlying passed reader. pub fn new>(reader: R, allowed_size: S) -> Self { - let user_len_min = match allowed_size.start_bound() { + let allowed_size = match allowed_size.start_bound() { Bound::Included(&n) => n, Bound::Excluded(&n) => n.saturating_add(1), Bound::Unbounded => 0, - }; - - let user_len_max = match allowed_size.end_bound() { + }..=match allowed_size.end_bound() { Bound::Included(&n) => n, Bound::Excluded(&n) => n.checked_sub(1).unwrap(), Bound::Unbounded => u64::MAX, @@ -79,8 +74,7 @@ where Self { state: State::Size { reader: Some(reader), - user_len_min, - user_len_max, + allowed_size, filled: 0, buf: [0; 8], }, @@ -128,15 +122,14 @@ impl AsyncRead for BytesReader { match this { State::Size { reader, - user_len_min, - user_len_max, + allowed_size, filled: 8, buf, } => { let reader = reader.take().unwrap(); let data_len = u64::from_le_bytes(*buf); - if data_len < *user_len_min || data_len > *user_len_max { + if !allowed_size.contains(&data_len) { return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size")) .into(); } -- cgit 1.4.1