diff options
Diffstat (limited to 'tvix/nix-compat/src/wire')
-rw-r--r-- | tvix/nix-compat/src/wire/de/bytes.rs | 70 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/de/collections.rs | 105 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/de/int.rs | 100 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/de/mock.rs | 261 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/de/mod.rs | 225 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/de/reader.rs | 526 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/mod.rs | 6 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/protocol_version.rs | 139 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/ser/bytes.rs | 89 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/ser/collections.rs | 94 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/ser/display.rs | 8 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/ser/int.rs | 108 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/ser/mock.rs | 672 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/ser/mod.rs | 124 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/ser/writer.rs | 306 |
15 files changed, 2833 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/wire/de/bytes.rs b/tvix/nix-compat/src/wire/de/bytes.rs new file mode 100644 index 000000000000..4c64247f7051 --- /dev/null +++ b/tvix/nix-compat/src/wire/de/bytes.rs @@ -0,0 +1,70 @@ +use bytes::Bytes; + +use super::{Error, NixDeserialize, NixRead}; + +impl NixDeserialize for Bytes { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + reader.try_read_bytes().await + } +} + +impl NixDeserialize for String { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + if let Some(buf) = reader.try_read_bytes().await? { + String::from_utf8(buf.to_vec()) + .map_err(R::Error::invalid_data) + .map(Some) + } else { + Ok(None) + } + } +} + +#[cfg(test)] +mod test { + use std::io; + + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use crate::wire::de::{NixRead, NixReader}; + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_read_string(#[case] expected: &str, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: String = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + #[tokio::test] + async fn test_read_string_invalid() { + let mock = Builder::new() + .read(&hex!("0300 0000 0000 0000 EDA0 8000 0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + assert_eq!( + io::ErrorKind::InvalidData, + reader.read_value::<String>().await.unwrap_err().kind() + ); + } +} diff --git a/tvix/nix-compat/src/wire/de/collections.rs b/tvix/nix-compat/src/wire/de/collections.rs new file mode 100644 index 000000000000..e1271635e4e6 --- /dev/null +++ b/tvix/nix-compat/src/wire/de/collections.rs @@ -0,0 +1,105 @@ +use std::{collections::BTreeMap, future::Future}; + +use super::{NixDeserialize, NixRead}; + +#[allow(clippy::manual_async_fn)] +impl<T> NixDeserialize for Vec<T> +where + T: NixDeserialize + Send, +{ + fn try_deserialize<R>( + reader: &mut R, + ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send, + { + async move { + if let Some(len) = reader.try_read_value::<usize>().await? { + let mut ret = Vec::with_capacity(len); + for _ in 0..len { + ret.push(reader.read_value().await?); + } + Ok(Some(ret)) + } else { + Ok(None) + } + } + } +} + +#[allow(clippy::manual_async_fn)] +impl<K, V> NixDeserialize for BTreeMap<K, V> +where + K: NixDeserialize + Ord + Send, + V: NixDeserialize + Send, +{ + fn try_deserialize<R>( + reader: &mut R, + ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send, + { + async move { + if let Some(len) = reader.try_read_value::<usize>().await? { + let mut ret = BTreeMap::new(); + for _ in 0..len { + let key = reader.read_value().await?; + let value = reader.read_value().await?; + ret.insert(key, value); + } + Ok(Some(ret)) + } else { + Ok(None) + } + } + } +} + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + use std::fmt; + + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use crate::wire::de::{NixDeserialize, NixRead, NixReader}; + + #[rstest] + #[case::empty(vec![], &hex!("0000 0000 0000 0000"))] + #[case::one(vec![0x29], &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(vec![0x7469, 10], &hex!("0200 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_read_small_vec(#[case] expected: Vec<usize>, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: Vec<usize> = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + fn empty_map() -> BTreeMap<usize, u64> { + BTreeMap::new() + } + macro_rules! map { + ($($key:expr => $value:expr),*) => {{ + let mut ret = BTreeMap::new(); + $(ret.insert($key, $value);)* + ret + }}; + } + + #[rstest] + #[case::empty(empty_map(), &hex!("0000 0000 0000 0000"))] + #[case::one(map![0x7469usize => 10u64], &hex!("0100 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_read_small_btree_map<E>(#[case] expected: E, #[case] data: &[u8]) + where + E: NixDeserialize + PartialEq + fmt::Debug, + { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: E = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } +} diff --git a/tvix/nix-compat/src/wire/de/int.rs b/tvix/nix-compat/src/wire/de/int.rs new file mode 100644 index 000000000000..d505de9b1b24 --- /dev/null +++ b/tvix/nix-compat/src/wire/de/int.rs @@ -0,0 +1,100 @@ +use super::{Error, NixDeserialize, NixRead}; + +impl NixDeserialize for u64 { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + reader.try_read_number().await + } +} + +impl NixDeserialize for usize { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + if let Some(value) = reader.try_read_number().await? { + value.try_into().map_err(R::Error::invalid_data).map(Some) + } else { + Ok(None) + } + } +} + +impl NixDeserialize for bool { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + Ok(reader.try_read_number().await?.map(|v| v != 0)) + } +} +impl NixDeserialize for i64 { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + Ok(reader.try_read_number().await?.map(|v| v as i64)) + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use crate::wire::de::{NixRead, NixReader}; + + #[rstest] + #[case::simple_false(false, &hex!("0000 0000 0000 0000"))] + #[case::simple_true(true, &hex!("0100 0000 0000 0000"))] + #[case::other_true(true, &hex!("1234 5600 0000 0000"))] + #[case::max_true(true, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_read_bool(#[case] expected: bool, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: bool = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_read_u64(#[case] expected: u64, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: u64 = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(usize::MAX, &usize::MAX.to_le_bytes())] + #[tokio::test] + async fn test_read_usize(#[case] expected: usize, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: usize = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + // FUTUREWORK: Test this on supported hardware + #[tokio::test] + #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] + async fn test_read_usize_overflow() { + let mock = Builder::new().read(&u64::MAX.to_le_bytes()).build(); + let mut reader = NixReader::new(mock); + assert_eq!( + std::io::ErrorKind::InvalidData, + reader.read_value::<usize>().await.unwrap_err().kind() + ); + } +} diff --git a/tvix/nix-compat/src/wire/de/mock.rs b/tvix/nix-compat/src/wire/de/mock.rs new file mode 100644 index 000000000000..8a1fb817743c --- /dev/null +++ b/tvix/nix-compat/src/wire/de/mock.rs @@ -0,0 +1,261 @@ +use std::collections::VecDeque; +use std::fmt; +use std::io; +use std::thread; + +use bytes::Bytes; +use thiserror::Error; + +use crate::wire::ProtocolVersion; + +use super::NixRead; + +#[derive(Debug, Error, PartialEq, Eq, Clone)] +pub enum Error { + #[error("custom error '{0}'")] + Custom(String), + #[error("invalid data '{0}'")] + InvalidData(String), + #[error("missing data '{0}'")] + MissingData(String), + #[error("IO error {0} '{1}'")] + IO(io::ErrorKind, String), + #[error("wrong read: expected {0} got {1}")] + WrongRead(OperationType, OperationType), +} + +impl Error { + pub fn expected_read_number() -> Error { + Error::WrongRead(OperationType::ReadNumber, OperationType::ReadBytes) + } + + pub fn expected_read_bytes() -> Error { + Error::WrongRead(OperationType::ReadBytes, OperationType::ReadNumber) + } +} + +impl super::Error for Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + Self::Custom(msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + Self::IO(err.kind(), err.to_string()) + } + + fn invalid_data<T: fmt::Display>(msg: T) -> Self { + Self::InvalidData(msg.to_string()) + } + + fn missing_data<T: fmt::Display>(msg: T) -> Self { + Self::MissingData(msg.to_string()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OperationType { + ReadNumber, + ReadBytes, +} + +impl fmt::Display for OperationType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::ReadNumber => write!(f, "read_number"), + Self::ReadBytes => write!(f, "read_bytess"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum Operation { + ReadNumber(Result<u64, Error>), + ReadBytes(Result<Bytes, Error>), +} + +impl From<Operation> for OperationType { + fn from(value: Operation) -> Self { + match value { + Operation::ReadNumber(_) => OperationType::ReadNumber, + Operation::ReadBytes(_) => OperationType::ReadBytes, + } + } +} + +pub struct Builder { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Builder { + pub fn new() -> Builder { + Builder { + version: Default::default(), + ops: VecDeque::new(), + } + } + + pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self { + self.version = version.into(); + self + } + + pub fn read_number(&mut self, value: u64) -> &mut Self { + self.ops.push_back(Operation::ReadNumber(Ok(value))); + self + } + + pub fn read_number_error(&mut self, err: Error) -> &mut Self { + self.ops.push_back(Operation::ReadNumber(Err(err))); + self + } + + pub fn read_bytes(&mut self, value: Bytes) -> &mut Self { + self.ops.push_back(Operation::ReadBytes(Ok(value))); + self + } + + pub fn read_slice(&mut self, data: &[u8]) -> &mut Self { + let value = Bytes::copy_from_slice(data); + self.ops.push_back(Operation::ReadBytes(Ok(value))); + self + } + + pub fn read_bytes_error(&mut self, err: Error) -> &mut Self { + self.ops.push_back(Operation::ReadBytes(Err(err))); + self + } + + pub fn build(&mut self) -> Mock { + Mock { + version: self.version, + ops: self.ops.clone(), + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +pub struct Mock { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl NixRead for Mock { + type Error = Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> { + match self.ops.pop_front() { + Some(Operation::ReadNumber(ret)) => ret.map(Some), + Some(Operation::ReadBytes(_)) => Err(Error::expected_read_bytes()), + None => Ok(None), + } + } + + async fn try_read_bytes_limited( + &mut self, + _limit: std::ops::RangeInclusive<usize>, + ) -> Result<Option<Bytes>, Self::Error> { + match self.ops.pop_front() { + Some(Operation::ReadBytes(ret)) => ret.map(Some), + Some(Operation::ReadNumber(_)) => Err(Error::expected_read_number()), + None => Ok(None), + } + } +} + +impl Drop for Mock { + fn drop(&mut self) { + // No need to panic again + if thread::panicking() { + return; + } + if let Some(op) = self.ops.front() { + panic!("reader dropped with {op:?} operation still unread") + } + } +} + +#[cfg(test)] +mod test { + use bytes::Bytes; + use hex_literal::hex; + + use crate::wire::de::NixRead; + + use super::{Builder, Error}; + + #[tokio::test] + async fn read_slice() { + let mut mock = Builder::new() + .read_number(10) + .read_slice(&[]) + .read_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]); + assert_eq!( + &hex!("0000 1234 5678 9ABC DEFF"), + &mock.read_bytes().await.unwrap()[..] + ); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn read_bytes() { + let mut mock = Builder::new() + .read_number(10) + .read_bytes(Bytes::from_static(&[])) + .read_bytes(Bytes::from_static(&hex!("0000 1234 5678 9ABC DEFF"))) + .build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]); + assert_eq!( + &hex!("0000 1234 5678 9ABC DEFF"), + &mock.read_bytes().await.unwrap()[..] + ); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn read_number() { + let mut mock = Builder::new().read_number(10).build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn expect_number() { + let mut mock = Builder::new().read_number(10).build(); + assert_eq!( + Error::expected_read_number(), + mock.read_bytes().await.unwrap_err() + ); + } + + #[tokio::test] + async fn expect_bytes() { + let mut mock = Builder::new().read_slice(&[]).build(); + assert_eq!( + Error::expected_read_bytes(), + mock.read_number().await.unwrap_err() + ); + } + + #[test] + #[should_panic] + fn operations_left() { + let _ = Builder::new().read_number(10).build(); + } +} diff --git a/tvix/nix-compat/src/wire/de/mod.rs b/tvix/nix-compat/src/wire/de/mod.rs new file mode 100644 index 000000000000..f85ccd8fea0e --- /dev/null +++ b/tvix/nix-compat/src/wire/de/mod.rs @@ -0,0 +1,225 @@ +use std::error::Error as StdError; +use std::future::Future; +use std::ops::RangeInclusive; +use std::{fmt, io}; + +use ::bytes::Bytes; + +use super::ProtocolVersion; + +mod bytes; +mod collections; +mod int; +#[cfg(any(test, feature = "test"))] +pub mod mock; +mod reader; + +pub use reader::{NixReader, NixReaderBuilder}; + +/// Like serde the `Error` trait allows `NixRead` implementations to add +/// custom error handling for `NixDeserialize`. +pub trait Error: Sized + StdError { + /// A totally custom non-specific error. + fn custom<T: fmt::Display>(msg: T) -> Self; + + /// Some kind of std::io::Error occured. + fn io_error(err: std::io::Error) -> Self { + Self::custom(format_args!("There was an I/O error {}", err)) + } + + /// The data read from `NixRead` is invalid. + /// This could be that some bytes were supposed to be valid UFT-8 but weren't. + fn invalid_data<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } + + /// Required data is missing. This is mostly like an EOF + fn missing_data<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } +} + +impl Error for io::Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::Other, msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + err + } + + fn invalid_data<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::InvalidData, msg.to_string()) + } + + fn missing_data<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::UnexpectedEof, msg.to_string()) + } +} + +/// A reader of data from the Nix daemon protocol. +/// Basically there are two basic types in the Nix daemon protocol +/// u64 and a bytes buffer. Everything else is more or less built on +/// top of these two types. +pub trait NixRead: Send { + type Error: Error + Send; + + /// Some types are serialized differently depending on the version + /// of the protocol and so this can be used for implementing that. + fn version(&self) -> ProtocolVersion; + + /// Read a single u64 from the protocol. + /// This returns an Option to support graceful shutdown. + fn try_read_number( + &mut self, + ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + '_; + + /// Read bytes from the protocol. + /// A size limit on the returned bytes has to be specified. + /// This returns an Option to support graceful shutdown. + fn try_read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_; + + /// Read bytes from the protocol without a limit. + /// The default implementation just calls `try_read_bytes_limited` with a + /// limit of `0..=usize::MAX` but other implementations are free to have a + /// reader wide limit. + /// This returns an Option to support graceful shutdown. + fn try_read_bytes( + &mut self, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + self.try_read_bytes_limited(0..=usize::MAX) + } + + /// Read a single u64 from the protocol. + /// This will return an error if the number could not be read. + fn read_number(&mut self) -> impl Future<Output = Result<u64, Self::Error>> + Send + '_ { + async move { + match self.try_read_number().await? { + Some(v) => Ok(v), + None => Err(Self::Error::missing_data("unexpected end-of-file")), + } + } + } + + /// Read bytes from the protocol. + /// A size limit on the returned bytes has to be specified. + /// This will return an error if the number could not be read. + fn read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + async move { + match self.try_read_bytes_limited(limit).await? { + Some(v) => Ok(v), + None => Err(Self::Error::missing_data("unexpected end-of-file")), + } + } + } + + /// Read bytes from the protocol. + /// The default implementation just calls `read_bytes_limited` with a + /// limit of `0..=usize::MAX` but other implementations are free to have a + /// reader wide limit. + /// This will return an error if the bytes could not be read. + fn read_bytes(&mut self) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + self.read_bytes_limited(0..=usize::MAX) + } + + /// Read a value from the protocol. + /// Uses `NixDeserialize::deserialize` to read a value. + fn read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<V, Self::Error>> + Send + '_ { + V::deserialize(self) + } + + /// Read a value from the protocol. + /// Uses `NixDeserialize::try_deserialize` to read a value. + /// This returns an Option to support graceful shutdown. + fn try_read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<Option<V>, Self::Error>> + Send + '_ { + V::try_deserialize(self) + } +} + +impl<T: ?Sized + NixRead> NixRead for &mut T { + type Error = T::Error; + + fn version(&self) -> ProtocolVersion { + (**self).version() + } + + fn try_read_number( + &mut self, + ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + '_ { + (**self).try_read_number() + } + + fn try_read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + (**self).try_read_bytes_limited(limit) + } + + fn try_read_bytes( + &mut self, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + (**self).try_read_bytes() + } + + fn read_number(&mut self) -> impl Future<Output = Result<u64, Self::Error>> + Send + '_ { + (**self).read_number() + } + + fn read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + (**self).read_bytes_limited(limit) + } + + fn read_bytes(&mut self) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + (**self).read_bytes() + } + + fn try_read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<Option<V>, Self::Error>> + Send + '_ { + (**self).try_read_value() + } + + fn read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<V, Self::Error>> + Send + '_ { + (**self).read_value() + } +} + +/// A data structure that can be deserialized from the Nix daemon +/// worker protocol. +pub trait NixDeserialize: Sized { + /// Read a value from the reader. + /// This returns an Option to support gracefull shutdown. + fn try_deserialize<R>( + reader: &mut R, + ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send; + + fn deserialize<R>(reader: &mut R) -> impl Future<Output = Result<Self, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send, + { + async move { + match Self::try_deserialize(reader).await? { + Some(v) => Ok(v), + None => Err(R::Error::missing_data("unexpected end-of-file")), + } + } + } +} diff --git a/tvix/nix-compat/src/wire/de/reader.rs b/tvix/nix-compat/src/wire/de/reader.rs new file mode 100644 index 000000000000..b7825f393c4e --- /dev/null +++ b/tvix/nix-compat/src/wire/de/reader.rs @@ -0,0 +1,526 @@ +use std::future::poll_fn; +use std::io::{self, Cursor}; +use std::ops::RangeInclusive; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf}; + +use crate::wire::{ProtocolVersion, EMPTY_BYTES}; + +use super::{Error, NixRead}; + +pub struct NixReaderBuilder { + buf: Option<BytesMut>, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, +} + +impl Default for NixReaderBuilder { + fn default() -> Self { + Self { + buf: Default::default(), + reserved_buf_size: 8192, + max_buf_size: 8192, + version: Default::default(), + } + } +} + +impl NixReaderBuilder { + pub fn set_buffer(mut self, buf: BytesMut) -> Self { + self.buf = Some(buf); + self + } + + pub fn set_reserved_buf_size(mut self, size: usize) -> Self { + self.reserved_buf_size = size; + self + } + + pub fn set_max_buf_size(mut self, size: usize) -> Self { + self.max_buf_size = size; + self + } + + pub fn set_version(mut self, version: ProtocolVersion) -> Self { + self.version = version; + self + } + + pub fn build<R>(self, reader: R) -> NixReader<R> { + let buf = self.buf.unwrap_or_else(|| BytesMut::with_capacity(0)); + NixReader { + buf, + inner: reader, + reserved_buf_size: self.reserved_buf_size, + max_buf_size: self.max_buf_size, + version: self.version, + } + } +} + +pin_project! { + pub struct NixReader<R> { + #[pin] + inner: R, + buf: BytesMut, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, + } +} + +impl NixReader<Cursor<Vec<u8>>> { + pub fn builder() -> NixReaderBuilder { + NixReaderBuilder::default() + } +} + +impl<R> NixReader<R> +where + R: AsyncReadExt, +{ + pub fn new(reader: R) -> NixReader<R> { + NixReader::builder().build(reader) + } + + pub fn buffer(&self) -> &[u8] { + &self.buf[..] + } + + #[cfg(test)] + pub(crate) fn buffer_mut(&mut self) -> &mut BytesMut { + &mut self.buf + } + + /// Remaining capacity in internal buffer + pub fn remaining_mut(&self) -> usize { + self.buf.capacity() - self.buf.len() + } + + fn poll_force_fill_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<usize>> { + // Ensure that buffer has space for at least reserved_buf_size bytes + if self.remaining_mut() < self.reserved_buf_size { + let me = self.as_mut().project(); + me.buf.reserve(*me.reserved_buf_size); + } + let me = self.project(); + let n = { + let dst = me.buf.spare_capacity_mut(); + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(me.inner.poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // SAFETY: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + me.buf.advance_mut(n); + } + Poll::Ready(Ok(n)) + } +} + +impl<R> NixReader<R> +where + R: AsyncReadExt + Unpin, +{ + async fn force_fill(&mut self) -> io::Result<usize> { + let mut p = Pin::new(self); + let read = poll_fn(|cx| p.as_mut().poll_force_fill_buf(cx)).await?; + Ok(read) + } +} + +impl<R> NixRead for NixReader<R> +where + R: AsyncReadExt + Send + Unpin, +{ + type Error = io::Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> { + let mut buf = [0u8; 8]; + let read = self.read_buf(&mut &mut buf[..]).await?; + if read == 0 { + return Ok(None); + } + if read < 8 { + self.read_exact(&mut buf[read..]).await?; + } + let num = Buf::get_u64_le(&mut &buf[..]); + Ok(Some(num)) + } + + async fn try_read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> Result<Option<Bytes>, Self::Error> { + assert!( + *limit.end() <= self.max_buf_size, + "The limit must be smaller than {}", + self.max_buf_size + ); + match self.try_read_number().await? { + Some(raw_len) => { + // Check that length is in range and convert to usize + let len = raw_len + .try_into() + .ok() + .filter(|v| limit.contains(v)) + .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))?; + + // Calculate 64bit aligned length and convert to usize + let aligned: usize = raw_len + .checked_add(7) + .map(|v| v & !7) + .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))? + .try_into() + .map_err(Self::Error::invalid_data)?; + + // Ensure that there is enough space in buffer for contents + if self.buf.len() + self.remaining_mut() < aligned { + self.buf.reserve(aligned - self.buf.len()); + } + while self.buf.len() < aligned { + if self.force_fill().await? == 0 { + return Err(Self::Error::missing_data( + "unexpected end-of-file reading bytes", + )); + } + } + let mut contents = self.buf.split_to(aligned); + + let padding = aligned - len; + // Ensure padding is all zeros + if contents[len..] != EMPTY_BYTES[..padding] { + return Err(Self::Error::invalid_data("non-zero padding")); + } + + contents.truncate(len); + Ok(Some(contents.freeze())) + } + None => Ok(None), + } + } + + fn try_read_bytes( + &mut self, + ) -> impl std::future::Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + self.try_read_bytes_limited(0..=self.max_buf_size) + } + + fn read_bytes( + &mut self, + ) -> impl std::future::Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + self.read_bytes_limited(0..=self.max_buf_size) + } +} + +impl<R: AsyncRead> AsyncRead for NixReader<R> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let rem = ready!(self.as_mut().poll_fill_buf(cx))?; + let amt = std::cmp::min(rem.len(), buf.remaining()); + buf.put_slice(&rem[0..amt]); + self.consume(amt); + Poll::Ready(Ok(())) + } +} + +impl<R: AsyncRead> AsyncBufRead for NixReader<R> { + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + if self.as_ref().project_ref().buf.is_empty() { + ready!(self.as_mut().poll_force_fill_buf(cx))?; + } + let me = self.project(); + Poll::Ready(Ok(&me.buf[..])) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let me = self.project(); + me.buf.advance(amt) + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use super::*; + use crate::wire::de::NixRead; + + #[tokio::test] + async fn test_read_u64() { + let mock = Builder::new().read(&hex!("0100 0000 0000 0000")).build(); + let mut reader = NixReader::new(mock); + + assert_eq!(1, reader.read_number().await.unwrap()); + assert_eq!(hex!(""), reader.buffer()); + + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(hex!(""), &buf[..]); + } + + #[tokio::test] + async fn test_read_u64_rest() { + let mock = Builder::new() + .read(&hex!("0100 0000 0000 0000 0123 4567 89AB CDEF")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!(1, reader.read_number().await.unwrap()); + assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer()); + + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(hex!("0123 4567 89AB CDEF"), &buf[..]); + } + + #[tokio::test] + async fn test_read_u64_partial() { + let mock = Builder::new() + .read(&hex!("0100 0000")) + .wait(Duration::ZERO) + .read(&hex!("0000 0000 0123 4567 89AB CDEF")) + .wait(Duration::ZERO) + .read(&hex!("0100 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!(1, reader.read_number().await.unwrap()); + assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer()); + + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(hex!("0123 4567 89AB CDEF 0100 0000"), &buf[..]); + } + + #[tokio::test] + async fn test_read_u64_eof() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.read_number().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_u64_none() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!(None, reader.try_read_number().await.unwrap()); + } + + #[tokio::test] + async fn test_try_read_u64_eof() { + let mock = Builder::new().read(&hex!("0100 0000 0000")).build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_number().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_u64_eof2() { + let mock = Builder::new() + .read(&hex!("0100")) + .wait(Duration::ZERO) + .read(&hex!("0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_number().await.unwrap_err().kind() + ); + } + + #[rstest] + #[case::empty(b"", &hex!("0000 0000 0000 0000"))] + #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_read_bytes(#[case] expected: &[u8], #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual = reader.read_bytes().await.unwrap(); + assert_eq!(&actual[..], expected); + } + + #[tokio::test] + async fn test_read_bytes_empty() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_bytes_none() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!(None, reader.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn test_try_read_bytes_missing_data() { + let mock = Builder::new() + .read(&hex!("0500")) + .wait(Duration::ZERO) + .read(&hex!("0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_bytes_missing_padding() { + let mock = Builder::new() + .read(&hex!("0200 0000 0000 0000")) + .wait(Duration::ZERO) + .read(&hex!("1234")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_read_bytes_bad_padding() { + let mock = Builder::new() + .read(&hex!("0200 0000 0000 0000")) + .wait(Duration::ZERO) + .read(&hex!("1234 0100 0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::InvalidData, + reader.read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_read_bytes_limited_out_of_range() { + let mock = Builder::new().read(&hex!("FFFF 0000 0000 0000")).build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::InvalidData, + reader.read_bytes_limited(0..=50).await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_read_bytes_length_overflow() { + let mock = Builder::new().read(&hex!("F9FF FFFF FFFF FFFF")).build(); + let mut reader = NixReader::builder() + .set_max_buf_size(usize::MAX) + .build(mock); + + assert_eq!( + io::ErrorKind::InvalidData, + reader + .read_bytes_limited(0..=usize::MAX) + .await + .unwrap_err() + .kind() + ); + } + + // FUTUREWORK: Test this on supported hardware + #[tokio::test] + #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] + async fn test_bytes_length_conversion_overflow() { + let len = (usize::MAX as u64) + 1; + let mock = Builder::new().read(&len.to_le_bytes()).build(); + let mut reader = NixReader::new(mock); + assert_eq!( + std::io::ErrorKind::InvalidData, + reader.read_value::<usize>().await.unwrap_err().kind() + ); + } + + // FUTUREWORK: Test this on supported hardware + #[tokio::test] + #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] + async fn test_bytes_aligned_length_conversion_overflow() { + let len = (usize::MAX - 6) as u64; + let mock = Builder::new().read(&len.to_le_bytes()).build(); + let mut reader = NixReader::new(mock); + assert_eq!( + std::io::ErrorKind::InvalidData, + reader.read_value::<usize>().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_buffer_resize() { + let mock = Builder::new() + .read(&hex!("0100")) + .read(&hex!("0000 0000 0000")) + .build(); + let mut reader = NixReader::builder().set_reserved_buf_size(8).build(mock); + // buffer has no capacity initially + assert_eq!(0, reader.buffer_mut().capacity()); + + assert_eq!(2, reader.force_fill().await.unwrap()); + + // After first read buffer should have capacity we chose + assert_eq!(8, reader.buffer_mut().capacity()); + + // Because there was only 6 bytes remaining in buffer, + // which is enough to read the last 6 bytes, but we require + // capacity for 8 bytes, it doubled the capacity + assert_eq!(6, reader.force_fill().await.unwrap()); + assert_eq!(16, reader.buffer_mut().capacity()); + + assert_eq!(1, reader.read_number().await.unwrap()); + } +} diff --git a/tvix/nix-compat/src/wire/mod.rs b/tvix/nix-compat/src/wire/mod.rs index a197e3a1f451..c3e88dda05ec 100644 --- a/tvix/nix-compat/src/wire/mod.rs +++ b/tvix/nix-compat/src/wire/mod.rs @@ -3,3 +3,9 @@ mod bytes; pub use bytes::*; + +mod protocol_version; +pub use protocol_version::ProtocolVersion; + +pub mod de; +pub mod ser; diff --git a/tvix/nix-compat/src/wire/protocol_version.rs b/tvix/nix-compat/src/wire/protocol_version.rs new file mode 100644 index 000000000000..19da28d484dd --- /dev/null +++ b/tvix/nix-compat/src/wire/protocol_version.rs @@ -0,0 +1,139 @@ +/// The latest version that is currently supported by nix-compat. +static DEFAULT_PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::from_parts(1, 37); + +/// Protocol versions are represented as a u16. +/// The upper 8 bits are the major version, the lower bits the minor. +/// This is not aware of any endianness, use [crate::wire::read_u64] to get an +/// u64 first, and the try_from() impl from here if you're receiving over the +/// Nix Worker protocol. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct ProtocolVersion(u16); + +impl ProtocolVersion { + pub const fn from_parts(major: u8, minor: u8) -> Self { + Self(((major as u16) << 8) | minor as u16) + } + + pub fn major(&self) -> u8 { + ((self.0 & 0xff00) >> 8) as u8 + } + + pub fn minor(&self) -> u8 { + (self.0 & 0x00ff) as u8 + } +} + +impl Default for ProtocolVersion { + fn default() -> Self { + DEFAULT_PROTOCOL_VERSION + } +} + +impl PartialOrd for ProtocolVersion { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +impl Ord for ProtocolVersion { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match self.major().cmp(&other.major()) { + std::cmp::Ordering::Less => std::cmp::Ordering::Less, + std::cmp::Ordering::Greater => std::cmp::Ordering::Greater, + std::cmp::Ordering::Equal => { + // same major, compare minor + self.minor().cmp(&other.minor()) + } + } + } +} + +impl From<u16> for ProtocolVersion { + fn from(value: u16) -> Self { + Self::from_parts(((value & 0xff00) >> 8) as u8, (value & 0x00ff) as u8) + } +} + +#[cfg(any(test, feature = "test"))] +impl From<(u8, u8)> for ProtocolVersion { + fn from((major, minor): (u8, u8)) -> Self { + Self::from_parts(major, minor) + } +} + +impl TryFrom<u64> for ProtocolVersion { + type Error = &'static str; + + fn try_from(value: u64) -> Result<Self, Self::Error> { + if value & !0xffff != 0 { + return Err("only two least significant bits might be populated"); + } + + Ok((value as u16).into()) + } +} + +impl From<ProtocolVersion> for u16 { + fn from(value: ProtocolVersion) -> Self { + value.0 + } +} + +impl From<ProtocolVersion> for u64 { + fn from(value: ProtocolVersion) -> Self { + value.0 as u64 + } +} + +impl std::fmt::Display for ProtocolVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}", self.major(), self.minor()) + } +} + +#[cfg(test)] +mod tests { + use super::ProtocolVersion; + + #[test] + fn from_parts() { + let version = ProtocolVersion::from_parts(1, 37); + assert_eq!(version.major(), 1, "correct major"); + assert_eq!(version.minor(), 37, "correct minor"); + assert_eq!("1.37", &version.to_string(), "to_string"); + + assert_eq!(0x0125, Into::<u16>::into(version)); + assert_eq!(0x0125, Into::<u64>::into(version)); + } + + #[test] + fn from_u16() { + let version = ProtocolVersion::from(0x0125_u16); + assert_eq!("1.37", &version.to_string()); + } + + #[test] + fn from_u64() { + let version = ProtocolVersion::try_from(0x0125_u64).expect("must succeed"); + assert_eq!("1.37", &version.to_string()); + } + + /// This contains data in higher bits, which should fail. + #[test] + fn from_u64_fail() { + ProtocolVersion::try_from(0xaa0125_u64).expect_err("must fail"); + } + + #[test] + fn ord() { + let v0_37 = ProtocolVersion::from_parts(0, 37); + let v1_37 = ProtocolVersion::from_parts(1, 37); + let v1_40 = ProtocolVersion::from_parts(1, 40); + + assert!(v0_37 < v1_37); + assert!(v1_37 > v0_37); + assert!(v1_37 < v1_40); + assert!(v1_40 > v1_37); + assert!(v1_40 <= v1_40); + } +} diff --git a/tvix/nix-compat/src/wire/ser/bytes.rs b/tvix/nix-compat/src/wire/ser/bytes.rs new file mode 100644 index 000000000000..4338d3f8761e --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/bytes.rs @@ -0,0 +1,89 @@ +use bytes::Bytes; + +use super::{NixSerialize, NixWrite}; + +impl NixSerialize for Bytes { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self).await + } +} + +impl<'a> NixSerialize for &'a [u8] { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self).await + } +} + +impl NixSerialize for String { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +impl NixSerialize for str { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::{NixWrite, NixWriter}; + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_write_str(#[case] value: &str, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_write_string(#[case] value: &str, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value.to_string()).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/wire/ser/collections.rs b/tvix/nix-compat/src/wire/ser/collections.rs new file mode 100644 index 000000000000..478e1d04d809 --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/collections.rs @@ -0,0 +1,94 @@ +use std::collections::BTreeMap; +use std::future::Future; + +use super::{NixSerialize, NixWrite}; + +impl<T> NixSerialize for Vec<T> +where + T: NixSerialize + Send + Sync, +{ + #[allow(clippy::manual_async_fn)] + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite, + { + async move { + writer.write_value(&self.len()).await?; + for value in self.iter() { + writer.write_value(value).await?; + } + Ok(()) + } + } +} + +impl<K, V> NixSerialize for BTreeMap<K, V> +where + K: NixSerialize + Ord + Send + Sync, + V: NixSerialize + Send + Sync, +{ + #[allow(clippy::manual_async_fn)] + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite, + { + async move { + writer.write_value(&self.len()).await?; + for (key, value) in self.iter() { + writer.write_value(key).await?; + writer.write_value(value).await?; + } + Ok(()) + } + } +} + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + use std::fmt; + + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::{NixSerialize, NixWrite, NixWriter}; + + #[rstest] + #[case::empty(vec![], &hex!("0000 0000 0000 0000"))] + #[case::one(vec![0x29], &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(vec![0x7469, 10], &hex!("0200 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_write_small_vec(#[case] value: Vec<usize>, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + fn empty_map() -> BTreeMap<usize, u64> { + BTreeMap::new() + } + macro_rules! map { + ($($key:expr => $value:expr),*) => {{ + let mut ret = BTreeMap::new(); + $(ret.insert($key, $value);)* + ret + }}; + } + + #[rstest] + #[case::empty(empty_map(), &hex!("0000 0000 0000 0000"))] + #[case::one(map![0x7469usize => 10u64], &hex!("0100 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_write_small_btree_map<E>(#[case] value: E, #[case] data: &[u8]) + where + E: NixSerialize + Send + PartialEq + fmt::Debug, + { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/wire/ser/display.rs b/tvix/nix-compat/src/wire/ser/display.rs new file mode 100644 index 000000000000..a3438d50d8ff --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/display.rs @@ -0,0 +1,8 @@ +use nix_compat_derive::nix_serialize_remote; + +use crate::nixhash; + +nix_serialize_remote!( + #[nix(display)] + nixhash::HashAlgo +); diff --git a/tvix/nix-compat/src/wire/ser/int.rs b/tvix/nix-compat/src/wire/ser/int.rs new file mode 100644 index 000000000000..e68179c71dc7 --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/int.rs @@ -0,0 +1,108 @@ +#[cfg(feature = "nix-compat-derive")] +use nix_compat_derive::nix_serialize_remote; + +use super::{Error, NixSerialize, NixWrite}; + +impl NixSerialize for u64 { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_number(*self).await + } +} + +impl NixSerialize for usize { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + let v = (*self).try_into().map_err(W::Error::unsupported_data)?; + writer.write_number(v).await + } +} + +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u8 +); +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u16 +); +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u32 +); + +impl NixSerialize for bool { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + if *self { + writer.write_number(1).await + } else { + writer.write_number(0).await + } + } +} + +impl NixSerialize for i64 { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_number(*self as u64).await + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::{NixWrite, NixWriter}; + + #[rstest] + #[case::simple_false(false, &hex!("0000 0000 0000 0000"))] + #[case::simple_true(true, &hex!("0100 0000 0000 0000"))] + #[tokio::test] + async fn test_write_bool(#[case] value: bool, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_write_u64(#[case] value: u64, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(usize::MAX, &usize::MAX.to_le_bytes())] + #[tokio::test] + async fn test_write_usize(#[case] value: usize, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/wire/ser/mock.rs b/tvix/nix-compat/src/wire/ser/mock.rs new file mode 100644 index 000000000000..7104a94238ff --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/mock.rs @@ -0,0 +1,672 @@ +use std::collections::VecDeque; +use std::fmt; +use std::io; +use std::thread; + +#[cfg(test)] +use ::proptest::prelude::TestCaseError; +use thiserror::Error; + +use crate::wire::ProtocolVersion; + +use super::NixWrite; + +#[derive(Debug, Error, PartialEq, Eq, Clone)] +pub enum Error { + #[error("custom error '{0}'")] + Custom(String), + #[error("unsupported data error '{0}'")] + UnsupportedData(String), + #[error("Invalid enum: {0}")] + InvalidEnum(String), + #[error("IO error {0} '{1}'")] + IO(io::ErrorKind, String), + #[error("wrong write: expected {0} got {1}")] + WrongWrite(OperationType, OperationType), + #[error("unexpected write: got an extra {0}")] + ExtraWrite(OperationType), + #[error("got an unexpected number {0} in write_number")] + UnexpectedNumber(u64), + #[error("got an unexpected slice '{0:?}' in write_slice")] + UnexpectedSlice(Vec<u8>), + #[error("got an unexpected display '{0:?}' in write_slice")] + UnexpectedDisplay(String), +} + +impl Error { + pub fn unexpected_write_number(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteNumber) + } + + pub fn extra_write_number() -> Error { + Error::ExtraWrite(OperationType::WriteNumber) + } + + pub fn unexpected_write_slice(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteSlice) + } + + pub fn unexpected_write_display(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteDisplay) + } +} + +impl super::Error for Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + Self::Custom(msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + Self::IO(err.kind(), err.to_string()) + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + Self::UnsupportedData(msg.to_string()) + } + + fn invalid_enum<T: fmt::Display>(msg: T) -> Self { + Self::InvalidEnum(msg.to_string()) + } +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OperationType { + WriteNumber, + WriteSlice, + WriteDisplay, +} + +impl fmt::Display for OperationType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::WriteNumber => write!(f, "write_number"), + Self::WriteSlice => write!(f, "write_slice"), + Self::WriteDisplay => write!(f, "write_display"), + } + } +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, PartialEq, Eq)] +enum Operation { + WriteNumber(u64, Result<(), Error>), + WriteSlice(Vec<u8>, Result<(), Error>), + WriteDisplay(String, Result<(), Error>), +} + +impl From<Operation> for OperationType { + fn from(value: Operation) -> Self { + match value { + Operation::WriteNumber(_, _) => OperationType::WriteNumber, + Operation::WriteSlice(_, _) => OperationType::WriteSlice, + Operation::WriteDisplay(_, _) => OperationType::WriteDisplay, + } + } +} + +pub struct Builder { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Builder { + pub fn new() -> Builder { + Builder { + version: Default::default(), + ops: VecDeque::new(), + } + } + + pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self { + self.version = version.into(); + self + } + + pub fn write_number(&mut self, value: u64) -> &mut Self { + self.ops.push_back(Operation::WriteNumber(value, Ok(()))); + self + } + + pub fn write_number_error(&mut self, value: u64, err: Error) -> &mut Self { + self.ops.push_back(Operation::WriteNumber(value, Err(err))); + self + } + + pub fn write_slice(&mut self, value: &[u8]) -> &mut Self { + self.ops + .push_back(Operation::WriteSlice(value.to_vec(), Ok(()))); + self + } + + pub fn write_slice_error(&mut self, value: &[u8], err: Error) -> &mut Self { + self.ops + .push_back(Operation::WriteSlice(value.to_vec(), Err(err))); + self + } + + pub fn write_display<D>(&mut self, value: D) -> &mut Self + where + D: fmt::Display, + { + let msg = value.to_string(); + self.ops.push_back(Operation::WriteDisplay(msg, Ok(()))); + self + } + + pub fn write_display_error<D>(&mut self, value: D, err: Error) -> &mut Self + where + D: fmt::Display, + { + let msg = value.to_string(); + self.ops.push_back(Operation::WriteDisplay(msg, Err(err))); + self + } + + #[cfg(test)] + fn write_operation_type(&mut self, op: OperationType) -> &mut Self { + match op { + OperationType::WriteNumber => self.write_number(10), + OperationType::WriteSlice => self.write_slice(b"testing"), + OperationType::WriteDisplay => self.write_display("testing"), + } + } + + #[cfg(test)] + fn write_operation(&mut self, op: &Operation) -> &mut Self { + match op { + Operation::WriteNumber(value, Ok(_)) => self.write_number(*value), + Operation::WriteNumber(value, Err(Error::UnexpectedNumber(_))) => { + self.write_number(*value) + } + Operation::WriteNumber(_, Err(Error::ExtraWrite(OperationType::WriteNumber))) => self, + Operation::WriteNumber(_, Err(Error::WrongWrite(op, OperationType::WriteNumber))) => { + self.write_operation_type(*op) + } + Operation::WriteNumber(value, Err(Error::Custom(msg))) => { + self.write_number_error(*value, Error::Custom(msg.clone())) + } + Operation::WriteNumber(value, Err(Error::IO(kind, msg))) => { + self.write_number_error(*value, Error::IO(*kind, msg.clone())) + } + Operation::WriteSlice(value, Ok(_)) => self.write_slice(value), + Operation::WriteSlice(value, Err(Error::UnexpectedSlice(_))) => self.write_slice(value), + Operation::WriteSlice(_, Err(Error::ExtraWrite(OperationType::WriteSlice))) => self, + Operation::WriteSlice(_, Err(Error::WrongWrite(op, OperationType::WriteSlice))) => { + self.write_operation_type(*op) + } + Operation::WriteSlice(value, Err(Error::Custom(msg))) => { + self.write_slice_error(value, Error::Custom(msg.clone())) + } + Operation::WriteSlice(value, Err(Error::IO(kind, msg))) => { + self.write_slice_error(value, Error::IO(*kind, msg.clone())) + } + Operation::WriteDisplay(value, Ok(_)) => self.write_display(value), + Operation::WriteDisplay(value, Err(Error::Custom(msg))) => { + self.write_display_error(value, Error::Custom(msg.clone())) + } + Operation::WriteDisplay(value, Err(Error::IO(kind, msg))) => { + self.write_display_error(value, Error::IO(*kind, msg.clone())) + } + Operation::WriteDisplay(value, Err(Error::UnexpectedDisplay(_))) => { + self.write_display(value) + } + Operation::WriteDisplay(_, Err(Error::ExtraWrite(OperationType::WriteDisplay))) => self, + Operation::WriteDisplay(_, Err(Error::WrongWrite(op, OperationType::WriteDisplay))) => { + self.write_operation_type(*op) + } + s => panic!("Invalid operation {:?}", s), + } + } + + pub fn build(&mut self) -> Mock { + Mock { + version: self.version, + ops: self.ops.clone(), + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +pub struct Mock { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Mock { + #[cfg(test)] + #[allow(dead_code)] + async fn assert_operation(&mut self, op: Operation) { + match op { + Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { + assert_eq!( + self.write_number(value).await, + Err(Error::UnexpectedNumber(value)) + ); + } + Operation::WriteNumber(value, res) => { + assert_eq!(self.write_number(value).await, res); + } + Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { + assert_eq!(self.write_slice(value).await, res.clone()); + } + Operation::WriteSlice(value, res) => { + assert_eq!(self.write_slice(&value).await, res); + } + Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { + assert_eq!(self.write_display(value).await, res.clone()); + } + Operation::WriteDisplay(value, res) => { + assert_eq!(self.write_display(value).await, res); + } + } + } + + #[cfg(test)] + async fn prop_assert_operation(&mut self, op: Operation) -> Result<(), TestCaseError> { + use ::proptest::prop_assert_eq; + + match op { + Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { + prop_assert_eq!( + self.write_number(value).await, + Err(Error::UnexpectedNumber(value)) + ); + } + Operation::WriteNumber(value, res) => { + prop_assert_eq!(self.write_number(value).await, res); + } + Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { + prop_assert_eq!(self.write_slice(value).await, res.clone()); + } + Operation::WriteSlice(value, res) => { + prop_assert_eq!(self.write_slice(&value).await, res); + } + Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { + prop_assert_eq!(self.write_display(&value).await, res.clone()); + } + Operation::WriteDisplay(value, res) => { + prop_assert_eq!(self.write_display(&value).await, res); + } + } + Ok(()) + } +} + +impl NixWrite for Mock { + type Error = Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { + match self.ops.pop_front() { + Some(Operation::WriteNumber(expected, ret)) => { + if value != expected { + return Err(Error::UnexpectedNumber(value)); + } + ret + } + Some(op) => Err(Error::unexpected_write_number(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteNumber)), + } + } + + async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + match self.ops.pop_front() { + Some(Operation::WriteSlice(expected, ret)) => { + if buf != expected { + return Err(Error::UnexpectedSlice(buf.to_vec())); + } + ret + } + Some(op) => Err(Error::unexpected_write_slice(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteSlice)), + } + } + + async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error> + where + D: fmt::Display + Send, + Self: Sized, + { + let value = msg.to_string(); + match self.ops.pop_front() { + Some(Operation::WriteDisplay(expected, ret)) => { + if value != expected { + return Err(Error::UnexpectedDisplay(value)); + } + ret + } + Some(op) => Err(Error::unexpected_write_display(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteDisplay)), + } + } +} + +impl Drop for Mock { + fn drop(&mut self) { + // No need to panic again + if thread::panicking() { + return; + } + if let Some(op) = self.ops.front() { + panic!("reader dropped with {op:?} operation still unread") + } + } +} + +#[cfg(test)] +mod proptest { + use std::io; + + use proptest::{ + prelude::{any, Arbitrary, BoxedStrategy, Just, Strategy}, + prop_oneof, + }; + + use super::{Error, Operation, OperationType}; + + pub fn arb_write_number_operation() -> impl Strategy<Value = Operation> { + ( + any::<u64>(), + prop_oneof![ + Just(Ok(())), + any::<u64>().prop_map(|v| Err(Error::UnexpectedNumber(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteSlice, + OperationType::WriteNumber + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteDisplay, + OperationType::WriteNumber + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same number", |(v, res)| match res { + Err(Error::UnexpectedNumber(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteNumber(v, res)) + } + + pub fn arb_write_slice_operation() -> impl Strategy<Value = Operation> { + ( + any::<Vec<u8>>(), + prop_oneof![ + Just(Ok(())), + any::<Vec<u8>>().prop_map(|v| Err(Error::UnexpectedSlice(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteNumber, + OperationType::WriteSlice + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteDisplay, + OperationType::WriteSlice + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same slice", |(v, res)| match res { + Err(Error::UnexpectedSlice(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteSlice(v, res)) + } + + #[allow(dead_code)] + pub fn arb_extra_write() -> impl Strategy<Value = Operation> { + prop_oneof![ + any::<u64>().prop_map(|msg| { + Operation::WriteNumber(msg, Err(Error::ExtraWrite(OperationType::WriteNumber))) + }), + any::<Vec<u8>>().prop_map(|msg| { + Operation::WriteSlice(msg, Err(Error::ExtraWrite(OperationType::WriteSlice))) + }), + any::<String>().prop_map(|msg| { + Operation::WriteDisplay(msg, Err(Error::ExtraWrite(OperationType::WriteDisplay))) + }), + ] + } + + pub fn arb_write_display_operation() -> impl Strategy<Value = Operation> { + ( + any::<String>(), + prop_oneof![ + Just(Ok(())), + any::<String>().prop_map(|v| Err(Error::UnexpectedDisplay(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteNumber, + OperationType::WriteDisplay + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteSlice, + OperationType::WriteDisplay + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same string", |(v, res)| match res { + Err(Error::UnexpectedDisplay(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteDisplay(v, res)) + } + + pub fn arb_operation() -> impl Strategy<Value = Operation> { + prop_oneof![ + arb_write_number_operation(), + arb_write_slice_operation(), + arb_write_display_operation(), + ] + } + + impl Arbitrary for Operation { + type Parameters = (); + type Strategy = BoxedStrategy<Operation>; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + arb_operation().boxed() + } + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use proptest::prelude::any; + use proptest::prelude::TestCaseError; + use proptest::proptest; + + use crate::wire::ser::mock::proptest::arb_extra_write; + use crate::wire::ser::mock::Operation; + use crate::wire::ser::mock::OperationType; + use crate::wire::ser::Error as _; + use crate::wire::ser::NixWrite; + + use super::{Builder, Error}; + + #[tokio::test] + async fn write_number() { + let mut mock = Builder::new().write_number(10).build(); + mock.write_number(10).await.unwrap(); + } + + #[tokio::test] + async fn write_number_error() { + let mut mock = Builder::new() + .write_number_error(10, Error::custom("bad number")) + .build(); + assert_eq!( + Err(Error::custom("bad number")), + mock.write_number(10).await + ); + } + + #[tokio::test] + async fn write_number_unexpected() { + let mut mock = Builder::new().write_slice(b"").build(); + assert_eq!( + Err(Error::unexpected_write_number(OperationType::WriteSlice)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn write_number_unexpected_number() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::UnexpectedNumber(11)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn extra_write_number() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteNumber)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn write_slice() { + let mut mock = Builder::new() + .write_slice(&[]) + .write_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .build(); + mock.write_slice(&[]).await.expect("write_slice empty"); + mock.write_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .await + .expect("write_slice"); + } + + #[tokio::test] + async fn write_slice_error() { + let mut mock = Builder::new() + .write_slice_error(&[], Error::custom("bad slice")) + .build(); + assert_eq!(Err(Error::custom("bad slice")), mock.write_slice(&[]).await); + } + + #[tokio::test] + async fn write_slice_unexpected() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::unexpected_write_slice(OperationType::WriteNumber)), + mock.write_slice(b"").await + ); + } + + #[tokio::test] + async fn write_slice_unexpected_slice() { + let mut mock = Builder::new().write_slice(b"").build(); + assert_eq!( + Err(Error::UnexpectedSlice(b"bad slice".to_vec())), + mock.write_slice(b"bad slice").await + ); + } + + #[tokio::test] + async fn extra_write_slice() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteSlice)), + mock.write_slice(b"extra slice").await + ); + } + + #[tokio::test] + async fn write_display() { + let mut mock = Builder::new().write_display("testing").build(); + mock.write_display("testing").await.unwrap(); + } + + #[tokio::test] + async fn write_display_error() { + let mut mock = Builder::new() + .write_display_error("testing", Error::custom("bad number")) + .build(); + assert_eq!( + Err(Error::custom("bad number")), + mock.write_display("testing").await + ); + } + + #[tokio::test] + async fn write_display_unexpected() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::unexpected_write_display(OperationType::WriteNumber)), + mock.write_display("").await + ); + } + + #[tokio::test] + async fn write_display_unexpected_display() { + let mut mock = Builder::new().write_display("").build(); + assert_eq!( + Err(Error::UnexpectedDisplay("bad display".to_string())), + mock.write_display("bad display").await + ); + } + + #[tokio::test] + async fn extra_write_display() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteDisplay)), + mock.write_display("extra slice").await + ); + } + + #[test] + #[should_panic] + fn operations_left() { + let _ = Builder::new().write_number(10).build(); + } + + #[test] + fn proptest_mock() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + proptest!(|( + operations in any::<Vec<Operation>>(), + extra_operations in proptest::collection::vec(arb_extra_write(), 0..3) + )| { + rt.block_on(async { + let mut builder = Builder::new(); + for op in operations.iter() { + builder.write_operation(op); + } + for op in extra_operations.iter() { + builder.write_operation(op); + } + let mut mock = builder.build(); + for op in operations { + mock.prop_assert_operation(op).await?; + } + for op in extra_operations { + mock.prop_assert_operation(op).await?; + } + Ok(()) as Result<(), TestCaseError> + })?; + }); + } +} diff --git a/tvix/nix-compat/src/wire/ser/mod.rs b/tvix/nix-compat/src/wire/ser/mod.rs new file mode 100644 index 000000000000..5860226f39eb --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/mod.rs @@ -0,0 +1,124 @@ +use std::error::Error as StdError; +use std::future::Future; +use std::{fmt, io}; + +use super::ProtocolVersion; + +mod bytes; +mod collections; +#[cfg(feature = "nix-compat-derive")] +mod display; +mod int; +#[cfg(any(test, feature = "test"))] +pub mod mock; +mod writer; + +pub use writer::{NixWriter, NixWriterBuilder}; + +pub trait Error: Sized + StdError { + fn custom<T: fmt::Display>(msg: T) -> Self; + + fn io_error(err: std::io::Error) -> Self { + Self::custom(format_args!("There was an I/O error {}", err)) + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } + + fn invalid_enum<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } +} + +impl Error for io::Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::Other, msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + err + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::InvalidData, msg.to_string()) + } +} + +pub trait NixWrite: Send { + type Error: Error; + + /// Some types are serialized differently depending on the version + /// of the protocol and so this can be used for implementing that. + fn version(&self) -> ProtocolVersion; + + /// Write a single u64 to the protocol. + fn write_number(&mut self, value: u64) -> impl Future<Output = Result<(), Self::Error>> + Send; + + /// Write a slice of bytes to the protocol. + fn write_slice(&mut self, buf: &[u8]) -> impl Future<Output = Result<(), Self::Error>> + Send; + + /// Write a value that implements `std::fmt::Display` to the protocol. + /// The protocol uses many small string formats and instead of allocating + /// a `String` each time we want to write one an implementation of `NixWrite` + /// can instead use `Display` to dump these formats to a reusable buffer. + fn write_display<D>(&mut self, msg: D) -> impl Future<Output = Result<(), Self::Error>> + Send + where + D: fmt::Display + Send, + Self: Sized, + { + async move { + let s = msg.to_string(); + self.write_slice(s.as_bytes()).await + } + } + + /// Write a value to the protocol. + /// Uses `NixSerialize::serialize` to write the value. + fn write_value<V>(&mut self, value: &V) -> impl Future<Output = Result<(), Self::Error>> + Send + where + V: NixSerialize + Send + ?Sized, + Self: Sized, + { + value.serialize(self) + } +} + +impl<T: NixWrite> NixWrite for &mut T { + type Error = T::Error; + + fn version(&self) -> ProtocolVersion { + (**self).version() + } + + fn write_number(&mut self, value: u64) -> impl Future<Output = Result<(), Self::Error>> + Send { + (**self).write_number(value) + } + + fn write_slice(&mut self, buf: &[u8]) -> impl Future<Output = Result<(), Self::Error>> + Send { + (**self).write_slice(buf) + } + + fn write_display<D>(&mut self, msg: D) -> impl Future<Output = Result<(), Self::Error>> + Send + where + D: fmt::Display + Send, + Self: Sized, + { + (**self).write_display(msg) + } + + fn write_value<V>(&mut self, value: &V) -> impl Future<Output = Result<(), Self::Error>> + Send + where + V: NixSerialize + Send + ?Sized, + Self: Sized, + { + (**self).write_value(value) + } +} + +pub trait NixSerialize { + /// Write a value to the writer. + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite; +} diff --git a/tvix/nix-compat/src/wire/ser/writer.rs b/tvix/nix-compat/src/wire/ser/writer.rs new file mode 100644 index 000000000000..da1c2b18c5e2 --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/writer.rs @@ -0,0 +1,306 @@ +use std::fmt::{self, Write as _}; +use std::future::poll_fn; +use std::io::{self, Cursor}; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use bytes::{Buf, BufMut, BytesMut}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +use crate::wire::{padding_len, ProtocolVersion, EMPTY_BYTES}; + +use super::{Error, NixWrite}; + +pub struct NixWriterBuilder { + buf: Option<BytesMut>, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, +} + +impl Default for NixWriterBuilder { + fn default() -> Self { + Self { + buf: Default::default(), + reserved_buf_size: 8192, + max_buf_size: 8192, + version: Default::default(), + } + } +} + +impl NixWriterBuilder { + pub fn set_buffer(mut self, buf: BytesMut) -> Self { + self.buf = Some(buf); + self + } + + pub fn set_reserved_buf_size(mut self, size: usize) -> Self { + self.reserved_buf_size = size; + self + } + + pub fn set_max_buf_size(mut self, size: usize) -> Self { + self.max_buf_size = size; + self + } + + pub fn set_version(mut self, version: ProtocolVersion) -> Self { + self.version = version; + self + } + + pub fn build<W>(self, writer: W) -> NixWriter<W> { + let buf = self + .buf + .unwrap_or_else(|| BytesMut::with_capacity(self.max_buf_size)); + NixWriter { + buf, + inner: writer, + reserved_buf_size: self.reserved_buf_size, + max_buf_size: self.max_buf_size, + version: self.version, + } + } +} + +pin_project! { + pub struct NixWriter<W> { + #[pin] + inner: W, + buf: BytesMut, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, + } +} + +impl NixWriter<Cursor<Vec<u8>>> { + pub fn builder() -> NixWriterBuilder { + NixWriterBuilder::default() + } +} + +impl<W> NixWriter<W> +where + W: AsyncWriteExt, +{ + pub fn new(writer: W) -> NixWriter<W> { + NixWriter::builder().build(writer) + } + + pub fn buffer(&self) -> &[u8] { + &self.buf[..] + } + + pub fn set_version(&mut self, version: ProtocolVersion) { + self.version = version; + } + + /// Remaining capacity in internal buffer + pub fn remaining_mut(&self) -> usize { + self.buf.capacity() - self.buf.len() + } + + fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + let mut this = self.project(); + while !this.buf.is_empty() { + let n = ready!(this.inner.as_mut().poll_write(cx, &this.buf[..]))?; + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write the buffer", + ))); + } + this.buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl<W> NixWriter<W> +where + W: AsyncWriteExt + Unpin, +{ + async fn flush_buf(&mut self) -> Result<(), io::Error> { + let mut s = Pin::new(self); + poll_fn(move |cx| s.as_mut().poll_flush_buf(cx)).await + } +} + +impl<W> AsyncWrite for NixWriter<W> +where + W: AsyncWrite, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + // Flush + if self.remaining_mut() < buf.len() { + ready!(self.as_mut().poll_flush_buf(cx))?; + } + let this = self.project(); + if buf.len() > this.buf.capacity() { + this.inner.poll_write(cx, buf) + } else { + this.buf.put_slice(buf); + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + ready!(self.as_mut().poll_flush_buf(cx))?; + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + ready!(self.as_mut().poll_flush_buf(cx))?; + self.project().inner.poll_shutdown(cx) + } +} + +impl<W> NixWrite for NixWriter<W> +where + W: AsyncWrite + Send + Unpin, +{ + type Error = io::Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { + let mut buf = [0u8; 8]; + BufMut::put_u64_le(&mut &mut buf[..], value); + self.write_all(&buf).await + } + + async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + let padding = padding_len(buf.len() as u64) as usize; + self.write_value(&buf.len()).await?; + self.write_all(buf).await?; + if padding > 0 { + self.write_all(&EMPTY_BYTES[..padding]).await + } else { + Ok(()) + } + } + + async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error> + where + D: fmt::Display + Send, + Self: Sized, + { + // Ensure that buffer has space for at least reserved_buf_size bytes + if self.remaining_mut() < self.reserved_buf_size && !self.buf.is_empty() { + self.flush_buf().await?; + } + let offset = self.buf.len(); + self.buf.put_u64_le(0); + if let Err(err) = write!(self.buf, "{}", msg) { + self.buf.truncate(offset); + return Err(Self::Error::unsupported_data(err)); + } + let len = self.buf.len() - offset - 8; + BufMut::put_u64_le(&mut &mut self.buf[offset..(offset + 8)], len as u64); + let padding = padding_len(len as u64) as usize; + self.write_all(&EMPTY_BYTES[..padding]).await + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::NixWrite; + + use super::NixWriter; + + #[rstest] + #[case(1, &hex!("0100 0000 0000 0000"))] + #[case::evil(666, &hex!("9A02 0000 0000 0000"))] + #[case::max(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_write_number(#[case] number: u64, #[case] buf: &[u8]) { + let mock = Builder::new().write(buf).build(); + let mut writer = NixWriter::new(mock); + + writer.write_number(number).await.unwrap(); + assert_eq!(writer.buffer(), buf); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } + + #[rstest] + #[case::empty(b"", &hex!("0000 0000 0000 0000"))] + #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_write_slice( + #[case] value: &[u8], + #[case] buf: &[u8], + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] buf_size: usize, + ) { + let mut builder = Builder::new(); + for chunk in buf.chunks(chunks_size) { + builder.write(chunk); + builder.wait(Duration::ZERO); + } + let mock = builder.build(); + let mut writer = NixWriter::builder().set_max_buf_size(buf_size).build(mock); + + writer.write_slice(value).await.unwrap(); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_write_display( + #[case] value: &str, + #[case] buf: &[u8], + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, + ) { + let mut builder = Builder::new(); + for chunk in buf.chunks(chunks_size) { + builder.write(chunk); + builder.wait(Duration::ZERO); + } + let mock = builder.build(); + let mut writer = NixWriter::builder().build(mock); + + writer.write_display(value).await.unwrap(); + assert_eq!(writer.buffer(), buf); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } +} |