diff options
author | edef <edef@edef.eu> | 2024-04-25T22·43+0000 |
---|---|---|
committer | edef <edef@edef.eu> | 2024-04-25T23·47+0000 |
commit | 70c679eac4485d31b1676531a19b170e1661a04d (patch) | |
tree | 5c40d7c574abd252d29cb8bd23bc92ac50d04c9a /tvix/nix-compat/src/wire | |
parent | 859bfcb68b3ba333e343a98a3fb93ca9bf3d0006 (diff) |
feat(nix-compat/wire/bytes): allow specifying a pre-read size r/8011
Change-Id: I9c94239c308cfbc2e6dae871ba77fb33507433c9 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11517 Tested-by: BuildkiteCI Reviewed-by: flokli <flokli@flokli.de>
Diffstat (limited to 'tvix/nix-compat/src/wire')
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/reader.rs | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/wire/bytes/reader.rs b/tvix/nix-compat/src/wire/bytes/reader.rs index 4239b40fec2d..18a8c6c686f0 100644 --- a/tvix/nix-compat/src/wire/bytes/reader.rs +++ b/tvix/nix-compat/src/wire/bytes/reader.rs @@ -61,6 +61,20 @@ where state: BytesPacketPosition::Size(0), } } + + /// Construct a new BytesReader with a known, and already-read size. + pub fn with_size(r: R, size: u64) -> Self { + Self { + inner: r, + allowed_size: size..=size, + payload_size: u64::to_le_bytes(size), + state: if size != 0 { + BytesPacketPosition::Payload(0) + } else { + BytesPacketPosition::Padding(0) + }, + } + } } /// Returns an error if the passed usize is 0. #[inline] @@ -261,6 +275,33 @@ mod tests { assert_eq!(payload, &buf[..]); } + /// Read bytes packets of various length, and ensure read_to_end returns the + /// expected payload. + #[rstest] + #[case::empty(&[])] // empty bytes packet + #[case::size_1b(&[0xff])] // 1 bytes payload + #[case::size_8b(&hex!("0001020304050607"))] // 8 bytes payload (no padding) + #[case::size_9b( &hex!("000102030405060708"))] // 9 bytes payload (7 bytes padding) + #[case::size_1m(LARGE_PAYLOAD.as_slice())] // larger bytes packet + #[tokio::test] + async fn read_payload_correct_known(#[case] payload: &[u8]) { + let packet = produce_packet_bytes(payload).await; + + let size = u64::from_le_bytes({ + let mut buf = [0; 8]; + buf.copy_from_slice(&packet[..8]); + buf + }); + + let mut mock = Builder::new().read(&packet[8..]).build(); + + let mut r = BytesReader::with_size(&mut mock, size); + let mut buf = Vec::new(); + r.read_to_end(&mut buf).await.expect("must succeed"); + + assert_eq!(payload, &buf[..]); + } + /// Fail if the bytes packet is larger than allowed #[tokio::test] async fn read_bigger_than_allowed_fail() { |