From 70c679eac4485d31b1676531a19b170e1661a04d Mon Sep 17 00:00:00 2001 From: edef Date: Thu, 25 Apr 2024 22:43:59 +0000 Subject: feat(nix-compat/wire/bytes): allow specifying a pre-read size Change-Id: I9c94239c308cfbc2e6dae871ba77fb33507433c9 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11517 Tested-by: BuildkiteCI Reviewed-by: flokli --- tvix/nix-compat/src/wire/bytes/reader.rs | 41 ++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tvix/nix-compat/src/wire/bytes/reader.rs b/tvix/nix-compat/src/wire/bytes/reader.rs index 4239b40fec2d..18a8c6c686f0 100644 --- a/tvix/nix-compat/src/wire/bytes/reader.rs +++ b/tvix/nix-compat/src/wire/bytes/reader.rs @@ -61,6 +61,20 @@ where state: BytesPacketPosition::Size(0), } } + + /// Construct a new BytesReader with a known, and already-read size. + pub fn with_size(r: R, size: u64) -> Self { + Self { + inner: r, + allowed_size: size..=size, + payload_size: u64::to_le_bytes(size), + state: if size != 0 { + BytesPacketPosition::Payload(0) + } else { + BytesPacketPosition::Padding(0) + }, + } + } } /// Returns an error if the passed usize is 0. #[inline] @@ -261,6 +275,33 @@ mod tests { assert_eq!(payload, &buf[..]); } + /// Read bytes packets of various length, and ensure read_to_end returns the + /// expected payload. + #[rstest] + #[case::empty(&[])] // empty bytes packet + #[case::size_1b(&[0xff])] // 1 bytes payload + #[case::size_8b(&hex!("0001020304050607"))] // 8 bytes payload (no padding) + #[case::size_9b( &hex!("000102030405060708"))] // 9 bytes payload (7 bytes padding) + #[case::size_1m(LARGE_PAYLOAD.as_slice())] // larger bytes packet + #[tokio::test] + async fn read_payload_correct_known(#[case] payload: &[u8]) { + let packet = produce_packet_bytes(payload).await; + + let size = u64::from_le_bytes({ + let mut buf = [0; 8]; + buf.copy_from_slice(&packet[..8]); + buf + }); + + let mut mock = Builder::new().read(&packet[8..]).build(); + + let mut r = BytesReader::with_size(&mut mock, size); + let mut buf = Vec::new(); + r.read_to_end(&mut buf).await.expect("must succeed"); + + assert_eq!(payload, &buf[..]); + } + /// Fail if the bytes packet is larger than allowed #[tokio::test] async fn read_bigger_than_allowed_fail() { -- cgit 1.4.1