diff options
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/mod.rs')
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/mod.rs | 179 |
1 files changed, 105 insertions, 74 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs index 9487536eb720..74adfb49b6a4 100644 --- a/tvix/nix-compat/src/wire/bytes/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/mod.rs @@ -1,23 +1,24 @@ +#[cfg(feature = "async")] +use std::mem::MaybeUninit; use std::{ io::{Error, ErrorKind}, - ops::RangeBounds, + ops::RangeInclusive, }; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +#[cfg(feature = "async")] +use tokio::io::ReadBuf; +use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; -mod reader; +pub(crate) mod reader; pub use reader::BytesReader; mod writer; pub use writer::BytesWriter; -use super::primitive; - /// 8 null bytes, used to write out padding. -const EMPTY_BYTES: &[u8; 8] = &[0u8; 8]; +pub(crate) const EMPTY_BYTES: &[u8; 8] = &[0u8; 8]; /// The length of the size field, in bytes is always 8. const LEN_SIZE: usize = 8; -#[allow(dead_code)] /// Read a "bytes wire packet" from the AsyncRead. /// Rejects reading more than `allowed_size` bytes of payload. /// @@ -33,26 +34,28 @@ const LEN_SIZE: usize = 8; /// On failure (for example if a too large byte packet was sent), the reader /// becomes unusable. /// -/// This buffers the entire payload into memory, a streaming version will be -/// added later. -pub async fn read_bytes<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<Vec<u8>> +/// This buffers the entire payload into memory, +/// a streaming version is available at [crate::wire::bytes::BytesReader]. +pub async fn read_bytes<R>(r: &mut R, allowed_size: RangeInclusive<usize>) -> io::Result<Vec<u8>> where - R: AsyncReadExt + Unpin, - S: RangeBounds<u64>, + R: AsyncReadExt + Unpin + ?Sized, { // read the length field - let len = primitive::read_u64(r).await?; - - if !allowed_size.contains(&len) { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "signalled package size not in allowed range", - )); - } + let len = r.read_u64_le().await?; + let len: usize = len + .try_into() + .ok() + .filter(|len| allowed_size.contains(len)) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "signalled package size not in allowed range", + ) + })?; // calculate the total length, including padding. // byte packets are padded to 8 byte blocks each. - let padded_len = padding_len(len) as u64 + (len as u64); + let padded_len = padding_len(len as u64) as u64 + (len as u64); let mut limited_reader = r.take(padded_len); let mut buf = Vec::new(); @@ -61,34 +64,89 @@ where // make sure we got exactly the number of bytes, and not less. if s as u64 != padded_len { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "got less bytes than expected", - )); + return Err(io::ErrorKind::UnexpectedEof.into()); } - let (_content, padding) = buf.split_at(len as usize); + let (_content, padding) = buf.split_at(len); // ensure the padding is all zeroes. - if !padding.iter().all(|e| *e == b'\0') { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, + if padding.iter().any(|&b| b != 0) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, "padding is not all zeroes", )); } // return the data without the padding - buf.truncate(len as usize); + buf.truncate(len); Ok(buf) } +#[cfg(feature = "async")] +pub(crate) async fn read_bytes_buf<'a, const N: usize, R>( + reader: &mut R, + buf: &'a mut [MaybeUninit<u8>; N], + allowed_size: RangeInclusive<usize>, +) -> io::Result<&'a [u8]> +where + R: AsyncReadExt + Unpin + ?Sized, +{ + assert_eq!(N % 8, 0); + assert!(*allowed_size.end() <= N); + + let len = reader.read_u64_le().await?; + let len: usize = len + .try_into() + .ok() + .filter(|len| allowed_size.contains(len)) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "signalled package size not in allowed range", + ) + })?; + + let buf_len = (len + 7) & !7; + let buf = { + let mut read_buf = ReadBuf::uninit(&mut buf[..buf_len]); + + while read_buf.filled().len() < buf_len { + reader.read_buf(&mut read_buf).await?; + } + + // ReadBuf::filled does not pass the underlying buffer's lifetime through, + // so we must make a trip to hell. + // + // SAFETY: `read_buf` is filled up to `buf_len`, and we verify that it is + // still pointing at the same underlying buffer. + unsafe { + assert_eq!(read_buf.filled().as_ptr(), buf.as_ptr() as *const u8); + assume_init_bytes(&buf[..buf_len]) + } + }; + + if buf[len..buf_len].iter().any(|&b| b != 0) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "padding is not all zeroes", + )); + } + + Ok(&buf[..len]) +} + +/// SAFETY: The bytes have to actually be initialized. +#[cfg(feature = "async")] +unsafe fn assume_init_bytes(slice: &[MaybeUninit<u8>]) -> &[u8] { + &*(slice as *const [MaybeUninit<u8>] as *const [u8]) +} + /// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string. /// Internally uses [read_bytes]. /// Rejects reading more than `allowed_size` bytes of payload. -pub async fn read_string<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<String> +pub async fn read_string<R>(r: &mut R, allowed_size: RangeInclusive<usize>) -> io::Result<String> where R: AsyncReadExt + Unpin, - S: RangeBounds<u64>, { let bytes = read_bytes(r, allowed_size).await?; String::from_utf8(bytes).map_err(|e| Error::new(ErrorKind::InvalidData, e)) @@ -106,9 +164,9 @@ where pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>( w: &mut W, b: B, -) -> std::io::Result<()> { +) -> io::Result<()> { // write the size packet. - primitive::write_u64(w, b.as_ref().len() as u64).await?; + w.write_u64_le(b.as_ref().len() as u64).await?; // write the payload w.write_all(b.as_ref()).await?; @@ -122,33 +180,10 @@ pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>( } /// Computes the number of bytes we should add to len (a length in -/// bytes) to be alined on 64 bits (8 bytes). +/// bytes) to be aligned on 64 bits (8 bytes). fn padding_len(len: u64) -> u8 { - let modulo = len % 8; - if modulo == 0 { - 0 - } else { - 8 - modulo as u8 - } -} - -/// Models the position inside a "bytes wire packet" that the reader or writer -/// is in. -/// It can be in three different stages, inside size, payload or padding fields. -/// The number tracks the number of bytes written inside the specific field. -/// There shall be no ambiguous states, at the end of a stage we immediately -/// move to the beginning of the next one: -/// - Size(LEN_SIZE) must be expressed as Payload(0) -/// - Payload(self.payload_len) must be expressed as Padding(0) -/// There's one exception - Size(LEN_SIZE) in the reader represents a failure -/// state we enter in case the allowed size doesn't match the allowed range. -/// -/// Padding(padding_len) means we're at the end of the bytes wire packet. -#[derive(Clone, Debug, PartialEq, Eq)] -enum BytesPacketPosition { - Size(usize), - Payload(u64), - Padding(usize), + let aligned = len.wrapping_add(7) & !7; + aligned.wrapping_sub(len) as u8 } #[cfg(test)] @@ -160,7 +195,7 @@ mod tests { /// The maximum length of bytes packets we're willing to accept in the test /// cases. - const MAX_LEN: u64 = 1024; + const MAX_LEN: usize = 1024; #[tokio::test] async fn test_read_8_bytes() { @@ -171,10 +206,7 @@ mod tests { assert_eq!( &12345678u64.to_le_bytes(), - read_bytes(&mut mock, 0u64..MAX_LEN) - .await - .unwrap() - .as_slice() + read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice() ); } @@ -187,10 +219,7 @@ mod tests { assert_eq!( hex!("010203040506070809"), - read_bytes(&mut mock, 0u64..MAX_LEN) - .await - .unwrap() - .as_slice() + read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice() ); } @@ -202,10 +231,7 @@ mod tests { assert_eq!( hex!(""), - read_bytes(&mut mock, 0u64..MAX_LEN) - .await - .unwrap() - .as_slice() + read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice() ); } @@ -215,7 +241,7 @@ mod tests { async fn test_read_reject_too_large() { let mut mock = Builder::new().read(&100u64.to_le_bytes()).build(); - read_bytes(&mut mock, 10..10) + read_bytes(&mut mock, 10..=10) .await .expect_err("expect this to fail"); } @@ -251,4 +277,9 @@ mod tests { .build(); assert_ok!(write_bytes(&mut mock, &input).await) } + + #[test] + fn padding_len_u64_max() { + assert_eq!(padding_len(u64::MAX), 1); + } } |