diff options
Diffstat (limited to 'tvix/nix-compat/src/wire')
-rw-r--r-- | tvix/nix-compat/src/wire/bytes.rs | 71 |
1 files changed, 62 insertions, 9 deletions
diff --git a/tvix/nix-compat/src/wire/bytes.rs b/tvix/nix-compat/src/wire/bytes.rs index f2fe30083b1c..a050b161048b 100644 --- a/tvix/nix-compat/src/wire/bytes.rs +++ b/tvix/nix-compat/src/wire/bytes.rs @@ -1,6 +1,6 @@ use std::ops::RangeBounds; -use tokio::io::AsyncReadExt; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::primitive; @@ -28,12 +28,7 @@ where // calculate the total length, including padding. // byte packets are padded to 8 byte blocks each. - let padded_len = if len % 8 == 0 { - len - } else { - len + (8 - len % 8) - }; - + let padded_len = padding_len(len) as u64 + (len as u64); let mut limited_reader = r.take(padded_len); let mut buf = Vec::new(); @@ -63,6 +58,32 @@ where Ok(buf) } +/// Writes a sequence of sized bits to a (hopefully buffered) +/// [AsyncWriteExt] handle. +/// +/// On the wire, it looks as follows: +/// +/// 1. Number of bytes contained in the buffer we're about to write on +/// the wire. (LE-encoded on 64 bits) +/// 2. Raw payload. +/// 3. Null padding up until the next 8 bytes alignment block. +/// +/// Note: if performance matters to you, make sure your +/// [AsyncWriteExt] handle is buffered. This function is quite +/// write-intesive. +pub async fn write_bytes<W: AsyncWriteExt + Unpin>(w: &mut W, b: &[u8]) -> std::io::Result<()> { + // We're assuming the handle is buffered: we can afford not + // writing all the bytes in one go. + let len = b.len(); + primitive::write_u64(w, len as u64).await?; + w.write_all(b).await?; + let padding = padding_len(len as u64); + if padding != 0 { + w.write_all(&vec![0; padding as usize]).await?; + } + Ok(()) +} + #[allow(dead_code)] /// Read an unlimited number of bytes from the AsyncRead. /// Note this can exhaust memory. @@ -72,9 +93,20 @@ pub async fn read_bytes_unchecked<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io read_bytes(r, 0u64..).await } +/// Computes the number of bytes we should add to len (a length in +/// bytes) to be alined on 64 bits (8 bytes). +fn padding_len(len: u64) -> u8 { + let modulo = len % 8; + if modulo == 0 { + 0 + } else { + 8 - modulo as u8 + } +} + #[cfg(test)] mod tests { - use tokio_test::io::Builder; + use tokio_test::{assert_ok, io::Builder}; use super::*; use hex_literal::hex; @@ -120,11 +152,32 @@ mod tests { #[tokio::test] /// Ensure we don't read any further than the size field if the length /// doesn't match the range we want to accept. - async fn test_reject_too_large() { + async fn test_read_reject_too_large() { let mut mock = Builder::new().read(&100u64.to_le_bytes()).build(); read_bytes(&mut mock, 10..10) .await .expect_err("expect this to fail"); } + + #[tokio::test] + async fn test_write_bytes_no_padding() { + let input = hex!("6478696f34657661"); + let len = input.len() as u64; + let mut mock = Builder::new() + .write(&len.to_le_bytes()) + .write(&input) + .build(); + assert_ok!(write_bytes(&mut mock, &input).await) + } + #[tokio::test] + async fn test_write_bytes_with_padding() { + let input = hex!("322e332e3137"); + let len = input.len() as u64; + let mut mock = Builder::new() + .write(&len.to_le_bytes()) + .write(&hex!("322e332e31370000")) + .build(); + assert_ok!(write_bytes(&mut mock, &input).await) + } } |