use pin_project_lite::pin_project; use std::task::{ready, Poll}; use tokio::io::AsyncWrite; use super::{padding_len, BytesPacketPosition, EMPTY_BYTES, LEN_SIZE}; pin_project! { /// Writes a "bytes wire packet" to the underlying writer. /// The format is the same as in [crate::wire::bytes::write_bytes], /// however this structure provides a [AsyncWrite] interface, /// allowing to not having to pass around the entire payload in memory. /// /// It internally takes care of writing (non-payload) framing (size and /// padding). /// /// During construction, the expected payload size needs to be provided. /// /// After writing the payload to it, the user MUST call flush (or shutdown), /// which will validate the written payload size to match, and write the /// necessary padding. /// /// In case flush is not called at the end, invalid data might be sent /// silently. /// /// The underlying writer returning `Ok(0)` is considered an EOF situation, /// which is stronger than the "typically means the underlying object is no /// longer able to accept bytes" interpretation from the docs. If such a /// situation occurs, an error is returned. /// /// The struct holds three fields, the underlying writer, the (expected) /// payload length, and an enum, tracking the state. pub struct BytesWriter<W> where W: AsyncWrite, { #[pin] inner: W, payload_len: u64, state: BytesPacketPosition, } } impl<W> BytesWriter<W> where W: AsyncWrite, { /// Constructs a new BytesWriter, using the underlying passed writer. pub fn new(w: W, payload_len: u64) -> Self { Self { inner: w, payload_len, state: BytesPacketPosition::Size(0), } } } /// Returns an error if the passed usize is 0. #[inline] fn ensure_nonzero_bytes_written(bytes_written: usize) -> Result<usize, std::io::Error> { if bytes_written == 0 { Err(std::io::Error::new( std::io::ErrorKind::WriteZero, "underlying writer accepted 0 bytes", )) } else { Ok(bytes_written) } } impl<W> AsyncWrite for BytesWriter<W> where W: AsyncWrite, { fn poll_write( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8], ) -> Poll<Result<usize, std::io::Error>> { // Use a loop, so we can deal with (multiple) state transitions. let mut this = self.project(); loop { match *this.state { BytesPacketPosition::Size(LEN_SIZE) => unreachable!(), BytesPacketPosition::Size(pos) => { let size_field = &this.payload_len.to_le_bytes(); let bytes_written = ensure_nonzero_bytes_written(ready!(this .inner .as_mut() .poll_write(cx, &size_field[pos..]))?)?; let new_pos = pos + bytes_written; if new_pos == LEN_SIZE { *this.state = BytesPacketPosition::Payload(0); } else { *this.state = BytesPacketPosition::Size(new_pos); } } BytesPacketPosition::Payload(pos) => { // Ensure we still have space for more payload if pos + (buf.len() as u64) > *this.payload_len { return Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "tried to write excess bytes", ))); } let bytes_written = ready!(this.inner.as_mut().poll_write(cx, buf))?; ensure_nonzero_bytes_written(bytes_written)?; let new_pos = pos + (bytes_written as u64); if new_pos == *this.payload_len { *this.state = BytesPacketPosition::Padding(0) } else { *this.state = BytesPacketPosition::Payload(new_pos) } return Poll::Ready(Ok(bytes_written)); } // If we're already in padding state, there should be no more payload left to write! BytesPacketPosition::Padding(_pos) => { return Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "tried to write excess bytes", ))) } } } } fn poll_flush( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll<Result<(), std::io::Error>> { let mut this = self.project(); loop { match *this.state { BytesPacketPosition::Size(LEN_SIZE) => unreachable!(), BytesPacketPosition::Size(pos) => { // More bytes to write in the size field let size_field = &this.payload_len.to_le_bytes()[..]; let bytes_written = ensure_nonzero_bytes_written(ready!(this .inner .as_mut() .poll_write(cx, &size_field[pos..]))?)?; let new_pos = pos + bytes_written; if new_pos == LEN_SIZE { // Size field written, now ready to receive payload *this.state = BytesPacketPosition::Payload(0); } else { *this.state = BytesPacketPosition::Size(new_pos); } } BytesPacketPosition::Payload(_pos) => { // If we're at position 0 and want to write 0 bytes of payload // in total, we can transition to padding. // Otherwise, break, as we're expecting more payload to // be written. if *this.payload_len == 0 { *this.state = BytesPacketPosition::Padding(0); } else { break; } } BytesPacketPosition::Padding(pos) => { // Write remaining padding, if there is padding to write. let total_padding_len = padding_len(*this.payload_len) as usize; if pos != total_padding_len { let bytes_written = ensure_nonzero_bytes_written(ready!(this .inner .as_mut() .poll_write(cx, &EMPTY_BYTES[pos..total_padding_len]))?)?; *this.state = BytesPacketPosition::Padding(pos + bytes_written); } else { // everything written, break break; } } } } // Flush the underlying writer. this.inner.as_mut().poll_flush(cx) } fn poll_shutdown( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll<Result<(), std::io::Error>> { // Call flush. ready!(self.as_mut().poll_flush(cx))?; let this = self.project(); // After a flush, being inside the padding state, and at the end of the padding // is the only way to prevent a dirty shutdown. if let BytesPacketPosition::Padding(pos) = *this.state { let padding_len = padding_len(*this.payload_len) as usize; if padding_len == pos { // Shutdown the underlying writer return this.inner.poll_shutdown(cx); } } // Shutdown the underlying writer, bubbling up any errors. ready!(this.inner.poll_shutdown(cx))?; // return an error about unclean shutdown Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::BrokenPipe, "unclean shutdown", ))) } } #[cfg(test)] mod tests { use std::time::Duration; use crate::wire::bytes::write_bytes; use hex_literal::hex; use lazy_static::lazy_static; use tokio::io::AsyncWriteExt; use tokio_test::{assert_err, assert_ok, io::Builder}; use super::*; lazy_static! { pub static ref LARGE_PAYLOAD: Vec<u8> = (0..255).collect::<Vec<u8>>().repeat(4 * 1024); } /// Helper function, calling the (simpler) write_bytes with the payload. /// We use this to create data we want to see on the wire. async fn produce_exp_bytes(payload: &[u8]) -> Vec<u8> { let mut exp = vec![]; write_bytes(&mut exp, payload).await.unwrap(); exp } /// Write an empty bytes packet. #[tokio::test] async fn write_empty() { let payload = &[]; let mut mock = Builder::new() .write(&produce_exp_bytes(payload).await) .build(); let mut w = BytesWriter::new(&mut mock, 0); assert_ok!(w.write_all(&[]).await, "write all data"); assert_ok!(w.flush().await, "flush"); } /// Write an empty bytes packet, not calling write. #[tokio::test] async fn write_empty_only_flush() { let payload = &[]; let mut mock = Builder::new() .write(&produce_exp_bytes(payload).await) .build(); let mut w = BytesWriter::new(&mut mock, 0); assert_ok!(w.flush().await, "flush"); } /// Write an empty bytes packet, not calling write or flush, only shutdown. #[tokio::test] async fn write_empty_only_shutdown() { let payload = &[]; let mut mock = Builder::new() .write(&produce_exp_bytes(payload).await) .build(); let mut w = BytesWriter::new(&mut mock, 0); assert_ok!(w.shutdown().await, "shutdown"); } /// Write a 1 bytes packet #[tokio::test] async fn write_1b() { let payload = &[0xff]; let mut mock = Builder::new() .write(&produce_exp_bytes(payload).await) .build(); let mut w = BytesWriter::new(&mut mock, payload.len() as u64); assert_ok!(w.write_all(payload).await); assert_ok!(w.flush().await, "flush"); } /// Write a 8 bytes payload (no padding) #[tokio::test] async fn write_8b() { let payload = &hex!("0001020304050607"); let mut mock = Builder::new() .write(&produce_exp_bytes(payload).await) .build(); let mut w = BytesWriter::new(&mut mock, payload.len() as u64); assert_ok!(w.write_all(payload).await); assert_ok!(w.flush().await, "flush"); } /// Write a 9 bytes payload (7 bytes padding) #[tokio::test] async fn write_9b() { let payload = &hex!("000102030405060708"); let mut mock = Builder::new() .write(&produce_exp_bytes(payload).await) .build(); let mut w = BytesWriter::new(&mut mock, payload.len() as u64); assert_ok!(w.write_all(payload).await); assert_ok!(w.flush().await, "flush"); } /// Write a 9 bytes packet very granularly, with a lot of flushing in between, /// and a shutdown at the end. #[tokio::test] async fn write_9b_flush() { let payload = &hex!("000102030405060708"); let exp_bytes = produce_exp_bytes(payload).await; let mut mock = Builder::new().write(&exp_bytes).build(); let mut w = BytesWriter::new(&mut mock, payload.len() as u64); assert_ok!(w.flush().await); assert_ok!(w.write_all(&payload[..4]).await); assert_ok!(w.flush().await); // empty write, cause why not assert_ok!(w.write_all(&[]).await); assert_ok!(w.flush().await); assert_ok!(w.write_all(&payload[4..]).await); assert_ok!(w.flush().await); assert_ok!(w.shutdown().await); } /// Write a 9 bytes packet, but cause the sink to only accept half of the /// padding, ensuring we correctly write (only) the rest of the padding later. /// We write another 2 bytes of "bait", where a faulty implementation (pre /// cl/11384) would put too many null bytes. #[tokio::test] async fn write_9b_write_padding_2steps() { let payload = &hex!("000102030405060708"); let exp_bytes = produce_exp_bytes(payload).await; let mut mock = Builder::new() .write(&exp_bytes[0..8]) // size .write(&exp_bytes[8..17]) // payload .write(&exp_bytes[17..19]) // padding (2 of 7 bytes) // insert a wait to prevent Mock from merging the two writes into one .wait(Duration::from_nanos(1)) .write(&hex!("0000000000ffff")) // padding (5 of 7 bytes, plus 2 bytes of "bait") .build(); let mut w = BytesWriter::new(&mut mock, payload.len() as u64); assert_ok!(w.write_all(&payload[..]).await); assert_ok!(w.flush().await); // Write bait assert_ok!(mock.write_all(&hex!("ffff")).await); } /// Write a larger bytes packet #[tokio::test] async fn write_1m() { let payload = LARGE_PAYLOAD.as_slice(); let exp_bytes = produce_exp_bytes(payload).await; let mut mock = Builder::new().write(&exp_bytes).build(); let mut w = BytesWriter::new(&mut mock, payload.len() as u64); assert_ok!(w.write_all(payload).await); assert_ok!(w.flush().await, "flush"); } /// Not calling flush at the end, but shutdown is also ok if we wrote all /// bytes we promised to write (as shutdown implies flush) #[tokio::test] async fn write_shutdown_without_flush_end() { let payload = &[0xf0, 0xff]; let exp_bytes = produce_exp_bytes(payload).await; let mut mock = Builder::new().write(&exp_bytes).build(); let mut w = BytesWriter::new(&mut mock, payload.len() as u64); // call flush to write the size field assert_ok!(w.flush().await); // write payload assert_ok!(w.write_all(payload).await); // call shutdown assert_ok!(w.shutdown().await); } /// Writing more bytes than previously signalled should fail. #[tokio::test] async fn write_more_than_signalled_fail() { let mut buf = Vec::new(); let mut w = BytesWriter::new(&mut buf, 2); assert_err!(w.write_all(&hex!("000102")).await); } /// Writing more bytes than previously signalled, but in two parts #[tokio::test] async fn write_more_than_signalled_split_fail() { let mut buf = Vec::new(); let mut w = BytesWriter::new(&mut buf, 2); // write two bytes assert_ok!(w.write_all(&hex!("0001")).await); // write the excess byte. assert_err!(w.write_all(&hex!("02")).await); } /// Writing more bytes than previously signalled, but flushing after the /// signalled amount should fail. #[tokio::test] async fn write_more_than_signalled_flush_fail() { let mut buf = Vec::new(); let mut w = BytesWriter::new(&mut buf, 2); // write two bytes, then flush assert_ok!(w.write_all(&hex!("0001")).await); assert_ok!(w.flush().await); // write the excess byte. assert_err!(w.write_all(&hex!("02")).await); } /// Calling shutdown while not having written all bytes that were promised /// returns an error. /// Note there's still cases of silent corruption if the user doesn't call /// shutdown explicitly (only drops). #[tokio::test] async fn premature_shutdown() { let payload = &[0xf0, 0xff]; let mut buf = Vec::new(); let mut w = BytesWriter::new(&mut buf, payload.len() as u64); // call flush to write the size field assert_ok!(w.flush().await); // write half of the payload (!) assert_ok!(w.write_all(&payload[0..1]).await); // call shutdown, ensure it fails assert_err!(w.shutdown().await); } /// Write to a Writer that fails to write during the size packet (after 4 bytes). /// Ensure this error gets propagated on the first call to write. #[tokio::test] async fn inner_writer_fail_during_size_firstwrite() { let payload = &[0xf0]; let mut mock = Builder::new() .write(&1u32.to_le_bytes()) .write_error(std::io::Error::new(std::io::ErrorKind::Other, "๐ฟ")) .build(); let mut w = BytesWriter::new(&mut mock, payload.len() as u64); assert_err!(w.write_all(payload).await); } /// Write to a Writer that fails to write during the size packet (after 4 bytes). /// Ensure this error gets propagated during an initial flush #[tokio::test] async fn inner_writer_fail_during_size_initial_flush() { let payload = &[0xf0]; let mut mock = Builder::new() .write(&1u32.to_le_bytes()) .write_error(std::io::Error::new(std::io::ErrorKind::Other, "๐ฟ")) .build(); let mut w = BytesWriter::new(&mut mock, payload.len() as u64); assert_err!(w.flush().await); } /// Write to a writer that fails to write during the payload (after 9 bytes). /// Ensure this error gets propagated when we're writing this byte. #[tokio::test] async fn inner_writer_fail_during_write() { let payload = &hex!("f0ff"); let mut mock = Builder::new() .write(&2u64.to_le_bytes()) .write(&hex!("f0")) .write_error(std::io::Error::new(std::io::ErrorKind::Other, "๐ฟ")) .build(); let mut w = BytesWriter::new(&mut mock, payload.len() as u64); assert_ok!(w.write(&hex!("f0")).await); assert_err!(w.write(&hex!("ff")).await); } /// Write to a writer that fails to write during the padding (after 10 bytes). /// Ensure this error gets propagated during a flush. #[tokio::test] async fn inner_writer_fail_during_padding_flush() { let payload = &hex!("f0"); let mut mock = Builder::new() .write(&1u64.to_le_bytes()) .write(&hex!("f0")) .write(&hex!("00")) .write_error(std::io::Error::new(std::io::ErrorKind::Other, "๐ฟ")) .build(); let mut w = BytesWriter::new(&mut mock, payload.len() as u64); assert_ok!(w.write(&hex!("f0")).await); assert_err!(w.flush().await); } }