From 36b296609b0d4db633f78a6fa3685f61227af94a Mon Sep 17 00:00:00 2001 From: Florian Klink Date: Wed, 10 Apr 2024 14:54:11 +0300 Subject: refactor(tvix/nix-compat): reorganize wire and bytes Move everything bytes-related into its own module, and re-export both bytes and primitive in a flat space from wire/mod.rs. Expose this if a `wire` feature flag is set. We only have `async` stuff in here. Change-Id: Ia4ce4791f13a5759901cc9d6ce6bd6bbcca587c7 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11389 Autosubmit: flokli Reviewed-by: raitobezarius Tested-by: BuildkiteCI Reviewed-by: Brian Olsen --- tvix/nix-compat/src/wire/bytes_writer.rs | 543 ------------------------------- 1 file changed, 543 deletions(-) delete mode 100644 tvix/nix-compat/src/wire/bytes_writer.rs (limited to 'tvix/nix-compat/src/wire/bytes_writer.rs') diff --git a/tvix/nix-compat/src/wire/bytes_writer.rs b/tvix/nix-compat/src/wire/bytes_writer.rs deleted file mode 100644 index 8bd0a2d00ef6..000000000000 --- a/tvix/nix-compat/src/wire/bytes_writer.rs +++ /dev/null @@ -1,543 +0,0 @@ -use pin_project_lite::pin_project; -use std::task::{ready, Poll}; - -use tokio::io::AsyncWrite; - -use super::bytes::EMPTY_BYTES; - -/// The length of the size field, in bytes is always 8. -pub(crate) const LEN_SIZE: usize = 8; - -pin_project! { - /// Writes a "bytes wire packet" to the underlying writer. - /// The format is the same as in [crate::wire::bytes::write_bytes], - /// however this structure provides a [AsyncWrite] interface, - /// allowing to not having to pass around the entire payload in memory. - /// - /// It internally takes care of writing (non-payload) framing (size and - /// padding). - /// - /// During construction, the expected payload size needs to be provided. - /// - /// After writing the payload to it, the user MUST call flush (or shutdown), - /// which will validate the written payload size to match, and write the - /// necessary padding. - /// - /// In case flush is not called at the end, invalid data might be sent - /// silently. - /// - /// The underlying writer returning `Ok(0)` is considered an EOF situation, - /// which is stronger than the "typically means the underlying object is no - /// longer able to accept bytes" interpretation from the docs. If such a - /// situation occurs, an error is returned. - /// - /// The struct holds three fields, the underlying writer, the (expected) - /// payload length, and an enum, tracking the state. - pub struct BytesWriter - where - W: AsyncWrite, - { - #[pin] - inner: W, - payload_len: u64, - state: BytesPacketPosition, - } -} - -/// Models the position inside a "bytes wire packet" that the reader or writer -/// is in. -/// It can be in three different stages, inside size, payload or padding fields. -/// The number tracks the number of bytes written inside the specific field. -/// There shall be no ambiguous states, at the end of a stage we immediately -/// move to the beginning of the next one: -/// - Size(LEN_SIZE) must be expressed as Payload(0) -/// - Payload(self.payload_len) must be expressed as Padding(0) -/// There's one exception - Size(LEN_SIZE) in the reader represents a failure -/// state we enter in case the allowed size doesn't match the allowed range. -/// -/// Padding(padding_len) means we're at the end of the bytes wire packet. -#[derive(Clone, Debug, PartialEq, Eq)] -pub(crate) enum BytesPacketPosition { - Size(usize), - Payload(u64), - Padding(usize), -} - -impl BytesWriter -where - W: AsyncWrite, -{ - /// Constructs a new BytesWriter, using the underlying passed writer. - pub fn new(w: W, payload_len: u64) -> Self { - Self { - inner: w, - payload_len, - state: BytesPacketPosition::Size(0), - } - } -} - -/// Returns an error if the passed usize is 0. -fn ensure_nonzero_bytes_written(bytes_written: usize) -> Result { - if bytes_written == 0 { - Err(std::io::Error::new( - std::io::ErrorKind::WriteZero, - "underlying writer accepted 0 bytes", - )) - } else { - Ok(bytes_written) - } -} - -impl AsyncWrite for BytesWriter -where - W: AsyncWrite, -{ - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - // Use a loop, so we can deal with (multiple) state transitions. - let mut this = self.project(); - - loop { - match *this.state { - BytesPacketPosition::Size(LEN_SIZE) => unreachable!(), - BytesPacketPosition::Size(pos) => { - let size_field = &this.payload_len.to_le_bytes(); - - let bytes_written = ensure_nonzero_bytes_written(ready!(this - .inner - .as_mut() - .poll_write(cx, &size_field[pos..]))?)?; - - let new_pos = pos + bytes_written; - if new_pos == LEN_SIZE { - *this.state = BytesPacketPosition::Payload(0); - } else { - *this.state = BytesPacketPosition::Size(new_pos); - } - } - BytesPacketPosition::Payload(pos) => { - // Ensure we still have space for more payload - if pos + (buf.len() as u64) > *this.payload_len { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "tried to write excess bytes", - ))); - } - let bytes_written = ready!(this.inner.as_mut().poll_write(cx, buf))?; - ensure_nonzero_bytes_written(bytes_written)?; - let new_pos = pos + (bytes_written as u64); - if new_pos == *this.payload_len { - *this.state = BytesPacketPosition::Padding(0) - } else { - *this.state = BytesPacketPosition::Payload(new_pos) - } - - return Poll::Ready(Ok(bytes_written)); - } - // If we're already in padding state, there should be no more payload left to write! - BytesPacketPosition::Padding(_pos) => { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "tried to write excess bytes", - ))) - } - } - } - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let mut this = self.project(); - - loop { - match *this.state { - BytesPacketPosition::Size(LEN_SIZE) => unreachable!(), - BytesPacketPosition::Size(pos) => { - // More bytes to write in the size field - let size_field = &this.payload_len.to_le_bytes()[..]; - let bytes_written = ensure_nonzero_bytes_written(ready!(this - .inner - .as_mut() - .poll_write(cx, &size_field[pos..]))?)?; - let new_pos = pos + bytes_written; - if new_pos == LEN_SIZE { - // Size field written, now ready to receive payload - *this.state = BytesPacketPosition::Payload(0); - } else { - *this.state = BytesPacketPosition::Size(new_pos); - } - } - BytesPacketPosition::Payload(_pos) => { - // If we're at position 0 and want to write 0 bytes of payload - // in total, we can transition to padding. - // Otherwise, break, as we're expecting more payload to - // be written. - if *this.payload_len == 0 { - *this.state = BytesPacketPosition::Padding(0); - } else { - break; - } - } - BytesPacketPosition::Padding(pos) => { - // Write remaining padding, if there is padding to write. - let total_padding_len = super::bytes::padding_len(*this.payload_len) as usize; - - if pos != total_padding_len { - let bytes_written = ensure_nonzero_bytes_written(ready!(this - .inner - .as_mut() - .poll_write(cx, &EMPTY_BYTES[pos..total_padding_len]))?)?; - *this.state = BytesPacketPosition::Padding(pos + bytes_written); - } else { - // everything written, break - break; - } - } - } - } - // Flush the underlying writer. - this.inner.as_mut().poll_flush(cx) - } - - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - // Call flush. - ready!(self.as_mut().poll_flush(cx))?; - - let this = self.project(); - - // After a flush, being inside the padding state, and at the end of the padding - // is the only way to prevent a dirty shutdown. - if let BytesPacketPosition::Padding(pos) = *this.state { - let padding_len = super::bytes::padding_len(*this.payload_len) as usize; - if padding_len == pos { - // Shutdown the underlying writer - return this.inner.poll_shutdown(cx); - } - } - - // Shutdown the underlying writer, bubbling up any errors. - ready!(this.inner.poll_shutdown(cx))?; - - // return an error about unclean shutdown - Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "unclean shutdown", - ))) - } -} - -#[cfg(test)] -mod tests { - use std::time::Duration; - - use crate::wire::bytes::write_bytes; - use hex_literal::hex; - use lazy_static::lazy_static; - use tokio::io::AsyncWriteExt; - use tokio_test::{assert_err, assert_ok, io::Builder}; - - use super::*; - - lazy_static! { - pub static ref LARGE_PAYLOAD: Vec = (0..255).collect::>().repeat(4 * 1024); - } - - /// Helper function, calling the (simpler) write_bytes with the payload. - /// We use this to create data we want to see on the wire. - async fn produce_exp_bytes(payload: &[u8]) -> Vec { - let mut exp = vec![]; - write_bytes(&mut exp, payload).await.unwrap(); - exp - } - - /// Write an empty bytes packet. - #[tokio::test] - async fn write_empty() { - let payload = &[]; - let mut mock = Builder::new() - .write(&produce_exp_bytes(payload).await) - .build(); - - let mut w = BytesWriter::new(&mut mock, 0); - assert_ok!(w.write_all(&[]).await, "write all data"); - assert_ok!(w.flush().await, "flush"); - } - - /// Write an empty bytes packet, not calling write. - #[tokio::test] - async fn write_empty_only_flush() { - let payload = &[]; - let mut mock = Builder::new() - .write(&produce_exp_bytes(payload).await) - .build(); - - let mut w = BytesWriter::new(&mut mock, 0); - assert_ok!(w.flush().await, "flush"); - } - - /// Write an empty bytes packet, not calling write or flush, only shutdown. - #[tokio::test] - async fn write_empty_only_shutdown() { - let payload = &[]; - let mut mock = Builder::new() - .write(&produce_exp_bytes(payload).await) - .build(); - - let mut w = BytesWriter::new(&mut mock, 0); - assert_ok!(w.shutdown().await, "shutdown"); - } - - /// Write a 1 bytes packet - #[tokio::test] - async fn write_1b() { - let payload = &[0xff]; - - let mut mock = Builder::new() - .write(&produce_exp_bytes(payload).await) - .build(); - - let mut w = BytesWriter::new(&mut mock, payload.len() as u64); - assert_ok!(w.write_all(payload).await); - assert_ok!(w.flush().await, "flush"); - } - - /// Write a 8 bytes payload (no padding) - #[tokio::test] - async fn write_8b() { - let payload = &hex!("0001020304050607"); - - let mut mock = Builder::new() - .write(&produce_exp_bytes(payload).await) - .build(); - - let mut w = BytesWriter::new(&mut mock, payload.len() as u64); - assert_ok!(w.write_all(payload).await); - assert_ok!(w.flush().await, "flush"); - } - - /// Write a 9 bytes payload (7 bytes padding) - #[tokio::test] - async fn write_9b() { - let payload = &hex!("000102030405060708"); - - let mut mock = Builder::new() - .write(&produce_exp_bytes(payload).await) - .build(); - - let mut w = BytesWriter::new(&mut mock, payload.len() as u64); - assert_ok!(w.write_all(payload).await); - assert_ok!(w.flush().await, "flush"); - } - - /// Write a 9 bytes packet very granularly, with a lot of flushing in between, - /// and a shutdown at the end. - #[tokio::test] - async fn write_9b_flush() { - let payload = &hex!("000102030405060708"); - let exp_bytes = produce_exp_bytes(payload).await; - - let mut mock = Builder::new().write(&exp_bytes).build(); - - let mut w = BytesWriter::new(&mut mock, payload.len() as u64); - assert_ok!(w.flush().await); - - assert_ok!(w.write_all(&payload[..4]).await); - assert_ok!(w.flush().await); - - // empty write, cause why not - assert_ok!(w.write_all(&[]).await); - assert_ok!(w.flush().await); - - assert_ok!(w.write_all(&payload[4..]).await); - assert_ok!(w.flush().await); - assert_ok!(w.shutdown().await); - } - - /// Write a 9 bytes packet, but cause the sink to only accept half of the - /// padding, ensuring we correctly write (only) the rest of the padding later. - /// We write another 2 bytes of "bait", where a faulty implementation (pre - /// cl/11384) would put too many null bytes. - #[tokio::test] - async fn write_9b_write_padding_2steps() { - let payload = &hex!("000102030405060708"); - let exp_bytes = produce_exp_bytes(payload).await; - - let mut mock = Builder::new() - .write(&exp_bytes[0..8]) // size - .write(&exp_bytes[8..17]) // payload - .write(&exp_bytes[17..19]) // padding (2 of 7 bytes) - // insert a wait to prevent Mock from merging the two writes into one - .wait(Duration::from_nanos(1)) - .write(&hex!("0000000000ffff")) // padding (5 of 7 bytes, plus 2 bytes of "bait") - .build(); - - let mut w = BytesWriter::new(&mut mock, payload.len() as u64); - assert_ok!(w.write_all(&payload[..]).await); - assert_ok!(w.flush().await); - // Write bait - assert_ok!(mock.write_all(&hex!("ffff")).await); - } - - /// Write a larger bytes packet - #[tokio::test] - async fn write_1m() { - let payload = LARGE_PAYLOAD.as_slice(); - let exp_bytes = produce_exp_bytes(payload).await; - - let mut mock = Builder::new().write(&exp_bytes).build(); - let mut w = BytesWriter::new(&mut mock, payload.len() as u64); - - assert_ok!(w.write_all(payload).await); - assert_ok!(w.flush().await, "flush"); - } - - /// Not calling flush at the end, but shutdown is also ok if we wrote all - /// bytes we promised to write (as shutdown implies flush) - #[tokio::test] - async fn write_shutdown_without_flush_end() { - let payload = &[0xf0, 0xff]; - let exp_bytes = produce_exp_bytes(payload).await; - - let mut mock = Builder::new().write(&exp_bytes).build(); - let mut w = BytesWriter::new(&mut mock, payload.len() as u64); - - // call flush to write the size field - assert_ok!(w.flush().await); - - // write payload - assert_ok!(w.write_all(payload).await); - - // call shutdown - assert_ok!(w.shutdown().await); - } - - /// Writing more bytes than previously signalled should fail. - #[tokio::test] - async fn write_more_than_signalled_fail() { - let mut buf = Vec::new(); - let mut w = BytesWriter::new(&mut buf, 2); - - assert_err!(w.write_all(&hex!("000102")).await); - } - /// Writing more bytes than previously signalled, but in two parts - #[tokio::test] - async fn write_more_than_signalled_split_fail() { - let mut buf = Vec::new(); - let mut w = BytesWriter::new(&mut buf, 2); - - // write two bytes - assert_ok!(w.write_all(&hex!("0001")).await); - - // write the excess byte. - assert_err!(w.write_all(&hex!("02")).await); - } - - /// Writing more bytes than previously signalled, but flushing after the - /// signalled amount should fail. - #[tokio::test] - async fn write_more_than_signalled_flush_fail() { - let mut buf = Vec::new(); - let mut w = BytesWriter::new(&mut buf, 2); - - // write two bytes, then flush - assert_ok!(w.write_all(&hex!("0001")).await); - assert_ok!(w.flush().await); - - // write the excess byte. - assert_err!(w.write_all(&hex!("02")).await); - } - - /// Calling shutdown while not having written all bytes that were promised - /// returns an error. - /// Note there's still cases of silent corruption if the user doesn't call - /// shutdown explicitly (only drops). - #[tokio::test] - async fn premature_shutdown() { - let payload = &[0xf0, 0xff]; - let mut buf = Vec::new(); - let mut w = BytesWriter::new(&mut buf, payload.len() as u64); - - // call flush to write the size field - assert_ok!(w.flush().await); - - // write half of the payload (!) - assert_ok!(w.write_all(&payload[0..1]).await); - - // call shutdown, ensure it fails - assert_err!(w.shutdown().await); - } - - /// Write to a Writer that fails to write during the size packet (after 4 bytes). - /// Ensure this error gets propagated on the first call to write. - #[tokio::test] - async fn inner_writer_fail_during_size_firstwrite() { - let payload = &[0xf0]; - - let mut mock = Builder::new() - .write(&1u32.to_le_bytes()) - .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿")) - .build(); - let mut w = BytesWriter::new(&mut mock, payload.len() as u64); - - assert_err!(w.write_all(payload).await); - } - - /// Write to a Writer that fails to write during the size packet (after 4 bytes). - /// Ensure this error gets propagated during an initial flush - #[tokio::test] - async fn inner_writer_fail_during_size_initial_flush() { - let payload = &[0xf0]; - - let mut mock = Builder::new() - .write(&1u32.to_le_bytes()) - .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿")) - .build(); - let mut w = BytesWriter::new(&mut mock, payload.len() as u64); - - assert_err!(w.flush().await); - } - - /// Write to a writer that fails to write during the payload (after 9 bytes). - /// Ensure this error gets propagated when we're writing this byte. - #[tokio::test] - async fn inner_writer_fail_during_write() { - let payload = &hex!("f0ff"); - - let mut mock = Builder::new() - .write(&2u64.to_le_bytes()) - .write(&hex!("f0")) - .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿")) - .build(); - let mut w = BytesWriter::new(&mut mock, payload.len() as u64); - - assert_ok!(w.write(&hex!("f0")).await); - assert_err!(w.write(&hex!("ff")).await); - } - - /// Write to a writer that fails to write during the padding (after 10 bytes). - /// Ensure this error gets propagated during a flush. - #[tokio::test] - async fn inner_writer_fail_during_padding_flush() { - let payload = &hex!("f0"); - - let mut mock = Builder::new() - .write(&1u64.to_le_bytes()) - .write(&hex!("f0")) - .write(&hex!("00")) - .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿")) - .build(); - let mut w = BytesWriter::new(&mut mock, payload.len() as u64); - - assert_ok!(w.write(&hex!("f0")).await); - assert_err!(w.flush().await); - } -} -- cgit 1.4.1