about summary refs log tree commit diff
path: root/tvix
diff options
context:
space:
mode:
Diffstat (limited to 'tvix')
-rw-r--r--tvix/nix-compat/src/wire/bytes.rs71
1 files changed, 62 insertions, 9 deletions
diff --git a/tvix/nix-compat/src/wire/bytes.rs b/tvix/nix-compat/src/wire/bytes.rs
index f2fe30083b1c..a050b161048b 100644
--- a/tvix/nix-compat/src/wire/bytes.rs
+++ b/tvix/nix-compat/src/wire/bytes.rs
@@ -1,6 +1,6 @@
 use std::ops::RangeBounds;
 
-use tokio::io::AsyncReadExt;
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
 
 use super::primitive;
 
@@ -28,12 +28,7 @@ where
 
     // calculate the total length, including padding.
     // byte packets are padded to 8 byte blocks each.
-    let padded_len = if len % 8 == 0 {
-        len
-    } else {
-        len + (8 - len % 8)
-    };
-
+    let padded_len = padding_len(len) as u64 + (len as u64);
     let mut limited_reader = r.take(padded_len);
 
     let mut buf = Vec::new();
@@ -63,6 +58,32 @@ where
     Ok(buf)
 }
 
+/// Writes a sequence of sized bits to a (hopefully buffered)
+/// [AsyncWriteExt] handle.
+///
+/// On the wire, it looks as follows:
+///
+/// 1. Number of bytes contained in the buffer we're about to write on
+///    the wire. (LE-encoded on 64 bits)
+/// 2. Raw payload.
+/// 3. Null padding up until the next 8 bytes alignment block.
+///
+/// Note: if performance matters to you, make sure your
+/// [AsyncWriteExt] handle is buffered. This function is quite
+/// write-intesive.
+pub async fn write_bytes<W: AsyncWriteExt + Unpin>(w: &mut W, b: &[u8]) -> std::io::Result<()> {
+    // We're assuming the handle is buffered: we can afford not
+    // writing all the bytes in one go.
+    let len = b.len();
+    primitive::write_u64(w, len as u64).await?;
+    w.write_all(b).await?;
+    let padding = padding_len(len as u64);
+    if padding != 0 {
+        w.write_all(&vec![0; padding as usize]).await?;
+    }
+    Ok(())
+}
+
 #[allow(dead_code)]
 /// Read an unlimited number of bytes from the AsyncRead.
 /// Note this can exhaust memory.
@@ -72,9 +93,20 @@ pub async fn read_bytes_unchecked<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io
     read_bytes(r, 0u64..).await
 }
 
+/// Computes the number of bytes we should add to len (a length in
+/// bytes) to be alined on 64 bits (8 bytes).
+fn padding_len(len: u64) -> u8 {
+    let modulo = len % 8;
+    if modulo == 0 {
+        0
+    } else {
+        8 - modulo as u8
+    }
+}
+
 #[cfg(test)]
 mod tests {
-    use tokio_test::io::Builder;
+    use tokio_test::{assert_ok, io::Builder};
 
     use super::*;
     use hex_literal::hex;
@@ -120,11 +152,32 @@ mod tests {
     #[tokio::test]
     /// Ensure we don't read any further than the size field if the length
     /// doesn't match the range we want to accept.
-    async fn test_reject_too_large() {
+    async fn test_read_reject_too_large() {
         let mut mock = Builder::new().read(&100u64.to_le_bytes()).build();
 
         read_bytes(&mut mock, 10..10)
             .await
             .expect_err("expect this to fail");
     }
+
+    #[tokio::test]
+    async fn test_write_bytes_no_padding() {
+        let input = hex!("6478696f34657661");
+        let len = input.len() as u64;
+        let mut mock = Builder::new()
+            .write(&len.to_le_bytes())
+            .write(&input)
+            .build();
+        assert_ok!(write_bytes(&mut mock, &input).await)
+    }
+    #[tokio::test]
+    async fn test_write_bytes_with_padding() {
+        let input = hex!("322e332e3137");
+        let len = input.len() as u64;
+        let mut mock = Builder::new()
+            .write(&len.to_le_bytes())
+            .write(&hex!("322e332e31370000"))
+            .build();
+        assert_ok!(write_bytes(&mut mock, &input).await)
+    }
 }