diff options
Diffstat (limited to 'tvix/nix-compat/src/wire/bytes/mod.rs')
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/mod.rs | 254 |
1 files changed, 254 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs new file mode 100644 index 000000000000..9487536eb720 --- /dev/null +++ b/tvix/nix-compat/src/wire/bytes/mod.rs @@ -0,0 +1,254 @@ +use std::{ + io::{Error, ErrorKind}, + ops::RangeBounds, +}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +mod reader; +pub use reader::BytesReader; +mod writer; +pub use writer::BytesWriter; + +use super::primitive; + +/// 8 null bytes, used to write out padding. +const EMPTY_BYTES: &[u8; 8] = &[0u8; 8]; + +/// The length of the size field, in bytes is always 8. +const LEN_SIZE: usize = 8; + +#[allow(dead_code)] +/// Read a "bytes wire packet" from the AsyncRead. +/// Rejects reading more than `allowed_size` bytes of payload. +/// +/// The packet is made up of three parts: +/// - a length header, u64, LE-encoded +/// - the payload itself +/// - null bytes to the next 8 byte boundary +/// +/// Ensures the payload size fits into the `allowed_size` passed, +/// and that the padding is actual null bytes. +/// +/// On success, the returned `Vec<u8>` only contains the payload itself. +/// On failure (for example if a too large byte packet was sent), the reader +/// becomes unusable. +/// +/// This buffers the entire payload into memory, a streaming version will be +/// added later. +pub async fn read_bytes<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<Vec<u8>> +where + R: AsyncReadExt + Unpin, + S: RangeBounds<u64>, +{ + // read the length field + let len = primitive::read_u64(r).await?; + + if !allowed_size.contains(&len) { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "signalled package size not in allowed range", + )); + } + + // calculate the total length, including padding. + // byte packets are padded to 8 byte blocks each. + let padded_len = padding_len(len) as u64 + (len as u64); + let mut limited_reader = r.take(padded_len); + + let mut buf = Vec::new(); + + let s = limited_reader.read_to_end(&mut buf).await?; + + // make sure we got exactly the number of bytes, and not less. + if s as u64 != padded_len { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "got less bytes than expected", + )); + } + + let (_content, padding) = buf.split_at(len as usize); + + // ensure the padding is all zeroes. + if !padding.iter().all(|e| *e == b'\0') { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "padding is not all zeroes", + )); + } + + // return the data without the padding + buf.truncate(len as usize); + Ok(buf) +} + +/// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string. +/// Internally uses [read_bytes]. +/// Rejects reading more than `allowed_size` bytes of payload. +pub async fn read_string<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<String> +where + R: AsyncReadExt + Unpin, + S: RangeBounds<u64>, +{ + let bytes = read_bytes(r, allowed_size).await?; + String::from_utf8(bytes).map_err(|e| Error::new(ErrorKind::InvalidData, e)) +} + +/// Writes a "bytes wire packet" to a (hopefully buffered) [AsyncWriteExt]. +/// +/// Accepts anything implementing AsRef<[u8]> as payload. +/// +/// See [read_bytes] for a description of the format. +/// +/// Note: if performance matters to you, make sure your +/// [AsyncWriteExt] handle is buffered. This function is quite +/// write-intesive. +pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>( + w: &mut W, + b: B, +) -> std::io::Result<()> { + // write the size packet. + primitive::write_u64(w, b.as_ref().len() as u64).await?; + + // write the payload + w.write_all(b.as_ref()).await?; + + // write padding if needed + let padding_len = padding_len(b.as_ref().len() as u64) as usize; + if padding_len != 0 { + w.write_all(&EMPTY_BYTES[..padding_len]).await?; + } + Ok(()) +} + +/// Computes the number of bytes we should add to len (a length in +/// bytes) to be alined on 64 bits (8 bytes). +fn padding_len(len: u64) -> u8 { + let modulo = len % 8; + if modulo == 0 { + 0 + } else { + 8 - modulo as u8 + } +} + +/// Models the position inside a "bytes wire packet" that the reader or writer +/// is in. +/// It can be in three different stages, inside size, payload or padding fields. +/// The number tracks the number of bytes written inside the specific field. +/// There shall be no ambiguous states, at the end of a stage we immediately +/// move to the beginning of the next one: +/// - Size(LEN_SIZE) must be expressed as Payload(0) +/// - Payload(self.payload_len) must be expressed as Padding(0) +/// There's one exception - Size(LEN_SIZE) in the reader represents a failure +/// state we enter in case the allowed size doesn't match the allowed range. +/// +/// Padding(padding_len) means we're at the end of the bytes wire packet. +#[derive(Clone, Debug, PartialEq, Eq)] +enum BytesPacketPosition { + Size(usize), + Payload(u64), + Padding(usize), +} + +#[cfg(test)] +mod tests { + use tokio_test::{assert_ok, io::Builder}; + + use super::*; + use hex_literal::hex; + + /// The maximum length of bytes packets we're willing to accept in the test + /// cases. + const MAX_LEN: u64 = 1024; + + #[tokio::test] + async fn test_read_8_bytes() { + let mut mock = Builder::new() + .read(&8u64.to_le_bytes()) + .read(&12345678u64.to_le_bytes()) + .build(); + + assert_eq!( + &12345678u64.to_le_bytes(), + read_bytes(&mut mock, 0u64..MAX_LEN) + .await + .unwrap() + .as_slice() + ); + } + + #[tokio::test] + async fn test_read_9_bytes() { + let mut mock = Builder::new() + .read(&9u64.to_le_bytes()) + .read(&hex!("01020304050607080900000000000000")) + .build(); + + assert_eq!( + hex!("010203040506070809"), + read_bytes(&mut mock, 0u64..MAX_LEN) + .await + .unwrap() + .as_slice() + ); + } + + #[tokio::test] + async fn test_read_0_bytes() { + // A empty byte packet is essentially just the 0 length field. + // No data is read, and there's zero padding. + let mut mock = Builder::new().read(&0u64.to_le_bytes()).build(); + + assert_eq!( + hex!(""), + read_bytes(&mut mock, 0u64..MAX_LEN) + .await + .unwrap() + .as_slice() + ); + } + + #[tokio::test] + /// Ensure we don't read any further than the size field if the length + /// doesn't match the range we want to accept. + async fn test_read_reject_too_large() { + let mut mock = Builder::new().read(&100u64.to_le_bytes()).build(); + + read_bytes(&mut mock, 10..10) + .await + .expect_err("expect this to fail"); + } + + #[tokio::test] + async fn test_write_bytes_no_padding() { + let input = hex!("6478696f34657661"); + let len = input.len() as u64; + let mut mock = Builder::new() + .write(&len.to_le_bytes()) + .write(&input) + .build(); + assert_ok!(write_bytes(&mut mock, &input).await) + } + #[tokio::test] + async fn test_write_bytes_with_padding() { + let input = hex!("322e332e3137"); + let len = input.len() as u64; + let mut mock = Builder::new() + .write(&len.to_le_bytes()) + .write(&hex!("322e332e31370000")) + .build(); + assert_ok!(write_bytes(&mut mock, &input).await) + } + + #[tokio::test] + async fn test_write_string() { + let input = "Hello, World!"; + let len = input.len() as u64; + let mut mock = Builder::new() + .write(&len.to_le_bytes()) + .write(&hex!("48656c6c6f2c20576f726c6421000000")) + .build(); + assert_ok!(write_bytes(&mut mock, &input).await) + } +} |