diff options
Diffstat (limited to 'tvix/nix-compat/src/nix_daemon')
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/bytes.rs | 70 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/collections.rs | 105 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/int.rs | 100 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/mock.rs | 261 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/mod.rs | 225 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/reader.rs | 527 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/mod.rs | 6 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/protocol_version.rs | 139 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/ser/bytes.rs | 89 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/ser/collections.rs | 94 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/ser/display.rs | 8 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/ser/int.rs | 108 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/ser/mock.rs | 672 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/ser/mod.rs | 124 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/ser/writer.rs | 308 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/worker_protocol.rs | 2 |
16 files changed, 1 insertions, 2837 deletions
diff --git a/tvix/nix-compat/src/nix_daemon/de/bytes.rs b/tvix/nix-compat/src/nix_daemon/de/bytes.rs deleted file mode 100644 index 7daced54eef7..000000000000 --- a/tvix/nix-compat/src/nix_daemon/de/bytes.rs +++ /dev/null @@ -1,70 +0,0 @@ -use bytes::Bytes; - -use super::{Error, NixDeserialize, NixRead}; - -impl NixDeserialize for Bytes { - async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> - where - R: ?Sized + NixRead + Send, - { - reader.try_read_bytes().await - } -} - -impl NixDeserialize for String { - async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> - where - R: ?Sized + NixRead + Send, - { - if let Some(buf) = reader.try_read_bytes().await? { - String::from_utf8(buf.to_vec()) - .map_err(R::Error::invalid_data) - .map(Some) - } else { - Ok(None) - } - } -} - -#[cfg(test)] -mod test { - use std::io; - - use hex_literal::hex; - use rstest::rstest; - use tokio_test::io::Builder; - - use crate::nix_daemon::de::{NixRead, NixReader}; - - #[rstest] - #[case::empty("", &hex!("0000 0000 0000 0000"))] - #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] - #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] - #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] - #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] - #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] - #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] - #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] - #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] - #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] - #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] - #[tokio::test] - async fn test_read_string(#[case] expected: &str, #[case] data: &[u8]) { - let mock = Builder::new().read(data).build(); - let mut reader = NixReader::new(mock); - let actual: String = reader.read_value().await.unwrap(); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_read_string_invalid() { - let mock = Builder::new() - .read(&hex!("0300 0000 0000 0000 EDA0 8000 0000 0000")) - .build(); - let mut reader = NixReader::new(mock); - assert_eq!( - io::ErrorKind::InvalidData, - reader.read_value::<String>().await.unwrap_err().kind() - ); - } -} diff --git a/tvix/nix-compat/src/nix_daemon/de/collections.rs b/tvix/nix-compat/src/nix_daemon/de/collections.rs deleted file mode 100644 index cf79f584506a..000000000000 --- a/tvix/nix-compat/src/nix_daemon/de/collections.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::{collections::BTreeMap, future::Future}; - -use super::{NixDeserialize, NixRead}; - -#[allow(clippy::manual_async_fn)] -impl<T> NixDeserialize for Vec<T> -where - T: NixDeserialize + Send, -{ - fn try_deserialize<R>( - reader: &mut R, - ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ - where - R: ?Sized + NixRead + Send, - { - async move { - if let Some(len) = reader.try_read_value::<usize>().await? { - let mut ret = Vec::with_capacity(len); - for _ in 0..len { - ret.push(reader.read_value().await?); - } - Ok(Some(ret)) - } else { - Ok(None) - } - } - } -} - -#[allow(clippy::manual_async_fn)] -impl<K, V> NixDeserialize for BTreeMap<K, V> -where - K: NixDeserialize + Ord + Send, - V: NixDeserialize + Send, -{ - fn try_deserialize<R>( - reader: &mut R, - ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ - where - R: ?Sized + NixRead + Send, - { - async move { - if let Some(len) = reader.try_read_value::<usize>().await? { - let mut ret = BTreeMap::new(); - for _ in 0..len { - let key = reader.read_value().await?; - let value = reader.read_value().await?; - ret.insert(key, value); - } - Ok(Some(ret)) - } else { - Ok(None) - } - } - } -} - -#[cfg(test)] -mod test { - use std::collections::BTreeMap; - use std::fmt; - - use hex_literal::hex; - use rstest::rstest; - use tokio_test::io::Builder; - - use crate::nix_daemon::de::{NixDeserialize, NixRead, NixReader}; - - #[rstest] - #[case::empty(vec![], &hex!("0000 0000 0000 0000"))] - #[case::one(vec![0x29], &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] - #[case::two(vec![0x7469, 10], &hex!("0200 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] - #[tokio::test] - async fn test_read_small_vec(#[case] expected: Vec<usize>, #[case] data: &[u8]) { - let mock = Builder::new().read(data).build(); - let mut reader = NixReader::new(mock); - let actual: Vec<usize> = reader.read_value().await.unwrap(); - assert_eq!(actual, expected); - } - - fn empty_map() -> BTreeMap<usize, u64> { - BTreeMap::new() - } - macro_rules! map { - ($($key:expr => $value:expr),*) => {{ - let mut ret = BTreeMap::new(); - $(ret.insert($key, $value);)* - ret - }}; - } - - #[rstest] - #[case::empty(empty_map(), &hex!("0000 0000 0000 0000"))] - #[case::one(map![0x7469usize => 10u64], &hex!("0100 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] - #[tokio::test] - async fn test_read_small_btree_map<E>(#[case] expected: E, #[case] data: &[u8]) - where - E: NixDeserialize + PartialEq + fmt::Debug, - { - let mock = Builder::new().read(data).build(); - let mut reader = NixReader::new(mock); - let actual: E = reader.read_value().await.unwrap(); - assert_eq!(actual, expected); - } -} diff --git a/tvix/nix-compat/src/nix_daemon/de/int.rs b/tvix/nix-compat/src/nix_daemon/de/int.rs deleted file mode 100644 index eecf641cfe99..000000000000 --- a/tvix/nix-compat/src/nix_daemon/de/int.rs +++ /dev/null @@ -1,100 +0,0 @@ -use super::{Error, NixDeserialize, NixRead}; - -impl NixDeserialize for u64 { - async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> - where - R: ?Sized + NixRead + Send, - { - reader.try_read_number().await - } -} - -impl NixDeserialize for usize { - async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> - where - R: ?Sized + NixRead + Send, - { - if let Some(value) = reader.try_read_number().await? { - value.try_into().map_err(R::Error::invalid_data).map(Some) - } else { - Ok(None) - } - } -} - -impl NixDeserialize for bool { - async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> - where - R: ?Sized + NixRead + Send, - { - Ok(reader.try_read_number().await?.map(|v| v != 0)) - } -} -impl NixDeserialize for i64 { - async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> - where - R: ?Sized + NixRead + Send, - { - Ok(reader.try_read_number().await?.map(|v| v as i64)) - } -} - -#[cfg(test)] -mod test { - use hex_literal::hex; - use rstest::rstest; - use tokio_test::io::Builder; - - use crate::nix_daemon::de::{NixRead, NixReader}; - - #[rstest] - #[case::simple_false(false, &hex!("0000 0000 0000 0000"))] - #[case::simple_true(true, &hex!("0100 0000 0000 0000"))] - #[case::other_true(true, &hex!("1234 5600 0000 0000"))] - #[case::max_true(true, &hex!("FFFF FFFF FFFF FFFF"))] - #[tokio::test] - async fn test_read_bool(#[case] expected: bool, #[case] data: &[u8]) { - let mock = Builder::new().read(data).build(); - let mut reader = NixReader::new(mock); - let actual: bool = reader.read_value().await.unwrap(); - assert_eq!(actual, expected); - } - - #[rstest] - #[case::zero(0, &hex!("0000 0000 0000 0000"))] - #[case::one(1, &hex!("0100 0000 0000 0000"))] - #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] - #[case::max_value(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] - #[tokio::test] - async fn test_read_u64(#[case] expected: u64, #[case] data: &[u8]) { - let mock = Builder::new().read(data).build(); - let mut reader = NixReader::new(mock); - let actual: u64 = reader.read_value().await.unwrap(); - assert_eq!(actual, expected); - } - - #[rstest] - #[case::zero(0, &hex!("0000 0000 0000 0000"))] - #[case::one(1, &hex!("0100 0000 0000 0000"))] - #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] - #[case::max_value(usize::MAX, &usize::MAX.to_le_bytes())] - #[tokio::test] - async fn test_read_usize(#[case] expected: usize, #[case] data: &[u8]) { - let mock = Builder::new().read(data).build(); - let mut reader = NixReader::new(mock); - let actual: usize = reader.read_value().await.unwrap(); - assert_eq!(actual, expected); - } - - // FUTUREWORK: Test this on supported hardware - #[tokio::test] - #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] - async fn test_read_usize_overflow() { - let mock = Builder::new().read(&u64::MAX.to_le_bytes()).build(); - let mut reader = NixReader::new(mock); - assert_eq!( - std::io::ErrorKind::InvalidData, - reader.read_value::<usize>().await.unwrap_err().kind() - ); - } -} diff --git a/tvix/nix-compat/src/nix_daemon/de/mock.rs b/tvix/nix-compat/src/nix_daemon/de/mock.rs deleted file mode 100644 index 31cc3a4897ba..000000000000 --- a/tvix/nix-compat/src/nix_daemon/de/mock.rs +++ /dev/null @@ -1,261 +0,0 @@ -use std::collections::VecDeque; -use std::fmt; -use std::io; -use std::thread; - -use bytes::Bytes; -use thiserror::Error; - -use crate::nix_daemon::ProtocolVersion; - -use super::NixRead; - -#[derive(Debug, Error, PartialEq, Eq, Clone)] -pub enum Error { - #[error("custom error '{0}'")] - Custom(String), - #[error("invalid data '{0}'")] - InvalidData(String), - #[error("missing data '{0}'")] - MissingData(String), - #[error("IO error {0} '{1}'")] - IO(io::ErrorKind, String), - #[error("wrong read: expected {0} got {1}")] - WrongRead(OperationType, OperationType), -} - -impl Error { - pub fn expected_read_number() -> Error { - Error::WrongRead(OperationType::ReadNumber, OperationType::ReadBytes) - } - - pub fn expected_read_bytes() -> Error { - Error::WrongRead(OperationType::ReadBytes, OperationType::ReadNumber) - } -} - -impl super::Error for Error { - fn custom<T: fmt::Display>(msg: T) -> Self { - Self::Custom(msg.to_string()) - } - - fn io_error(err: std::io::Error) -> Self { - Self::IO(err.kind(), err.to_string()) - } - - fn invalid_data<T: fmt::Display>(msg: T) -> Self { - Self::InvalidData(msg.to_string()) - } - - fn missing_data<T: fmt::Display>(msg: T) -> Self { - Self::MissingData(msg.to_string()) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum OperationType { - ReadNumber, - ReadBytes, -} - -impl fmt::Display for OperationType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::ReadNumber => write!(f, "read_number"), - Self::ReadBytes => write!(f, "read_bytess"), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -enum Operation { - ReadNumber(Result<u64, Error>), - ReadBytes(Result<Bytes, Error>), -} - -impl From<Operation> for OperationType { - fn from(value: Operation) -> Self { - match value { - Operation::ReadNumber(_) => OperationType::ReadNumber, - Operation::ReadBytes(_) => OperationType::ReadBytes, - } - } -} - -pub struct Builder { - version: ProtocolVersion, - ops: VecDeque<Operation>, -} - -impl Builder { - pub fn new() -> Builder { - Builder { - version: Default::default(), - ops: VecDeque::new(), - } - } - - pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self { - self.version = version.into(); - self - } - - pub fn read_number(&mut self, value: u64) -> &mut Self { - self.ops.push_back(Operation::ReadNumber(Ok(value))); - self - } - - pub fn read_number_error(&mut self, err: Error) -> &mut Self { - self.ops.push_back(Operation::ReadNumber(Err(err))); - self - } - - pub fn read_bytes(&mut self, value: Bytes) -> &mut Self { - self.ops.push_back(Operation::ReadBytes(Ok(value))); - self - } - - pub fn read_slice(&mut self, data: &[u8]) -> &mut Self { - let value = Bytes::copy_from_slice(data); - self.ops.push_back(Operation::ReadBytes(Ok(value))); - self - } - - pub fn read_bytes_error(&mut self, err: Error) -> &mut Self { - self.ops.push_back(Operation::ReadBytes(Err(err))); - self - } - - pub fn build(&mut self) -> Mock { - Mock { - version: self.version, - ops: self.ops.clone(), - } - } -} - -impl Default for Builder { - fn default() -> Self { - Self::new() - } -} - -pub struct Mock { - version: ProtocolVersion, - ops: VecDeque<Operation>, -} - -impl NixRead for Mock { - type Error = Error; - - fn version(&self) -> ProtocolVersion { - self.version - } - - async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> { - match self.ops.pop_front() { - Some(Operation::ReadNumber(ret)) => ret.map(Some), - Some(Operation::ReadBytes(_)) => Err(Error::expected_read_bytes()), - None => Ok(None), - } - } - - async fn try_read_bytes_limited( - &mut self, - _limit: std::ops::RangeInclusive<usize>, - ) -> Result<Option<Bytes>, Self::Error> { - match self.ops.pop_front() { - Some(Operation::ReadBytes(ret)) => ret.map(Some), - Some(Operation::ReadNumber(_)) => Err(Error::expected_read_number()), - None => Ok(None), - } - } -} - -impl Drop for Mock { - fn drop(&mut self) { - // No need to panic again - if thread::panicking() { - return; - } - if let Some(op) = self.ops.front() { - panic!("reader dropped with {op:?} operation still unread") - } - } -} - -#[cfg(test)] -mod test { - use bytes::Bytes; - use hex_literal::hex; - - use crate::nix_daemon::de::NixRead; - - use super::{Builder, Error}; - - #[tokio::test] - async fn read_slice() { - let mut mock = Builder::new() - .read_number(10) - .read_slice(&[]) - .read_slice(&hex!("0000 1234 5678 9ABC DEFF")) - .build(); - assert_eq!(10, mock.read_number().await.unwrap()); - assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]); - assert_eq!( - &hex!("0000 1234 5678 9ABC DEFF"), - &mock.read_bytes().await.unwrap()[..] - ); - assert_eq!(None, mock.try_read_number().await.unwrap()); - assert_eq!(None, mock.try_read_bytes().await.unwrap()); - } - - #[tokio::test] - async fn read_bytes() { - let mut mock = Builder::new() - .read_number(10) - .read_bytes(Bytes::from_static(&[])) - .read_bytes(Bytes::from_static(&hex!("0000 1234 5678 9ABC DEFF"))) - .build(); - assert_eq!(10, mock.read_number().await.unwrap()); - assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]); - assert_eq!( - &hex!("0000 1234 5678 9ABC DEFF"), - &mock.read_bytes().await.unwrap()[..] - ); - assert_eq!(None, mock.try_read_number().await.unwrap()); - assert_eq!(None, mock.try_read_bytes().await.unwrap()); - } - - #[tokio::test] - async fn read_number() { - let mut mock = Builder::new().read_number(10).build(); - assert_eq!(10, mock.read_number().await.unwrap()); - assert_eq!(None, mock.try_read_number().await.unwrap()); - assert_eq!(None, mock.try_read_bytes().await.unwrap()); - } - - #[tokio::test] - async fn expect_number() { - let mut mock = Builder::new().read_number(10).build(); - assert_eq!( - Error::expected_read_number(), - mock.read_bytes().await.unwrap_err() - ); - } - - #[tokio::test] - async fn expect_bytes() { - let mut mock = Builder::new().read_slice(&[]).build(); - assert_eq!( - Error::expected_read_bytes(), - mock.read_number().await.unwrap_err() - ); - } - - #[test] - #[should_panic] - fn operations_left() { - let _ = Builder::new().read_number(10).build(); - } -} diff --git a/tvix/nix-compat/src/nix_daemon/de/mod.rs b/tvix/nix-compat/src/nix_daemon/de/mod.rs deleted file mode 100644 index f85ccd8fea0e..000000000000 --- a/tvix/nix-compat/src/nix_daemon/de/mod.rs +++ /dev/null @@ -1,225 +0,0 @@ -use std::error::Error as StdError; -use std::future::Future; -use std::ops::RangeInclusive; -use std::{fmt, io}; - -use ::bytes::Bytes; - -use super::ProtocolVersion; - -mod bytes; -mod collections; -mod int; -#[cfg(any(test, feature = "test"))] -pub mod mock; -mod reader; - -pub use reader::{NixReader, NixReaderBuilder}; - -/// Like serde the `Error` trait allows `NixRead` implementations to add -/// custom error handling for `NixDeserialize`. -pub trait Error: Sized + StdError { - /// A totally custom non-specific error. - fn custom<T: fmt::Display>(msg: T) -> Self; - - /// Some kind of std::io::Error occured. - fn io_error(err: std::io::Error) -> Self { - Self::custom(format_args!("There was an I/O error {}", err)) - } - - /// The data read from `NixRead` is invalid. - /// This could be that some bytes were supposed to be valid UFT-8 but weren't. - fn invalid_data<T: fmt::Display>(msg: T) -> Self { - Self::custom(msg) - } - - /// Required data is missing. This is mostly like an EOF - fn missing_data<T: fmt::Display>(msg: T) -> Self { - Self::custom(msg) - } -} - -impl Error for io::Error { - fn custom<T: fmt::Display>(msg: T) -> Self { - io::Error::new(io::ErrorKind::Other, msg.to_string()) - } - - fn io_error(err: std::io::Error) -> Self { - err - } - - fn invalid_data<T: fmt::Display>(msg: T) -> Self { - io::Error::new(io::ErrorKind::InvalidData, msg.to_string()) - } - - fn missing_data<T: fmt::Display>(msg: T) -> Self { - io::Error::new(io::ErrorKind::UnexpectedEof, msg.to_string()) - } -} - -/// A reader of data from the Nix daemon protocol. -/// Basically there are two basic types in the Nix daemon protocol -/// u64 and a bytes buffer. Everything else is more or less built on -/// top of these two types. -pub trait NixRead: Send { - type Error: Error + Send; - - /// Some types are serialized differently depending on the version - /// of the protocol and so this can be used for implementing that. - fn version(&self) -> ProtocolVersion; - - /// Read a single u64 from the protocol. - /// This returns an Option to support graceful shutdown. - fn try_read_number( - &mut self, - ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + '_; - - /// Read bytes from the protocol. - /// A size limit on the returned bytes has to be specified. - /// This returns an Option to support graceful shutdown. - fn try_read_bytes_limited( - &mut self, - limit: RangeInclusive<usize>, - ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_; - - /// Read bytes from the protocol without a limit. - /// The default implementation just calls `try_read_bytes_limited` with a - /// limit of `0..=usize::MAX` but other implementations are free to have a - /// reader wide limit. - /// This returns an Option to support graceful shutdown. - fn try_read_bytes( - &mut self, - ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { - self.try_read_bytes_limited(0..=usize::MAX) - } - - /// Read a single u64 from the protocol. - /// This will return an error if the number could not be read. - fn read_number(&mut self) -> impl Future<Output = Result<u64, Self::Error>> + Send + '_ { - async move { - match self.try_read_number().await? { - Some(v) => Ok(v), - None => Err(Self::Error::missing_data("unexpected end-of-file")), - } - } - } - - /// Read bytes from the protocol. - /// A size limit on the returned bytes has to be specified. - /// This will return an error if the number could not be read. - fn read_bytes_limited( - &mut self, - limit: RangeInclusive<usize>, - ) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { - async move { - match self.try_read_bytes_limited(limit).await? { - Some(v) => Ok(v), - None => Err(Self::Error::missing_data("unexpected end-of-file")), - } - } - } - - /// Read bytes from the protocol. - /// The default implementation just calls `read_bytes_limited` with a - /// limit of `0..=usize::MAX` but other implementations are free to have a - /// reader wide limit. - /// This will return an error if the bytes could not be read. - fn read_bytes(&mut self) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { - self.read_bytes_limited(0..=usize::MAX) - } - - /// Read a value from the protocol. - /// Uses `NixDeserialize::deserialize` to read a value. - fn read_value<V: NixDeserialize>( - &mut self, - ) -> impl Future<Output = Result<V, Self::Error>> + Send + '_ { - V::deserialize(self) - } - - /// Read a value from the protocol. - /// Uses `NixDeserialize::try_deserialize` to read a value. - /// This returns an Option to support graceful shutdown. - fn try_read_value<V: NixDeserialize>( - &mut self, - ) -> impl Future<Output = Result<Option<V>, Self::Error>> + Send + '_ { - V::try_deserialize(self) - } -} - -impl<T: ?Sized + NixRead> NixRead for &mut T { - type Error = T::Error; - - fn version(&self) -> ProtocolVersion { - (**self).version() - } - - fn try_read_number( - &mut self, - ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + '_ { - (**self).try_read_number() - } - - fn try_read_bytes_limited( - &mut self, - limit: RangeInclusive<usize>, - ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { - (**self).try_read_bytes_limited(limit) - } - - fn try_read_bytes( - &mut self, - ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { - (**self).try_read_bytes() - } - - fn read_number(&mut self) -> impl Future<Output = Result<u64, Self::Error>> + Send + '_ { - (**self).read_number() - } - - fn read_bytes_limited( - &mut self, - limit: RangeInclusive<usize>, - ) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { - (**self).read_bytes_limited(limit) - } - - fn read_bytes(&mut self) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { - (**self).read_bytes() - } - - fn try_read_value<V: NixDeserialize>( - &mut self, - ) -> impl Future<Output = Result<Option<V>, Self::Error>> + Send + '_ { - (**self).try_read_value() - } - - fn read_value<V: NixDeserialize>( - &mut self, - ) -> impl Future<Output = Result<V, Self::Error>> + Send + '_ { - (**self).read_value() - } -} - -/// A data structure that can be deserialized from the Nix daemon -/// worker protocol. -pub trait NixDeserialize: Sized { - /// Read a value from the reader. - /// This returns an Option to support gracefull shutdown. - fn try_deserialize<R>( - reader: &mut R, - ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ - where - R: ?Sized + NixRead + Send; - - fn deserialize<R>(reader: &mut R) -> impl Future<Output = Result<Self, R::Error>> + Send + '_ - where - R: ?Sized + NixRead + Send, - { - async move { - match Self::try_deserialize(reader).await? { - Some(v) => Ok(v), - None => Err(R::Error::missing_data("unexpected end-of-file")), - } - } - } -} diff --git a/tvix/nix-compat/src/nix_daemon/de/reader.rs b/tvix/nix-compat/src/nix_daemon/de/reader.rs deleted file mode 100644 index 87c623b2220c..000000000000 --- a/tvix/nix-compat/src/nix_daemon/de/reader.rs +++ /dev/null @@ -1,527 +0,0 @@ -use std::future::poll_fn; -use std::io::{self, Cursor}; -use std::ops::RangeInclusive; -use std::pin::Pin; -use std::task::{ready, Context, Poll}; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use pin_project_lite::pin_project; -use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf}; - -use crate::nix_daemon::ProtocolVersion; -use crate::wire::EMPTY_BYTES; - -use super::{Error, NixRead}; - -pub struct NixReaderBuilder { - buf: Option<BytesMut>, - reserved_buf_size: usize, - max_buf_size: usize, - version: ProtocolVersion, -} - -impl Default for NixReaderBuilder { - fn default() -> Self { - Self { - buf: Default::default(), - reserved_buf_size: 8192, - max_buf_size: 8192, - version: Default::default(), - } - } -} - -impl NixReaderBuilder { - pub fn set_buffer(mut self, buf: BytesMut) -> Self { - self.buf = Some(buf); - self - } - - pub fn set_reserved_buf_size(mut self, size: usize) -> Self { - self.reserved_buf_size = size; - self - } - - pub fn set_max_buf_size(mut self, size: usize) -> Self { - self.max_buf_size = size; - self - } - - pub fn set_version(mut self, version: ProtocolVersion) -> Self { - self.version = version; - self - } - - pub fn build<R>(self, reader: R) -> NixReader<R> { - let buf = self.buf.unwrap_or_else(|| BytesMut::with_capacity(0)); - NixReader { - buf, - inner: reader, - reserved_buf_size: self.reserved_buf_size, - max_buf_size: self.max_buf_size, - version: self.version, - } - } -} - -pin_project! { - pub struct NixReader<R> { - #[pin] - inner: R, - buf: BytesMut, - reserved_buf_size: usize, - max_buf_size: usize, - version: ProtocolVersion, - } -} - -impl NixReader<Cursor<Vec<u8>>> { - pub fn builder() -> NixReaderBuilder { - NixReaderBuilder::default() - } -} - -impl<R> NixReader<R> -where - R: AsyncReadExt, -{ - pub fn new(reader: R) -> NixReader<R> { - NixReader::builder().build(reader) - } - - pub fn buffer(&self) -> &[u8] { - &self.buf[..] - } - - #[cfg(test)] - pub(crate) fn buffer_mut(&mut self) -> &mut BytesMut { - &mut self.buf - } - - /// Remaining capacity in internal buffer - pub fn remaining_mut(&self) -> usize { - self.buf.capacity() - self.buf.len() - } - - fn poll_force_fill_buf( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<io::Result<usize>> { - // Ensure that buffer has space for at least reserved_buf_size bytes - if self.remaining_mut() < self.reserved_buf_size { - let me = self.as_mut().project(); - me.buf.reserve(*me.reserved_buf_size); - } - let me = self.project(); - let n = { - let dst = me.buf.spare_capacity_mut(); - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - ready!(me.inner.poll_read(cx, &mut buf)?); - - // Ensure the pointer does not change from under us - assert_eq!(ptr, buf.filled().as_ptr()); - buf.filled().len() - }; - - // SAFETY: This is guaranteed to be the number of initialized (and read) - // bytes due to the invariants provided by `ReadBuf::filled`. - unsafe { - me.buf.advance_mut(n); - } - Poll::Ready(Ok(n)) - } -} - -impl<R> NixReader<R> -where - R: AsyncReadExt + Unpin, -{ - async fn force_fill(&mut self) -> io::Result<usize> { - let mut p = Pin::new(self); - let read = poll_fn(|cx| p.as_mut().poll_force_fill_buf(cx)).await?; - Ok(read) - } -} - -impl<R> NixRead for NixReader<R> -where - R: AsyncReadExt + Send + Unpin, -{ - type Error = io::Error; - - fn version(&self) -> ProtocolVersion { - self.version - } - - async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> { - let mut buf = [0u8; 8]; - let read = self.read_buf(&mut &mut buf[..]).await?; - if read == 0 { - return Ok(None); - } - if read < 8 { - self.read_exact(&mut buf[read..]).await?; - } - let num = Buf::get_u64_le(&mut &buf[..]); - Ok(Some(num)) - } - - async fn try_read_bytes_limited( - &mut self, - limit: RangeInclusive<usize>, - ) -> Result<Option<Bytes>, Self::Error> { - assert!( - *limit.end() <= self.max_buf_size, - "The limit must be smaller than {}", - self.max_buf_size - ); - match self.try_read_number().await? { - Some(raw_len) => { - // Check that length is in range and convert to usize - let len = raw_len - .try_into() - .ok() - .filter(|v| limit.contains(v)) - .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))?; - - // Calculate 64bit aligned length and convert to usize - let aligned: usize = raw_len - .checked_add(7) - .map(|v| v & !7) - .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))? - .try_into() - .map_err(Self::Error::invalid_data)?; - - // Ensure that there is enough space in buffer for contents - if self.buf.len() + self.remaining_mut() < aligned { - self.buf.reserve(aligned - self.buf.len()); - } - while self.buf.len() < aligned { - if self.force_fill().await? == 0 { - return Err(Self::Error::missing_data( - "unexpected end-of-file reading bytes", - )); - } - } - let mut contents = self.buf.split_to(aligned); - - let padding = aligned - len; - // Ensure padding is all zeros - if contents[len..] != EMPTY_BYTES[..padding] { - return Err(Self::Error::invalid_data("non-zero padding")); - } - - contents.truncate(len); - Ok(Some(contents.freeze())) - } - None => Ok(None), - } - } - - fn try_read_bytes( - &mut self, - ) -> impl std::future::Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { - self.try_read_bytes_limited(0..=self.max_buf_size) - } - - fn read_bytes( - &mut self, - ) -> impl std::future::Future<Output = Result<Bytes, Self::Error>> + Send + '_ { - self.read_bytes_limited(0..=self.max_buf_size) - } -} - -impl<R: AsyncRead> AsyncRead for NixReader<R> { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll<io::Result<()>> { - let rem = ready!(self.as_mut().poll_fill_buf(cx))?; - let amt = std::cmp::min(rem.len(), buf.remaining()); - buf.put_slice(&rem[0..amt]); - self.consume(amt); - Poll::Ready(Ok(())) - } -} - -impl<R: AsyncRead> AsyncBufRead for NixReader<R> { - fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { - if self.as_ref().project_ref().buf.is_empty() { - ready!(self.as_mut().poll_force_fill_buf(cx))?; - } - let me = self.project(); - Poll::Ready(Ok(&me.buf[..])) - } - - fn consume(self: Pin<&mut Self>, amt: usize) { - let me = self.project(); - me.buf.advance(amt) - } -} - -#[cfg(test)] -mod test { - use std::time::Duration; - - use hex_literal::hex; - use rstest::rstest; - use tokio_test::io::Builder; - - use super::*; - use crate::nix_daemon::de::NixRead; - - #[tokio::test] - async fn test_read_u64() { - let mock = Builder::new().read(&hex!("0100 0000 0000 0000")).build(); - let mut reader = NixReader::new(mock); - - assert_eq!(1, reader.read_number().await.unwrap()); - assert_eq!(hex!(""), reader.buffer()); - - let mut buf = Vec::new(); - reader.read_to_end(&mut buf).await.unwrap(); - assert_eq!(hex!(""), &buf[..]); - } - - #[tokio::test] - async fn test_read_u64_rest() { - let mock = Builder::new() - .read(&hex!("0100 0000 0000 0000 0123 4567 89AB CDEF")) - .build(); - let mut reader = NixReader::new(mock); - - assert_eq!(1, reader.read_number().await.unwrap()); - assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer()); - - let mut buf = Vec::new(); - reader.read_to_end(&mut buf).await.unwrap(); - assert_eq!(hex!("0123 4567 89AB CDEF"), &buf[..]); - } - - #[tokio::test] - async fn test_read_u64_partial() { - let mock = Builder::new() - .read(&hex!("0100 0000")) - .wait(Duration::ZERO) - .read(&hex!("0000 0000 0123 4567 89AB CDEF")) - .wait(Duration::ZERO) - .read(&hex!("0100 0000")) - .build(); - let mut reader = NixReader::new(mock); - - assert_eq!(1, reader.read_number().await.unwrap()); - assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer()); - - let mut buf = Vec::new(); - reader.read_to_end(&mut buf).await.unwrap(); - assert_eq!(hex!("0123 4567 89AB CDEF 0100 0000"), &buf[..]); - } - - #[tokio::test] - async fn test_read_u64_eof() { - let mock = Builder::new().build(); - let mut reader = NixReader::new(mock); - - assert_eq!( - io::ErrorKind::UnexpectedEof, - reader.read_number().await.unwrap_err().kind() - ); - } - - #[tokio::test] - async fn test_try_read_u64_none() { - let mock = Builder::new().build(); - let mut reader = NixReader::new(mock); - - assert_eq!(None, reader.try_read_number().await.unwrap()); - } - - #[tokio::test] - async fn test_try_read_u64_eof() { - let mock = Builder::new().read(&hex!("0100 0000 0000")).build(); - let mut reader = NixReader::new(mock); - - assert_eq!( - io::ErrorKind::UnexpectedEof, - reader.try_read_number().await.unwrap_err().kind() - ); - } - - #[tokio::test] - async fn test_try_read_u64_eof2() { - let mock = Builder::new() - .read(&hex!("0100")) - .wait(Duration::ZERO) - .read(&hex!("0000 0000")) - .build(); - let mut reader = NixReader::new(mock); - - assert_eq!( - io::ErrorKind::UnexpectedEof, - reader.try_read_number().await.unwrap_err().kind() - ); - } - - #[rstest] - #[case::empty(b"", &hex!("0000 0000 0000 0000"))] - #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] - #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] - #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] - #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] - #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] - #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] - #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] - #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] - #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] - #[tokio::test] - async fn test_read_bytes(#[case] expected: &[u8], #[case] data: &[u8]) { - let mock = Builder::new().read(data).build(); - let mut reader = NixReader::new(mock); - let actual = reader.read_bytes().await.unwrap(); - assert_eq!(&actual[..], expected); - } - - #[tokio::test] - async fn test_read_bytes_empty() { - let mock = Builder::new().build(); - let mut reader = NixReader::new(mock); - - assert_eq!( - io::ErrorKind::UnexpectedEof, - reader.read_bytes().await.unwrap_err().kind() - ); - } - - #[tokio::test] - async fn test_try_read_bytes_none() { - let mock = Builder::new().build(); - let mut reader = NixReader::new(mock); - - assert_eq!(None, reader.try_read_bytes().await.unwrap()); - } - - #[tokio::test] - async fn test_try_read_bytes_missing_data() { - let mock = Builder::new() - .read(&hex!("0500")) - .wait(Duration::ZERO) - .read(&hex!("0000 0000")) - .build(); - let mut reader = NixReader::new(mock); - - assert_eq!( - io::ErrorKind::UnexpectedEof, - reader.try_read_bytes().await.unwrap_err().kind() - ); - } - - #[tokio::test] - async fn test_try_read_bytes_missing_padding() { - let mock = Builder::new() - .read(&hex!("0200 0000 0000 0000")) - .wait(Duration::ZERO) - .read(&hex!("1234")) - .build(); - let mut reader = NixReader::new(mock); - - assert_eq!( - io::ErrorKind::UnexpectedEof, - reader.try_read_bytes().await.unwrap_err().kind() - ); - } - - #[tokio::test] - async fn test_read_bytes_bad_padding() { - let mock = Builder::new() - .read(&hex!("0200 0000 0000 0000")) - .wait(Duration::ZERO) - .read(&hex!("1234 0100 0000 0000")) - .build(); - let mut reader = NixReader::new(mock); - - assert_eq!( - io::ErrorKind::InvalidData, - reader.read_bytes().await.unwrap_err().kind() - ); - } - - #[tokio::test] - async fn test_read_bytes_limited_out_of_range() { - let mock = Builder::new().read(&hex!("FFFF 0000 0000 0000")).build(); - let mut reader = NixReader::new(mock); - - assert_eq!( - io::ErrorKind::InvalidData, - reader.read_bytes_limited(0..=50).await.unwrap_err().kind() - ); - } - - #[tokio::test] - async fn test_read_bytes_length_overflow() { - let mock = Builder::new().read(&hex!("F9FF FFFF FFFF FFFF")).build(); - let mut reader = NixReader::builder() - .set_max_buf_size(usize::MAX) - .build(mock); - - assert_eq!( - io::ErrorKind::InvalidData, - reader - .read_bytes_limited(0..=usize::MAX) - .await - .unwrap_err() - .kind() - ); - } - - // FUTUREWORK: Test this on supported hardware - #[tokio::test] - #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] - async fn test_bytes_length_conversion_overflow() { - let len = (usize::MAX as u64) + 1; - let mock = Builder::new().read(&len.to_le_bytes()).build(); - let mut reader = NixReader::new(mock); - assert_eq!( - std::io::ErrorKind::InvalidData, - reader.read_value::<usize>().await.unwrap_err().kind() - ); - } - - // FUTUREWORK: Test this on supported hardware - #[tokio::test] - #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] - async fn test_bytes_aligned_length_conversion_overflow() { - let len = (usize::MAX - 6) as u64; - let mock = Builder::new().read(&len.to_le_bytes()).build(); - let mut reader = NixReader::new(mock); - assert_eq!( - std::io::ErrorKind::InvalidData, - reader.read_value::<usize>().await.unwrap_err().kind() - ); - } - - #[tokio::test] - async fn test_buffer_resize() { - let mock = Builder::new() - .read(&hex!("0100")) - .read(&hex!("0000 0000 0000")) - .build(); - let mut reader = NixReader::builder().set_reserved_buf_size(8).build(mock); - // buffer has no capacity initially - assert_eq!(0, reader.buffer_mut().capacity()); - - assert_eq!(2, reader.force_fill().await.unwrap()); - - // After first read buffer should have capacity we chose - assert_eq!(8, reader.buffer_mut().capacity()); - - // Because there was only 6 bytes remaining in buffer, - // which is enough to read the last 6 bytes, but we require - // capacity for 8 bytes, it doubled the capacity - assert_eq!(6, reader.force_fill().await.unwrap()); - assert_eq!(16, reader.buffer_mut().capacity()); - - assert_eq!(1, reader.read_number().await.unwrap()); - } -} diff --git a/tvix/nix-compat/src/nix_daemon/mod.rs b/tvix/nix-compat/src/nix_daemon/mod.rs index a943b279f891..633fdbebd47c 100644 --- a/tvix/nix-compat/src/nix_daemon/mod.rs +++ b/tvix/nix-compat/src/nix_daemon/mod.rs @@ -1,7 +1 @@ pub mod worker_protocol; - -mod protocol_version; -pub use protocol_version::ProtocolVersion; - -pub mod de; -pub mod ser; diff --git a/tvix/nix-compat/src/nix_daemon/protocol_version.rs b/tvix/nix-compat/src/nix_daemon/protocol_version.rs deleted file mode 100644 index 19da28d484dd..000000000000 --- a/tvix/nix-compat/src/nix_daemon/protocol_version.rs +++ /dev/null @@ -1,139 +0,0 @@ -/// The latest version that is currently supported by nix-compat. -static DEFAULT_PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::from_parts(1, 37); - -/// Protocol versions are represented as a u16. -/// The upper 8 bits are the major version, the lower bits the minor. -/// This is not aware of any endianness, use [crate::wire::read_u64] to get an -/// u64 first, and the try_from() impl from here if you're receiving over the -/// Nix Worker protocol. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub struct ProtocolVersion(u16); - -impl ProtocolVersion { - pub const fn from_parts(major: u8, minor: u8) -> Self { - Self(((major as u16) << 8) | minor as u16) - } - - pub fn major(&self) -> u8 { - ((self.0 & 0xff00) >> 8) as u8 - } - - pub fn minor(&self) -> u8 { - (self.0 & 0x00ff) as u8 - } -} - -impl Default for ProtocolVersion { - fn default() -> Self { - DEFAULT_PROTOCOL_VERSION - } -} - -impl PartialOrd for ProtocolVersion { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - Some(self.cmp(other)) - } -} - -impl Ord for ProtocolVersion { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - match self.major().cmp(&other.major()) { - std::cmp::Ordering::Less => std::cmp::Ordering::Less, - std::cmp::Ordering::Greater => std::cmp::Ordering::Greater, - std::cmp::Ordering::Equal => { - // same major, compare minor - self.minor().cmp(&other.minor()) - } - } - } -} - -impl From<u16> for ProtocolVersion { - fn from(value: u16) -> Self { - Self::from_parts(((value & 0xff00) >> 8) as u8, (value & 0x00ff) as u8) - } -} - -#[cfg(any(test, feature = "test"))] -impl From<(u8, u8)> for ProtocolVersion { - fn from((major, minor): (u8, u8)) -> Self { - Self::from_parts(major, minor) - } -} - -impl TryFrom<u64> for ProtocolVersion { - type Error = &'static str; - - fn try_from(value: u64) -> Result<Self, Self::Error> { - if value & !0xffff != 0 { - return Err("only two least significant bits might be populated"); - } - - Ok((value as u16).into()) - } -} - -impl From<ProtocolVersion> for u16 { - fn from(value: ProtocolVersion) -> Self { - value.0 - } -} - -impl From<ProtocolVersion> for u64 { - fn from(value: ProtocolVersion) -> Self { - value.0 as u64 - } -} - -impl std::fmt::Display for ProtocolVersion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}.{}", self.major(), self.minor()) - } -} - -#[cfg(test)] -mod tests { - use super::ProtocolVersion; - - #[test] - fn from_parts() { - let version = ProtocolVersion::from_parts(1, 37); - assert_eq!(version.major(), 1, "correct major"); - assert_eq!(version.minor(), 37, "correct minor"); - assert_eq!("1.37", &version.to_string(), "to_string"); - - assert_eq!(0x0125, Into::<u16>::into(version)); - assert_eq!(0x0125, Into::<u64>::into(version)); - } - - #[test] - fn from_u16() { - let version = ProtocolVersion::from(0x0125_u16); - assert_eq!("1.37", &version.to_string()); - } - - #[test] - fn from_u64() { - let version = ProtocolVersion::try_from(0x0125_u64).expect("must succeed"); - assert_eq!("1.37", &version.to_string()); - } - - /// This contains data in higher bits, which should fail. - #[test] - fn from_u64_fail() { - ProtocolVersion::try_from(0xaa0125_u64).expect_err("must fail"); - } - - #[test] - fn ord() { - let v0_37 = ProtocolVersion::from_parts(0, 37); - let v1_37 = ProtocolVersion::from_parts(1, 37); - let v1_40 = ProtocolVersion::from_parts(1, 40); - - assert!(v0_37 < v1_37); - assert!(v1_37 > v0_37); - assert!(v1_37 < v1_40); - assert!(v1_40 > v1_37); - assert!(v1_40 <= v1_40); - } -} diff --git a/tvix/nix-compat/src/nix_daemon/ser/bytes.rs b/tvix/nix-compat/src/nix_daemon/ser/bytes.rs deleted file mode 100644 index 19494934ff32..000000000000 --- a/tvix/nix-compat/src/nix_daemon/ser/bytes.rs +++ /dev/null @@ -1,89 +0,0 @@ -use bytes::Bytes; - -use super::{NixSerialize, NixWrite}; - -impl NixSerialize for Bytes { - async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> - where - W: NixWrite, - { - writer.write_slice(self).await - } -} - -impl<'a> NixSerialize for &'a [u8] { - async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> - where - W: NixWrite, - { - writer.write_slice(self).await - } -} - -impl NixSerialize for String { - async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> - where - W: NixWrite, - { - writer.write_slice(self.as_bytes()).await - } -} - -impl NixSerialize for str { - async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> - where - W: NixWrite, - { - writer.write_slice(self.as_bytes()).await - } -} - -#[cfg(test)] -mod test { - use hex_literal::hex; - use rstest::rstest; - use tokio::io::AsyncWriteExt as _; - use tokio_test::io::Builder; - - use crate::nix_daemon::ser::{NixWrite, NixWriter}; - - #[rstest] - #[case::empty("", &hex!("0000 0000 0000 0000"))] - #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] - #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] - #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] - #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] - #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] - #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] - #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] - #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] - #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] - #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] - #[tokio::test] - async fn test_write_str(#[case] value: &str, #[case] data: &[u8]) { - let mock = Builder::new().write(data).build(); - let mut writer = NixWriter::new(mock); - writer.write_value(value).await.unwrap(); - writer.flush().await.unwrap(); - } - - #[rstest] - #[case::empty("", &hex!("0000 0000 0000 0000"))] - #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] - #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] - #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] - #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] - #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] - #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] - #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] - #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] - #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] - #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] - #[tokio::test] - async fn test_write_string(#[case] value: &str, #[case] data: &[u8]) { - let mock = Builder::new().write(data).build(); - let mut writer = NixWriter::new(mock); - writer.write_value(&value.to_string()).await.unwrap(); - writer.flush().await.unwrap(); - } -} diff --git a/tvix/nix-compat/src/nix_daemon/ser/collections.rs b/tvix/nix-compat/src/nix_daemon/ser/collections.rs deleted file mode 100644 index 70c32e1c79ac..000000000000 --- a/tvix/nix-compat/src/nix_daemon/ser/collections.rs +++ /dev/null @@ -1,94 +0,0 @@ -use std::collections::BTreeMap; -use std::future::Future; - -use super::{NixSerialize, NixWrite}; - -impl<T> NixSerialize for Vec<T> -where - T: NixSerialize + Send + Sync, -{ - #[allow(clippy::manual_async_fn)] - fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send - where - W: NixWrite, - { - async move { - writer.write_value(&self.len()).await?; - for value in self.iter() { - writer.write_value(value).await?; - } - Ok(()) - } - } -} - -impl<K, V> NixSerialize for BTreeMap<K, V> -where - K: NixSerialize + Ord + Send + Sync, - V: NixSerialize + Send + Sync, -{ - #[allow(clippy::manual_async_fn)] - fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send - where - W: NixWrite, - { - async move { - writer.write_value(&self.len()).await?; - for (key, value) in self.iter() { - writer.write_value(key).await?; - writer.write_value(value).await?; - } - Ok(()) - } - } -} - -#[cfg(test)] -mod test { - use std::collections::BTreeMap; - use std::fmt; - - use hex_literal::hex; - use rstest::rstest; - use tokio::io::AsyncWriteExt as _; - use tokio_test::io::Builder; - - use crate::nix_daemon::ser::{NixSerialize, NixWrite, NixWriter}; - - #[rstest] - #[case::empty(vec![], &hex!("0000 0000 0000 0000"))] - #[case::one(vec![0x29], &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] - #[case::two(vec![0x7469, 10], &hex!("0200 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] - #[tokio::test] - async fn test_write_small_vec(#[case] value: Vec<usize>, #[case] data: &[u8]) { - let mock = Builder::new().write(data).build(); - let mut writer = NixWriter::new(mock); - writer.write_value(&value).await.unwrap(); - writer.flush().await.unwrap(); - } - - fn empty_map() -> BTreeMap<usize, u64> { - BTreeMap::new() - } - macro_rules! map { - ($($key:expr => $value:expr),*) => {{ - let mut ret = BTreeMap::new(); - $(ret.insert($key, $value);)* - ret - }}; - } - - #[rstest] - #[case::empty(empty_map(), &hex!("0000 0000 0000 0000"))] - #[case::one(map![0x7469usize => 10u64], &hex!("0100 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] - #[tokio::test] - async fn test_write_small_btree_map<E>(#[case] value: E, #[case] data: &[u8]) - where - E: NixSerialize + Send + PartialEq + fmt::Debug, - { - let mock = Builder::new().write(data).build(); - let mut writer = NixWriter::new(mock); - writer.write_value(&value).await.unwrap(); - writer.flush().await.unwrap(); - } -} diff --git a/tvix/nix-compat/src/nix_daemon/ser/display.rs b/tvix/nix-compat/src/nix_daemon/ser/display.rs deleted file mode 100644 index a3438d50d8ff..000000000000 --- a/tvix/nix-compat/src/nix_daemon/ser/display.rs +++ /dev/null @@ -1,8 +0,0 @@ -use nix_compat_derive::nix_serialize_remote; - -use crate::nixhash; - -nix_serialize_remote!( - #[nix(display)] - nixhash::HashAlgo -); diff --git a/tvix/nix-compat/src/nix_daemon/ser/int.rs b/tvix/nix-compat/src/nix_daemon/ser/int.rs deleted file mode 100644 index 1be06442e322..000000000000 --- a/tvix/nix-compat/src/nix_daemon/ser/int.rs +++ /dev/null @@ -1,108 +0,0 @@ -#[cfg(feature = "nix-compat-derive")] -use nix_compat_derive::nix_serialize_remote; - -use super::{Error, NixSerialize, NixWrite}; - -impl NixSerialize for u64 { - async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> - where - W: NixWrite, - { - writer.write_number(*self).await - } -} - -impl NixSerialize for usize { - async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> - where - W: NixWrite, - { - let v = (*self).try_into().map_err(W::Error::unsupported_data)?; - writer.write_number(v).await - } -} - -#[cfg(feature = "nix-compat-derive")] -nix_serialize_remote!( - #[nix(into = "u64")] - u8 -); -#[cfg(feature = "nix-compat-derive")] -nix_serialize_remote!( - #[nix(into = "u64")] - u16 -); -#[cfg(feature = "nix-compat-derive")] -nix_serialize_remote!( - #[nix(into = "u64")] - u32 -); - -impl NixSerialize for bool { - async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> - where - W: NixWrite, - { - if *self { - writer.write_number(1).await - } else { - writer.write_number(0).await - } - } -} - -impl NixSerialize for i64 { - async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> - where - W: NixWrite, - { - writer.write_number(*self as u64).await - } -} - -#[cfg(test)] -mod test { - use hex_literal::hex; - use rstest::rstest; - use tokio::io::AsyncWriteExt as _; - use tokio_test::io::Builder; - - use crate::nix_daemon::ser::{NixWrite, NixWriter}; - - #[rstest] - #[case::simple_false(false, &hex!("0000 0000 0000 0000"))] - #[case::simple_true(true, &hex!("0100 0000 0000 0000"))] - #[tokio::test] - async fn test_write_bool(#[case] value: bool, #[case] expected: &[u8]) { - let mock = Builder::new().write(expected).build(); - let mut writer = NixWriter::new(mock); - writer.write_value(&value).await.unwrap(); - writer.flush().await.unwrap(); - } - - #[rstest] - #[case::zero(0, &hex!("0000 0000 0000 0000"))] - #[case::one(1, &hex!("0100 0000 0000 0000"))] - #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] - #[case::max_value(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] - #[tokio::test] - async fn test_write_u64(#[case] value: u64, #[case] expected: &[u8]) { - let mock = Builder::new().write(expected).build(); - let mut writer = NixWriter::new(mock); - writer.write_value(&value).await.unwrap(); - writer.flush().await.unwrap(); - } - - #[rstest] - #[case::zero(0, &hex!("0000 0000 0000 0000"))] - #[case::one(1, &hex!("0100 0000 0000 0000"))] - #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] - #[case::max_value(usize::MAX, &usize::MAX.to_le_bytes())] - #[tokio::test] - async fn test_write_usize(#[case] value: usize, #[case] expected: &[u8]) { - let mock = Builder::new().write(expected).build(); - let mut writer = NixWriter::new(mock); - writer.write_value(&value).await.unwrap(); - writer.flush().await.unwrap(); - } -} diff --git a/tvix/nix-compat/src/nix_daemon/ser/mock.rs b/tvix/nix-compat/src/nix_daemon/ser/mock.rs deleted file mode 100644 index 1319a8da3228..000000000000 --- a/tvix/nix-compat/src/nix_daemon/ser/mock.rs +++ /dev/null @@ -1,672 +0,0 @@ -use std::collections::VecDeque; -use std::fmt; -use std::io; -use std::thread; - -#[cfg(test)] -use ::proptest::prelude::TestCaseError; -use thiserror::Error; - -use crate::nix_daemon::ProtocolVersion; - -use super::NixWrite; - -#[derive(Debug, Error, PartialEq, Eq, Clone)] -pub enum Error { - #[error("custom error '{0}'")] - Custom(String), - #[error("unsupported data error '{0}'")] - UnsupportedData(String), - #[error("Invalid enum: {0}")] - InvalidEnum(String), - #[error("IO error {0} '{1}'")] - IO(io::ErrorKind, String), - #[error("wrong write: expected {0} got {1}")] - WrongWrite(OperationType, OperationType), - #[error("unexpected write: got an extra {0}")] - ExtraWrite(OperationType), - #[error("got an unexpected number {0} in write_number")] - UnexpectedNumber(u64), - #[error("got an unexpected slice '{0:?}' in write_slice")] - UnexpectedSlice(Vec<u8>), - #[error("got an unexpected display '{0:?}' in write_slice")] - UnexpectedDisplay(String), -} - -impl Error { - pub fn unexpected_write_number(expected: OperationType) -> Error { - Error::WrongWrite(expected, OperationType::WriteNumber) - } - - pub fn extra_write_number() -> Error { - Error::ExtraWrite(OperationType::WriteNumber) - } - - pub fn unexpected_write_slice(expected: OperationType) -> Error { - Error::WrongWrite(expected, OperationType::WriteSlice) - } - - pub fn unexpected_write_display(expected: OperationType) -> Error { - Error::WrongWrite(expected, OperationType::WriteDisplay) - } -} - -impl super::Error for Error { - fn custom<T: fmt::Display>(msg: T) -> Self { - Self::Custom(msg.to_string()) - } - - fn io_error(err: std::io::Error) -> Self { - Self::IO(err.kind(), err.to_string()) - } - - fn unsupported_data<T: fmt::Display>(msg: T) -> Self { - Self::UnsupportedData(msg.to_string()) - } - - fn invalid_enum<T: fmt::Display>(msg: T) -> Self { - Self::InvalidEnum(msg.to_string()) - } -} - -#[allow(clippy::enum_variant_names)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum OperationType { - WriteNumber, - WriteSlice, - WriteDisplay, -} - -impl fmt::Display for OperationType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::WriteNumber => write!(f, "write_number"), - Self::WriteSlice => write!(f, "write_slice"), - Self::WriteDisplay => write!(f, "write_display"), - } - } -} - -#[allow(clippy::enum_variant_names)] -#[derive(Debug, Clone, PartialEq, Eq)] -enum Operation { - WriteNumber(u64, Result<(), Error>), - WriteSlice(Vec<u8>, Result<(), Error>), - WriteDisplay(String, Result<(), Error>), -} - -impl From<Operation> for OperationType { - fn from(value: Operation) -> Self { - match value { - Operation::WriteNumber(_, _) => OperationType::WriteNumber, - Operation::WriteSlice(_, _) => OperationType::WriteSlice, - Operation::WriteDisplay(_, _) => OperationType::WriteDisplay, - } - } -} - -pub struct Builder { - version: ProtocolVersion, - ops: VecDeque<Operation>, -} - -impl Builder { - pub fn new() -> Builder { - Builder { - version: Default::default(), - ops: VecDeque::new(), - } - } - - pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self { - self.version = version.into(); - self - } - - pub fn write_number(&mut self, value: u64) -> &mut Self { - self.ops.push_back(Operation::WriteNumber(value, Ok(()))); - self - } - - pub fn write_number_error(&mut self, value: u64, err: Error) -> &mut Self { - self.ops.push_back(Operation::WriteNumber(value, Err(err))); - self - } - - pub fn write_slice(&mut self, value: &[u8]) -> &mut Self { - self.ops - .push_back(Operation::WriteSlice(value.to_vec(), Ok(()))); - self - } - - pub fn write_slice_error(&mut self, value: &[u8], err: Error) -> &mut Self { - self.ops - .push_back(Operation::WriteSlice(value.to_vec(), Err(err))); - self - } - - pub fn write_display<D>(&mut self, value: D) -> &mut Self - where - D: fmt::Display, - { - let msg = value.to_string(); - self.ops.push_back(Operation::WriteDisplay(msg, Ok(()))); - self - } - - pub fn write_display_error<D>(&mut self, value: D, err: Error) -> &mut Self - where - D: fmt::Display, - { - let msg = value.to_string(); - self.ops.push_back(Operation::WriteDisplay(msg, Err(err))); - self - } - - #[cfg(test)] - fn write_operation_type(&mut self, op: OperationType) -> &mut Self { - match op { - OperationType::WriteNumber => self.write_number(10), - OperationType::WriteSlice => self.write_slice(b"testing"), - OperationType::WriteDisplay => self.write_display("testing"), - } - } - - #[cfg(test)] - fn write_operation(&mut self, op: &Operation) -> &mut Self { - match op { - Operation::WriteNumber(value, Ok(_)) => self.write_number(*value), - Operation::WriteNumber(value, Err(Error::UnexpectedNumber(_))) => { - self.write_number(*value) - } - Operation::WriteNumber(_, Err(Error::ExtraWrite(OperationType::WriteNumber))) => self, - Operation::WriteNumber(_, Err(Error::WrongWrite(op, OperationType::WriteNumber))) => { - self.write_operation_type(*op) - } - Operation::WriteNumber(value, Err(Error::Custom(msg))) => { - self.write_number_error(*value, Error::Custom(msg.clone())) - } - Operation::WriteNumber(value, Err(Error::IO(kind, msg))) => { - self.write_number_error(*value, Error::IO(*kind, msg.clone())) - } - Operation::WriteSlice(value, Ok(_)) => self.write_slice(value), - Operation::WriteSlice(value, Err(Error::UnexpectedSlice(_))) => self.write_slice(value), - Operation::WriteSlice(_, Err(Error::ExtraWrite(OperationType::WriteSlice))) => self, - Operation::WriteSlice(_, Err(Error::WrongWrite(op, OperationType::WriteSlice))) => { - self.write_operation_type(*op) - } - Operation::WriteSlice(value, Err(Error::Custom(msg))) => { - self.write_slice_error(value, Error::Custom(msg.clone())) - } - Operation::WriteSlice(value, Err(Error::IO(kind, msg))) => { - self.write_slice_error(value, Error::IO(*kind, msg.clone())) - } - Operation::WriteDisplay(value, Ok(_)) => self.write_display(value), - Operation::WriteDisplay(value, Err(Error::Custom(msg))) => { - self.write_display_error(value, Error::Custom(msg.clone())) - } - Operation::WriteDisplay(value, Err(Error::IO(kind, msg))) => { - self.write_display_error(value, Error::IO(*kind, msg.clone())) - } - Operation::WriteDisplay(value, Err(Error::UnexpectedDisplay(_))) => { - self.write_display(value) - } - Operation::WriteDisplay(_, Err(Error::ExtraWrite(OperationType::WriteDisplay))) => self, - Operation::WriteDisplay(_, Err(Error::WrongWrite(op, OperationType::WriteDisplay))) => { - self.write_operation_type(*op) - } - s => panic!("Invalid operation {:?}", s), - } - } - - pub fn build(&mut self) -> Mock { - Mock { - version: self.version, - ops: self.ops.clone(), - } - } -} - -impl Default for Builder { - fn default() -> Self { - Self::new() - } -} - -pub struct Mock { - version: ProtocolVersion, - ops: VecDeque<Operation>, -} - -impl Mock { - #[cfg(test)] - #[allow(dead_code)] - async fn assert_operation(&mut self, op: Operation) { - match op { - Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { - assert_eq!( - self.write_number(value).await, - Err(Error::UnexpectedNumber(value)) - ); - } - Operation::WriteNumber(value, res) => { - assert_eq!(self.write_number(value).await, res); - } - Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { - assert_eq!(self.write_slice(value).await, res.clone()); - } - Operation::WriteSlice(value, res) => { - assert_eq!(self.write_slice(&value).await, res); - } - Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { - assert_eq!(self.write_display(value).await, res.clone()); - } - Operation::WriteDisplay(value, res) => { - assert_eq!(self.write_display(value).await, res); - } - } - } - - #[cfg(test)] - async fn prop_assert_operation(&mut self, op: Operation) -> Result<(), TestCaseError> { - use ::proptest::prop_assert_eq; - - match op { - Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { - prop_assert_eq!( - self.write_number(value).await, - Err(Error::UnexpectedNumber(value)) - ); - } - Operation::WriteNumber(value, res) => { - prop_assert_eq!(self.write_number(value).await, res); - } - Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { - prop_assert_eq!(self.write_slice(value).await, res.clone()); - } - Operation::WriteSlice(value, res) => { - prop_assert_eq!(self.write_slice(&value).await, res); - } - Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { - prop_assert_eq!(self.write_display(&value).await, res.clone()); - } - Operation::WriteDisplay(value, res) => { - prop_assert_eq!(self.write_display(&value).await, res); - } - } - Ok(()) - } -} - -impl NixWrite for Mock { - type Error = Error; - - fn version(&self) -> ProtocolVersion { - self.version - } - - async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { - match self.ops.pop_front() { - Some(Operation::WriteNumber(expected, ret)) => { - if value != expected { - return Err(Error::UnexpectedNumber(value)); - } - ret - } - Some(op) => Err(Error::unexpected_write_number(op.into())), - _ => Err(Error::ExtraWrite(OperationType::WriteNumber)), - } - } - - async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { - match self.ops.pop_front() { - Some(Operation::WriteSlice(expected, ret)) => { - if buf != expected { - return Err(Error::UnexpectedSlice(buf.to_vec())); - } - ret - } - Some(op) => Err(Error::unexpected_write_slice(op.into())), - _ => Err(Error::ExtraWrite(OperationType::WriteSlice)), - } - } - - async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error> - where - D: fmt::Display + Send, - Self: Sized, - { - let value = msg.to_string(); - match self.ops.pop_front() { - Some(Operation::WriteDisplay(expected, ret)) => { - if value != expected { - return Err(Error::UnexpectedDisplay(value)); - } - ret - } - Some(op) => Err(Error::unexpected_write_display(op.into())), - _ => Err(Error::ExtraWrite(OperationType::WriteDisplay)), - } - } -} - -impl Drop for Mock { - fn drop(&mut self) { - // No need to panic again - if thread::panicking() { - return; - } - if let Some(op) = self.ops.front() { - panic!("reader dropped with {op:?} operation still unread") - } - } -} - -#[cfg(test)] -mod proptest { - use std::io; - - use proptest::{ - prelude::{any, Arbitrary, BoxedStrategy, Just, Strategy}, - prop_oneof, - }; - - use super::{Error, Operation, OperationType}; - - pub fn arb_write_number_operation() -> impl Strategy<Value = Operation> { - ( - any::<u64>(), - prop_oneof![ - Just(Ok(())), - any::<u64>().prop_map(|v| Err(Error::UnexpectedNumber(v))), - Just(Err(Error::WrongWrite( - OperationType::WriteSlice, - OperationType::WriteNumber - ))), - Just(Err(Error::WrongWrite( - OperationType::WriteDisplay, - OperationType::WriteNumber - ))), - any::<String>().prop_map(|s| Err(Error::Custom(s))), - (any::<io::ErrorKind>(), any::<String>()) - .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), - ], - ) - .prop_filter("same number", |(v, res)| match res { - Err(Error::UnexpectedNumber(exp_v)) => v != exp_v, - _ => true, - }) - .prop_map(|(v, res)| Operation::WriteNumber(v, res)) - } - - pub fn arb_write_slice_operation() -> impl Strategy<Value = Operation> { - ( - any::<Vec<u8>>(), - prop_oneof![ - Just(Ok(())), - any::<Vec<u8>>().prop_map(|v| Err(Error::UnexpectedSlice(v))), - Just(Err(Error::WrongWrite( - OperationType::WriteNumber, - OperationType::WriteSlice - ))), - Just(Err(Error::WrongWrite( - OperationType::WriteDisplay, - OperationType::WriteSlice - ))), - any::<String>().prop_map(|s| Err(Error::Custom(s))), - (any::<io::ErrorKind>(), any::<String>()) - .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), - ], - ) - .prop_filter("same slice", |(v, res)| match res { - Err(Error::UnexpectedSlice(exp_v)) => v != exp_v, - _ => true, - }) - .prop_map(|(v, res)| Operation::WriteSlice(v, res)) - } - - #[allow(dead_code)] - pub fn arb_extra_write() -> impl Strategy<Value = Operation> { - prop_oneof![ - any::<u64>().prop_map(|msg| { - Operation::WriteNumber(msg, Err(Error::ExtraWrite(OperationType::WriteNumber))) - }), - any::<Vec<u8>>().prop_map(|msg| { - Operation::WriteSlice(msg, Err(Error::ExtraWrite(OperationType::WriteSlice))) - }), - any::<String>().prop_map(|msg| { - Operation::WriteDisplay(msg, Err(Error::ExtraWrite(OperationType::WriteDisplay))) - }), - ] - } - - pub fn arb_write_display_operation() -> impl Strategy<Value = Operation> { - ( - any::<String>(), - prop_oneof![ - Just(Ok(())), - any::<String>().prop_map(|v| Err(Error::UnexpectedDisplay(v))), - Just(Err(Error::WrongWrite( - OperationType::WriteNumber, - OperationType::WriteDisplay - ))), - Just(Err(Error::WrongWrite( - OperationType::WriteSlice, - OperationType::WriteDisplay - ))), - any::<String>().prop_map(|s| Err(Error::Custom(s))), - (any::<io::ErrorKind>(), any::<String>()) - .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), - ], - ) - .prop_filter("same string", |(v, res)| match res { - Err(Error::UnexpectedDisplay(exp_v)) => v != exp_v, - _ => true, - }) - .prop_map(|(v, res)| Operation::WriteDisplay(v, res)) - } - - pub fn arb_operation() -> impl Strategy<Value = Operation> { - prop_oneof![ - arb_write_number_operation(), - arb_write_slice_operation(), - arb_write_display_operation(), - ] - } - - impl Arbitrary for Operation { - type Parameters = (); - type Strategy = BoxedStrategy<Operation>; - - fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { - arb_operation().boxed() - } - } -} - -#[cfg(test)] -mod test { - use hex_literal::hex; - use proptest::prelude::any; - use proptest::prelude::TestCaseError; - use proptest::proptest; - - use crate::nix_daemon::ser::mock::proptest::arb_extra_write; - use crate::nix_daemon::ser::mock::Operation; - use crate::nix_daemon::ser::mock::OperationType; - use crate::nix_daemon::ser::Error as _; - use crate::nix_daemon::ser::NixWrite; - - use super::{Builder, Error}; - - #[tokio::test] - async fn write_number() { - let mut mock = Builder::new().write_number(10).build(); - mock.write_number(10).await.unwrap(); - } - - #[tokio::test] - async fn write_number_error() { - let mut mock = Builder::new() - .write_number_error(10, Error::custom("bad number")) - .build(); - assert_eq!( - Err(Error::custom("bad number")), - mock.write_number(10).await - ); - } - - #[tokio::test] - async fn write_number_unexpected() { - let mut mock = Builder::new().write_slice(b"").build(); - assert_eq!( - Err(Error::unexpected_write_number(OperationType::WriteSlice)), - mock.write_number(11).await - ); - } - - #[tokio::test] - async fn write_number_unexpected_number() { - let mut mock = Builder::new().write_number(10).build(); - assert_eq!( - Err(Error::UnexpectedNumber(11)), - mock.write_number(11).await - ); - } - - #[tokio::test] - async fn extra_write_number() { - let mut mock = Builder::new().build(); - assert_eq!( - Err(Error::ExtraWrite(OperationType::WriteNumber)), - mock.write_number(11).await - ); - } - - #[tokio::test] - async fn write_slice() { - let mut mock = Builder::new() - .write_slice(&[]) - .write_slice(&hex!("0000 1234 5678 9ABC DEFF")) - .build(); - mock.write_slice(&[]).await.expect("write_slice empty"); - mock.write_slice(&hex!("0000 1234 5678 9ABC DEFF")) - .await - .expect("write_slice"); - } - - #[tokio::test] - async fn write_slice_error() { - let mut mock = Builder::new() - .write_slice_error(&[], Error::custom("bad slice")) - .build(); - assert_eq!(Err(Error::custom("bad slice")), mock.write_slice(&[]).await); - } - - #[tokio::test] - async fn write_slice_unexpected() { - let mut mock = Builder::new().write_number(10).build(); - assert_eq!( - Err(Error::unexpected_write_slice(OperationType::WriteNumber)), - mock.write_slice(b"").await - ); - } - - #[tokio::test] - async fn write_slice_unexpected_slice() { - let mut mock = Builder::new().write_slice(b"").build(); - assert_eq!( - Err(Error::UnexpectedSlice(b"bad slice".to_vec())), - mock.write_slice(b"bad slice").await - ); - } - - #[tokio::test] - async fn extra_write_slice() { - let mut mock = Builder::new().build(); - assert_eq!( - Err(Error::ExtraWrite(OperationType::WriteSlice)), - mock.write_slice(b"extra slice").await - ); - } - - #[tokio::test] - async fn write_display() { - let mut mock = Builder::new().write_display("testing").build(); - mock.write_display("testing").await.unwrap(); - } - - #[tokio::test] - async fn write_display_error() { - let mut mock = Builder::new() - .write_display_error("testing", Error::custom("bad number")) - .build(); - assert_eq!( - Err(Error::custom("bad number")), - mock.write_display("testing").await - ); - } - - #[tokio::test] - async fn write_display_unexpected() { - let mut mock = Builder::new().write_number(10).build(); - assert_eq!( - Err(Error::unexpected_write_display(OperationType::WriteNumber)), - mock.write_display("").await - ); - } - - #[tokio::test] - async fn write_display_unexpected_display() { - let mut mock = Builder::new().write_display("").build(); - assert_eq!( - Err(Error::UnexpectedDisplay("bad display".to_string())), - mock.write_display("bad display").await - ); - } - - #[tokio::test] - async fn extra_write_display() { - let mut mock = Builder::new().build(); - assert_eq!( - Err(Error::ExtraWrite(OperationType::WriteDisplay)), - mock.write_display("extra slice").await - ); - } - - #[test] - #[should_panic] - fn operations_left() { - let _ = Builder::new().write_number(10).build(); - } - - #[test] - fn proptest_mock() { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - proptest!(|( - operations in any::<Vec<Operation>>(), - extra_operations in proptest::collection::vec(arb_extra_write(), 0..3) - )| { - rt.block_on(async { - let mut builder = Builder::new(); - for op in operations.iter() { - builder.write_operation(op); - } - for op in extra_operations.iter() { - builder.write_operation(op); - } - let mut mock = builder.build(); - for op in operations { - mock.prop_assert_operation(op).await?; - } - for op in extra_operations { - mock.prop_assert_operation(op).await?; - } - Ok(()) as Result<(), TestCaseError> - })?; - }); - } -} diff --git a/tvix/nix-compat/src/nix_daemon/ser/mod.rs b/tvix/nix-compat/src/nix_daemon/ser/mod.rs deleted file mode 100644 index 5860226f39eb..000000000000 --- a/tvix/nix-compat/src/nix_daemon/ser/mod.rs +++ /dev/null @@ -1,124 +0,0 @@ -use std::error::Error as StdError; -use std::future::Future; -use std::{fmt, io}; - -use super::ProtocolVersion; - -mod bytes; -mod collections; -#[cfg(feature = "nix-compat-derive")] -mod display; -mod int; -#[cfg(any(test, feature = "test"))] -pub mod mock; -mod writer; - -pub use writer::{NixWriter, NixWriterBuilder}; - -pub trait Error: Sized + StdError { - fn custom<T: fmt::Display>(msg: T) -> Self; - - fn io_error(err: std::io::Error) -> Self { - Self::custom(format_args!("There was an I/O error {}", err)) - } - - fn unsupported_data<T: fmt::Display>(msg: T) -> Self { - Self::custom(msg) - } - - fn invalid_enum<T: fmt::Display>(msg: T) -> Self { - Self::custom(msg) - } -} - -impl Error for io::Error { - fn custom<T: fmt::Display>(msg: T) -> Self { - io::Error::new(io::ErrorKind::Other, msg.to_string()) - } - - fn io_error(err: std::io::Error) -> Self { - err - } - - fn unsupported_data<T: fmt::Display>(msg: T) -> Self { - io::Error::new(io::ErrorKind::InvalidData, msg.to_string()) - } -} - -pub trait NixWrite: Send { - type Error: Error; - - /// Some types are serialized differently depending on the version - /// of the protocol and so this can be used for implementing that. - fn version(&self) -> ProtocolVersion; - - /// Write a single u64 to the protocol. - fn write_number(&mut self, value: u64) -> impl Future<Output = Result<(), Self::Error>> + Send; - - /// Write a slice of bytes to the protocol. - fn write_slice(&mut self, buf: &[u8]) -> impl Future<Output = Result<(), Self::Error>> + Send; - - /// Write a value that implements `std::fmt::Display` to the protocol. - /// The protocol uses many small string formats and instead of allocating - /// a `String` each time we want to write one an implementation of `NixWrite` - /// can instead use `Display` to dump these formats to a reusable buffer. - fn write_display<D>(&mut self, msg: D) -> impl Future<Output = Result<(), Self::Error>> + Send - where - D: fmt::Display + Send, - Self: Sized, - { - async move { - let s = msg.to_string(); - self.write_slice(s.as_bytes()).await - } - } - - /// Write a value to the protocol. - /// Uses `NixSerialize::serialize` to write the value. - fn write_value<V>(&mut self, value: &V) -> impl Future<Output = Result<(), Self::Error>> + Send - where - V: NixSerialize + Send + ?Sized, - Self: Sized, - { - value.serialize(self) - } -} - -impl<T: NixWrite> NixWrite for &mut T { - type Error = T::Error; - - fn version(&self) -> ProtocolVersion { - (**self).version() - } - - fn write_number(&mut self, value: u64) -> impl Future<Output = Result<(), Self::Error>> + Send { - (**self).write_number(value) - } - - fn write_slice(&mut self, buf: &[u8]) -> impl Future<Output = Result<(), Self::Error>> + Send { - (**self).write_slice(buf) - } - - fn write_display<D>(&mut self, msg: D) -> impl Future<Output = Result<(), Self::Error>> + Send - where - D: fmt::Display + Send, - Self: Sized, - { - (**self).write_display(msg) - } - - fn write_value<V>(&mut self, value: &V) -> impl Future<Output = Result<(), Self::Error>> + Send - where - V: NixSerialize + Send + ?Sized, - Self: Sized, - { - (**self).write_value(value) - } -} - -pub trait NixSerialize { - /// Write a value to the writer. - fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send - where - W: NixWrite; -} diff --git a/tvix/nix-compat/src/nix_daemon/ser/writer.rs b/tvix/nix-compat/src/nix_daemon/ser/writer.rs deleted file mode 100644 index 87e30580af34..000000000000 --- a/tvix/nix-compat/src/nix_daemon/ser/writer.rs +++ /dev/null @@ -1,308 +0,0 @@ -use std::fmt::{self, Write as _}; -use std::future::poll_fn; -use std::io::{self, Cursor}; -use std::pin::Pin; -use std::task::{ready, Context, Poll}; - -use bytes::{Buf, BufMut, BytesMut}; -use pin_project_lite::pin_project; -use tokio::io::{AsyncWrite, AsyncWriteExt}; - -use crate::nix_daemon::ProtocolVersion; -use crate::wire::padding_len; -use crate::wire::EMPTY_BYTES; - -use super::{Error, NixWrite}; - -pub struct NixWriterBuilder { - buf: Option<BytesMut>, - reserved_buf_size: usize, - max_buf_size: usize, - version: ProtocolVersion, -} - -impl Default for NixWriterBuilder { - fn default() -> Self { - Self { - buf: Default::default(), - reserved_buf_size: 8192, - max_buf_size: 8192, - version: Default::default(), - } - } -} - -impl NixWriterBuilder { - pub fn set_buffer(mut self, buf: BytesMut) -> Self { - self.buf = Some(buf); - self - } - - pub fn set_reserved_buf_size(mut self, size: usize) -> Self { - self.reserved_buf_size = size; - self - } - - pub fn set_max_buf_size(mut self, size: usize) -> Self { - self.max_buf_size = size; - self - } - - pub fn set_version(mut self, version: ProtocolVersion) -> Self { - self.version = version; - self - } - - pub fn build<W>(self, writer: W) -> NixWriter<W> { - let buf = self - .buf - .unwrap_or_else(|| BytesMut::with_capacity(self.max_buf_size)); - NixWriter { - buf, - inner: writer, - reserved_buf_size: self.reserved_buf_size, - max_buf_size: self.max_buf_size, - version: self.version, - } - } -} - -pin_project! { - pub struct NixWriter<W> { - #[pin] - inner: W, - buf: BytesMut, - reserved_buf_size: usize, - max_buf_size: usize, - version: ProtocolVersion, - } -} - -impl NixWriter<Cursor<Vec<u8>>> { - pub fn builder() -> NixWriterBuilder { - NixWriterBuilder::default() - } -} - -impl<W> NixWriter<W> -where - W: AsyncWriteExt, -{ - pub fn new(writer: W) -> NixWriter<W> { - NixWriter::builder().build(writer) - } - - pub fn buffer(&self) -> &[u8] { - &self.buf[..] - } - - pub fn set_version(&mut self, version: ProtocolVersion) { - self.version = version; - } - - /// Remaining capacity in internal buffer - pub fn remaining_mut(&self) -> usize { - self.buf.capacity() - self.buf.len() - } - - fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { - let mut this = self.project(); - while !this.buf.is_empty() { - let n = ready!(this.inner.as_mut().poll_write(cx, &this.buf[..]))?; - if n == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write the buffer", - ))); - } - this.buf.advance(n); - } - Poll::Ready(Ok(())) - } -} - -impl<W> NixWriter<W> -where - W: AsyncWriteExt + Unpin, -{ - async fn flush_buf(&mut self) -> Result<(), io::Error> { - let mut s = Pin::new(self); - poll_fn(move |cx| s.as_mut().poll_flush_buf(cx)).await - } -} - -impl<W> AsyncWrite for NixWriter<W> -where - W: AsyncWrite, -{ - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll<Result<usize, io::Error>> { - // Flush - if self.remaining_mut() < buf.len() { - ready!(self.as_mut().poll_flush_buf(cx))?; - } - let this = self.project(); - if buf.len() > this.buf.capacity() { - this.inner.poll_write(cx, buf) - } else { - this.buf.put_slice(buf); - Poll::Ready(Ok(buf.len())) - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { - ready!(self.as_mut().poll_flush_buf(cx))?; - self.project().inner.poll_flush(cx) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<Result<(), io::Error>> { - ready!(self.as_mut().poll_flush_buf(cx))?; - self.project().inner.poll_shutdown(cx) - } -} - -impl<W> NixWrite for NixWriter<W> -where - W: AsyncWrite + Send + Unpin, -{ - type Error = io::Error; - - fn version(&self) -> ProtocolVersion { - self.version - } - - async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { - let mut buf = [0u8; 8]; - BufMut::put_u64_le(&mut &mut buf[..], value); - self.write_all(&buf).await - } - - async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { - let padding = padding_len(buf.len() as u64) as usize; - self.write_value(&buf.len()).await?; - self.write_all(buf).await?; - if padding > 0 { - self.write_all(&EMPTY_BYTES[..padding]).await - } else { - Ok(()) - } - } - - async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error> - where - D: fmt::Display + Send, - Self: Sized, - { - // Ensure that buffer has space for at least reserved_buf_size bytes - if self.remaining_mut() < self.reserved_buf_size && !self.buf.is_empty() { - self.flush_buf().await?; - } - let offset = self.buf.len(); - self.buf.put_u64_le(0); - if let Err(err) = write!(self.buf, "{}", msg) { - self.buf.truncate(offset); - return Err(Self::Error::unsupported_data(err)); - } - let len = self.buf.len() - offset - 8; - BufMut::put_u64_le(&mut &mut self.buf[offset..(offset + 8)], len as u64); - let padding = padding_len(len as u64) as usize; - self.write_all(&EMPTY_BYTES[..padding]).await - } -} - -#[cfg(test)] -mod test { - use std::time::Duration; - - use hex_literal::hex; - use rstest::rstest; - use tokio::io::AsyncWriteExt as _; - use tokio_test::io::Builder; - - use crate::nix_daemon::ser::NixWrite; - - use super::NixWriter; - - #[rstest] - #[case(1, &hex!("0100 0000 0000 0000"))] - #[case::evil(666, &hex!("9A02 0000 0000 0000"))] - #[case::max(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] - #[tokio::test] - async fn test_write_number(#[case] number: u64, #[case] buf: &[u8]) { - let mock = Builder::new().write(buf).build(); - let mut writer = NixWriter::new(mock); - - writer.write_number(number).await.unwrap(); - assert_eq!(writer.buffer(), buf); - writer.flush().await.unwrap(); - assert_eq!(writer.buffer(), b""); - } - - #[rstest] - #[case::empty(b"", &hex!("0000 0000 0000 0000"))] - #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] - #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] - #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] - #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] - #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] - #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] - #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] - #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] - #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] - #[tokio::test] - async fn test_write_slice( - #[case] value: &[u8], - #[case] buf: &[u8], - #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, - #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] buf_size: usize, - ) { - let mut builder = Builder::new(); - for chunk in buf.chunks(chunks_size) { - builder.write(chunk); - builder.wait(Duration::ZERO); - } - let mock = builder.build(); - let mut writer = NixWriter::builder().set_max_buf_size(buf_size).build(mock); - - writer.write_slice(value).await.unwrap(); - writer.flush().await.unwrap(); - assert_eq!(writer.buffer(), b""); - } - - #[rstest] - #[case::empty("", &hex!("0000 0000 0000 0000"))] - #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] - #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] - #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] - #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] - #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] - #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] - #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] - #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] - #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] - #[tokio::test] - async fn test_write_display( - #[case] value: &str, - #[case] buf: &[u8], - #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, - ) { - let mut builder = Builder::new(); - for chunk in buf.chunks(chunks_size) { - builder.write(chunk); - builder.wait(Duration::ZERO); - } - let mock = builder.build(); - let mut writer = NixWriter::builder().build(mock); - - writer.write_display(value).await.unwrap(); - assert_eq!(writer.buffer(), buf); - writer.flush().await.unwrap(); - assert_eq!(writer.buffer(), b""); - } -} diff --git a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs index cc99bdb54fab..92259a0633a0 100644 --- a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs +++ b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs @@ -10,7 +10,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::wire; -use super::ProtocolVersion; +use crate::wire::ProtocolVersion; static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc" static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio" |