diff options
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/reader.rs')
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/reader.rs | 32 |
1 files changed, 20 insertions, 12 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader.rs b/tvix/nix-compat/src/wire/bytes/reader.rs index 9aea677645e7..9e8bd8704bbe 100644 --- a/tvix/nix-compat/src/wire/bytes/reader.rs +++ b/tvix/nix-compat/src/wire/bytes/reader.rs @@ -1,6 +1,6 @@ use pin_project_lite::pin_project; use std::{ - ops::RangeBounds, + ops::{Bound, RangeBounds, RangeInclusive}, task::{ready, Poll}, }; use tokio::io::AsyncRead; @@ -29,31 +29,40 @@ pin_project! { /// In case of an error due to size constraints, or in case of not reading /// all the way to the end (and getting a EOF), the underlying reader is no /// longer usable and might return garbage. - pub struct BytesReader<R, S> + pub struct BytesReader<R> where - R: AsyncRead, - S: RangeBounds<u64>, - + R: AsyncRead { #[pin] inner: R, - allowed_size: S, + allowed_size: RangeInclusive<u64>, payload_size: [u8; 8], state: BytesPacketPosition, } } -impl<R, S> BytesReader<R, S> +impl<R> BytesReader<R> where R: AsyncRead + Unpin, - S: RangeBounds<u64>, { /// Constructs a new BytesReader, using the underlying passed reader. - pub fn new(r: R, allowed_size: S) -> Self { + pub fn new<S: RangeBounds<u64>>(r: R, allowed_size: S) -> Self { + let user_len_min = match allowed_size.start_bound() { + Bound::Included(&n) => n, + Bound::Excluded(&n) => n.saturating_add(1), + Bound::Unbounded => 0, + }; + + let user_len_max = match allowed_size.end_bound() { + Bound::Included(&n) => n, + Bound::Excluded(&n) => n.checked_sub(1).unwrap(), + Bound::Unbounded => u64::MAX, + }; + Self { inner: r, - allowed_size, + allowed_size: user_len_min..=user_len_max, payload_size: [0; 8], state: BytesPacketPosition::Size(0), } @@ -72,10 +81,9 @@ fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, std::io::Error> } } -impl<R, S> AsyncRead for BytesReader<R, S> +impl<R> AsyncRead for BytesReader<R> where R: AsyncRead, - S: RangeBounds<u64>, { fn poll_read( self: std::pin::Pin<&mut Self>, |