diff options
author | Brian Olsen <brian@maven-group.org> | 2024-07-20T09·25+0200 |
---|---|---|
committer | clbot <clbot@tvl.fyi> | 2024-08-25T15·05+0000 |
commit | 9af69204787d47cfe551f524d01b1a726971f06e (patch) | |
tree | ab8f2b90b2624c0f73262899019a3f1098fb0bdb /tvix/nix-compat | |
parent | a774cb8c10ea976bcdff2e296e3cefc6adbc21d3 (diff) |
feat(nix-compat): Add NixDeserialize and NixRead traits r/8585
Add a trait for deserializing a type from a daemon worker connection. This adds the NixDeserialize trait which is kind of like the serde Deserialize trait in that individual types are meant to implement it and it can potentially be derived in the future. The NixDeserialize trait takes something that implements NixRead as input so that you can among other things mock the reader. Change-Id: Ibb59e3562dfc822652f7d18039f00a1c0d422997 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11990 Autosubmit: Brian Olsen <me@griff.name> Reviewed-by: flokli <flokli@flokli.de> Tested-by: BuildkiteCI
Diffstat (limited to 'tvix/nix-compat')
-rw-r--r-- | tvix/nix-compat/Cargo.toml | 6 | ||||
-rw-r--r-- | tvix/nix-compat/src/lib.rs | 2 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/bytes.rs | 70 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/collections.rs | 105 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/int.rs | 100 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/mock.rs | 261 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/mod.rs | 225 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/reader.rs | 527 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/mod.rs | 2 | ||||
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/protocol_version.rs | 9 | ||||
-rw-r--r-- | tvix/nix-compat/src/wire/bytes/mod.rs | 2 |
11 files changed, 1306 insertions, 3 deletions
diff --git a/tvix/nix-compat/Cargo.toml b/tvix/nix-compat/Cargo.toml index 22325ad12bfe..9f43bb24efcc 100644 --- a/tvix/nix-compat/Cargo.toml +++ b/tvix/nix-compat/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" # async NAR writer. Also needs the `wire` feature. async = ["tokio"] # code emitting low-level packets used in the daemon protocol. -wire = ["tokio", "pin-project-lite"] +wire = ["tokio", "pin-project-lite", "bytes"] # Enable all features by default. default = ["async", "wire"] @@ -29,6 +29,10 @@ sha2 = "0.10.6" thiserror = "1.0.38" tracing = "0.1.37" +[dependencies.bytes] +optional = true +version = "1.6.1" + [dependencies.tokio] optional = true version = "1.32.0" diff --git a/tvix/nix-compat/src/lib.rs b/tvix/nix-compat/src/lib.rs index 1410a8264240..6eec4b8d03a8 100644 --- a/tvix/nix-compat/src/lib.rs +++ b/tvix/nix-compat/src/lib.rs @@ -13,7 +13,7 @@ pub mod store_path; pub mod wire; #[cfg(feature = "wire")] -mod nix_daemon; +pub mod nix_daemon; #[cfg(feature = "wire")] pub use nix_daemon::worker_protocol; #[cfg(feature = "wire")] diff --git a/tvix/nix-compat/src/nix_daemon/de/bytes.rs b/tvix/nix-compat/src/nix_daemon/de/bytes.rs new file mode 100644 index 000000000000..7daced54eef7 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/de/bytes.rs @@ -0,0 +1,70 @@ +use bytes::Bytes; + +use super::{Error, NixDeserialize, NixRead}; + +impl NixDeserialize for Bytes { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + reader.try_read_bytes().await + } +} + +impl NixDeserialize for String { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + if let Some(buf) = reader.try_read_bytes().await? { + String::from_utf8(buf.to_vec()) + .map_err(R::Error::invalid_data) + .map(Some) + } else { + Ok(None) + } + } +} + +#[cfg(test)] +mod test { + use std::io; + + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use crate::nix_daemon::de::{NixRead, NixReader}; + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_read_string(#[case] expected: &str, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: String = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + #[tokio::test] + async fn test_read_string_invalid() { + let mock = Builder::new() + .read(&hex!("0300 0000 0000 0000 EDA0 8000 0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + assert_eq!( + io::ErrorKind::InvalidData, + reader.read_value::<String>().await.unwrap_err().kind() + ); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/de/collections.rs b/tvix/nix-compat/src/nix_daemon/de/collections.rs new file mode 100644 index 000000000000..cf79f584506a --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/de/collections.rs @@ -0,0 +1,105 @@ +use std::{collections::BTreeMap, future::Future}; + +use super::{NixDeserialize, NixRead}; + +#[allow(clippy::manual_async_fn)] +impl<T> NixDeserialize for Vec<T> +where + T: NixDeserialize + Send, +{ + fn try_deserialize<R>( + reader: &mut R, + ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send, + { + async move { + if let Some(len) = reader.try_read_value::<usize>().await? { + let mut ret = Vec::with_capacity(len); + for _ in 0..len { + ret.push(reader.read_value().await?); + } + Ok(Some(ret)) + } else { + Ok(None) + } + } + } +} + +#[allow(clippy::manual_async_fn)] +impl<K, V> NixDeserialize for BTreeMap<K, V> +where + K: NixDeserialize + Ord + Send, + V: NixDeserialize + Send, +{ + fn try_deserialize<R>( + reader: &mut R, + ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send, + { + async move { + if let Some(len) = reader.try_read_value::<usize>().await? { + let mut ret = BTreeMap::new(); + for _ in 0..len { + let key = reader.read_value().await?; + let value = reader.read_value().await?; + ret.insert(key, value); + } + Ok(Some(ret)) + } else { + Ok(None) + } + } + } +} + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + use std::fmt; + + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use crate::nix_daemon::de::{NixDeserialize, NixRead, NixReader}; + + #[rstest] + #[case::empty(vec![], &hex!("0000 0000 0000 0000"))] + #[case::one(vec![0x29], &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(vec![0x7469, 10], &hex!("0200 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_read_small_vec(#[case] expected: Vec<usize>, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: Vec<usize> = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + fn empty_map() -> BTreeMap<usize, u64> { + BTreeMap::new() + } + macro_rules! map { + ($($key:expr => $value:expr),*) => {{ + let mut ret = BTreeMap::new(); + $(ret.insert($key, $value);)* + ret + }}; + } + + #[rstest] + #[case::empty(empty_map(), &hex!("0000 0000 0000 0000"))] + #[case::one(map![0x7469usize => 10u64], &hex!("0100 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_read_small_btree_map<E>(#[case] expected: E, #[case] data: &[u8]) + where + E: NixDeserialize + PartialEq + fmt::Debug, + { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: E = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/de/int.rs b/tvix/nix-compat/src/nix_daemon/de/int.rs new file mode 100644 index 000000000000..eecf641cfe99 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/de/int.rs @@ -0,0 +1,100 @@ +use super::{Error, NixDeserialize, NixRead}; + +impl NixDeserialize for u64 { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + reader.try_read_number().await + } +} + +impl NixDeserialize for usize { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + if let Some(value) = reader.try_read_number().await? { + value.try_into().map_err(R::Error::invalid_data).map(Some) + } else { + Ok(None) + } + } +} + +impl NixDeserialize for bool { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + Ok(reader.try_read_number().await?.map(|v| v != 0)) + } +} +impl NixDeserialize for i64 { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + Ok(reader.try_read_number().await?.map(|v| v as i64)) + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use crate::nix_daemon::de::{NixRead, NixReader}; + + #[rstest] + #[case::simple_false(false, &hex!("0000 0000 0000 0000"))] + #[case::simple_true(true, &hex!("0100 0000 0000 0000"))] + #[case::other_true(true, &hex!("1234 5600 0000 0000"))] + #[case::max_true(true, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_read_bool(#[case] expected: bool, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: bool = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_read_u64(#[case] expected: u64, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: u64 = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(usize::MAX, &usize::MAX.to_le_bytes())] + #[tokio::test] + async fn test_read_usize(#[case] expected: usize, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: usize = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + // FUTUREWORK: Test this on supported hardware + #[tokio::test] + #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] + async fn test_read_usize_overflow() { + let mock = Builder::new().read(&u64::MAX.to_le_bytes()).build(); + let mut reader = NixReader::new(mock); + assert_eq!( + std::io::ErrorKind::InvalidData, + reader.read_value::<usize>().await.unwrap_err().kind() + ); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/de/mock.rs b/tvix/nix-compat/src/nix_daemon/de/mock.rs new file mode 100644 index 000000000000..31cc3a4897ba --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/de/mock.rs @@ -0,0 +1,261 @@ +use std::collections::VecDeque; +use std::fmt; +use std::io; +use std::thread; + +use bytes::Bytes; +use thiserror::Error; + +use crate::nix_daemon::ProtocolVersion; + +use super::NixRead; + +#[derive(Debug, Error, PartialEq, Eq, Clone)] +pub enum Error { + #[error("custom error '{0}'")] + Custom(String), + #[error("invalid data '{0}'")] + InvalidData(String), + #[error("missing data '{0}'")] + MissingData(String), + #[error("IO error {0} '{1}'")] + IO(io::ErrorKind, String), + #[error("wrong read: expected {0} got {1}")] + WrongRead(OperationType, OperationType), +} + +impl Error { + pub fn expected_read_number() -> Error { + Error::WrongRead(OperationType::ReadNumber, OperationType::ReadBytes) + } + + pub fn expected_read_bytes() -> Error { + Error::WrongRead(OperationType::ReadBytes, OperationType::ReadNumber) + } +} + +impl super::Error for Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + Self::Custom(msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + Self::IO(err.kind(), err.to_string()) + } + + fn invalid_data<T: fmt::Display>(msg: T) -> Self { + Self::InvalidData(msg.to_string()) + } + + fn missing_data<T: fmt::Display>(msg: T) -> Self { + Self::MissingData(msg.to_string()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OperationType { + ReadNumber, + ReadBytes, +} + +impl fmt::Display for OperationType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::ReadNumber => write!(f, "read_number"), + Self::ReadBytes => write!(f, "read_bytess"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum Operation { + ReadNumber(Result<u64, Error>), + ReadBytes(Result<Bytes, Error>), +} + +impl From<Operation> for OperationType { + fn from(value: Operation) -> Self { + match value { + Operation::ReadNumber(_) => OperationType::ReadNumber, + Operation::ReadBytes(_) => OperationType::ReadBytes, + } + } +} + +pub struct Builder { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Builder { + pub fn new() -> Builder { + Builder { + version: Default::default(), + ops: VecDeque::new(), + } + } + + pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self { + self.version = version.into(); + self + } + + pub fn read_number(&mut self, value: u64) -> &mut Self { + self.ops.push_back(Operation::ReadNumber(Ok(value))); + self + } + + pub fn read_number_error(&mut self, err: Error) -> &mut Self { + self.ops.push_back(Operation::ReadNumber(Err(err))); + self + } + + pub fn read_bytes(&mut self, value: Bytes) -> &mut Self { + self.ops.push_back(Operation::ReadBytes(Ok(value))); + self + } + + pub fn read_slice(&mut self, data: &[u8]) -> &mut Self { + let value = Bytes::copy_from_slice(data); + self.ops.push_back(Operation::ReadBytes(Ok(value))); + self + } + + pub fn read_bytes_error(&mut self, err: Error) -> &mut Self { + self.ops.push_back(Operation::ReadBytes(Err(err))); + self + } + + pub fn build(&mut self) -> Mock { + Mock { + version: self.version, + ops: self.ops.clone(), + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +pub struct Mock { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl NixRead for Mock { + type Error = Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> { + match self.ops.pop_front() { + Some(Operation::ReadNumber(ret)) => ret.map(Some), + Some(Operation::ReadBytes(_)) => Err(Error::expected_read_bytes()), + None => Ok(None), + } + } + + async fn try_read_bytes_limited( + &mut self, + _limit: std::ops::RangeInclusive<usize>, + ) -> Result<Option<Bytes>, Self::Error> { + match self.ops.pop_front() { + Some(Operation::ReadBytes(ret)) => ret.map(Some), + Some(Operation::ReadNumber(_)) => Err(Error::expected_read_number()), + None => Ok(None), + } + } +} + +impl Drop for Mock { + fn drop(&mut self) { + // No need to panic again + if thread::panicking() { + return; + } + if let Some(op) = self.ops.front() { + panic!("reader dropped with {op:?} operation still unread") + } + } +} + +#[cfg(test)] +mod test { + use bytes::Bytes; + use hex_literal::hex; + + use crate::nix_daemon::de::NixRead; + + use super::{Builder, Error}; + + #[tokio::test] + async fn read_slice() { + let mut mock = Builder::new() + .read_number(10) + .read_slice(&[]) + .read_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]); + assert_eq!( + &hex!("0000 1234 5678 9ABC DEFF"), + &mock.read_bytes().await.unwrap()[..] + ); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn read_bytes() { + let mut mock = Builder::new() + .read_number(10) + .read_bytes(Bytes::from_static(&[])) + .read_bytes(Bytes::from_static(&hex!("0000 1234 5678 9ABC DEFF"))) + .build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]); + assert_eq!( + &hex!("0000 1234 5678 9ABC DEFF"), + &mock.read_bytes().await.unwrap()[..] + ); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn read_number() { + let mut mock = Builder::new().read_number(10).build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn expect_number() { + let mut mock = Builder::new().read_number(10).build(); + assert_eq!( + Error::expected_read_number(), + mock.read_bytes().await.unwrap_err() + ); + } + + #[tokio::test] + async fn expect_bytes() { + let mut mock = Builder::new().read_slice(&[]).build(); + assert_eq!( + Error::expected_read_bytes(), + mock.read_number().await.unwrap_err() + ); + } + + #[test] + #[should_panic] + fn operations_left() { + let _ = Builder::new().read_number(10).build(); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/de/mod.rs b/tvix/nix-compat/src/nix_daemon/de/mod.rs new file mode 100644 index 000000000000..f85ccd8fea0e --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/de/mod.rs @@ -0,0 +1,225 @@ +use std::error::Error as StdError; +use std::future::Future; +use std::ops::RangeInclusive; +use std::{fmt, io}; + +use ::bytes::Bytes; + +use super::ProtocolVersion; + +mod bytes; +mod collections; +mod int; +#[cfg(any(test, feature = "test"))] +pub mod mock; +mod reader; + +pub use reader::{NixReader, NixReaderBuilder}; + +/// Like serde the `Error` trait allows `NixRead` implementations to add +/// custom error handling for `NixDeserialize`. +pub trait Error: Sized + StdError { + /// A totally custom non-specific error. + fn custom<T: fmt::Display>(msg: T) -> Self; + + /// Some kind of std::io::Error occured. + fn io_error(err: std::io::Error) -> Self { + Self::custom(format_args!("There was an I/O error {}", err)) + } + + /// The data read from `NixRead` is invalid. + /// This could be that some bytes were supposed to be valid UFT-8 but weren't. + fn invalid_data<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } + + /// Required data is missing. This is mostly like an EOF + fn missing_data<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } +} + +impl Error for io::Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::Other, msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + err + } + + fn invalid_data<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::InvalidData, msg.to_string()) + } + + fn missing_data<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::UnexpectedEof, msg.to_string()) + } +} + +/// A reader of data from the Nix daemon protocol. +/// Basically there are two basic types in the Nix daemon protocol +/// u64 and a bytes buffer. Everything else is more or less built on +/// top of these two types. +pub trait NixRead: Send { + type Error: Error + Send; + + /// Some types are serialized differently depending on the version + /// of the protocol and so this can be used for implementing that. + fn version(&self) -> ProtocolVersion; + + /// Read a single u64 from the protocol. + /// This returns an Option to support graceful shutdown. + fn try_read_number( + &mut self, + ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + '_; + + /// Read bytes from the protocol. + /// A size limit on the returned bytes has to be specified. + /// This returns an Option to support graceful shutdown. + fn try_read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_; + + /// Read bytes from the protocol without a limit. + /// The default implementation just calls `try_read_bytes_limited` with a + /// limit of `0..=usize::MAX` but other implementations are free to have a + /// reader wide limit. + /// This returns an Option to support graceful shutdown. + fn try_read_bytes( + &mut self, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + self.try_read_bytes_limited(0..=usize::MAX) + } + + /// Read a single u64 from the protocol. + /// This will return an error if the number could not be read. + fn read_number(&mut self) -> impl Future<Output = Result<u64, Self::Error>> + Send + '_ { + async move { + match self.try_read_number().await? { + Some(v) => Ok(v), + None => Err(Self::Error::missing_data("unexpected end-of-file")), + } + } + } + + /// Read bytes from the protocol. + /// A size limit on the returned bytes has to be specified. + /// This will return an error if the number could not be read. + fn read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + async move { + match self.try_read_bytes_limited(limit).await? { + Some(v) => Ok(v), + None => Err(Self::Error::missing_data("unexpected end-of-file")), + } + } + } + + /// Read bytes from the protocol. + /// The default implementation just calls `read_bytes_limited` with a + /// limit of `0..=usize::MAX` but other implementations are free to have a + /// reader wide limit. + /// This will return an error if the bytes could not be read. + fn read_bytes(&mut self) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + self.read_bytes_limited(0..=usize::MAX) + } + + /// Read a value from the protocol. + /// Uses `NixDeserialize::deserialize` to read a value. + fn read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<V, Self::Error>> + Send + '_ { + V::deserialize(self) + } + + /// Read a value from the protocol. + /// Uses `NixDeserialize::try_deserialize` to read a value. + /// This returns an Option to support graceful shutdown. + fn try_read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<Option<V>, Self::Error>> + Send + '_ { + V::try_deserialize(self) + } +} + +impl<T: ?Sized + NixRead> NixRead for &mut T { + type Error = T::Error; + + fn version(&self) -> ProtocolVersion { + (**self).version() + } + + fn try_read_number( + &mut self, + ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + '_ { + (**self).try_read_number() + } + + fn try_read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + (**self).try_read_bytes_limited(limit) + } + + fn try_read_bytes( + &mut self, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + (**self).try_read_bytes() + } + + fn read_number(&mut self) -> impl Future<Output = Result<u64, Self::Error>> + Send + '_ { + (**self).read_number() + } + + fn read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + (**self).read_bytes_limited(limit) + } + + fn read_bytes(&mut self) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + (**self).read_bytes() + } + + fn try_read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<Option<V>, Self::Error>> + Send + '_ { + (**self).try_read_value() + } + + fn read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<V, Self::Error>> + Send + '_ { + (**self).read_value() + } +} + +/// A data structure that can be deserialized from the Nix daemon +/// worker protocol. +pub trait NixDeserialize: Sized { + /// Read a value from the reader. + /// This returns an Option to support gracefull shutdown. + fn try_deserialize<R>( + reader: &mut R, + ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send; + + fn deserialize<R>(reader: &mut R) -> impl Future<Output = Result<Self, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send, + { + async move { + match Self::try_deserialize(reader).await? { + Some(v) => Ok(v), + None => Err(R::Error::missing_data("unexpected end-of-file")), + } + } + } +} diff --git a/tvix/nix-compat/src/nix_daemon/de/reader.rs b/tvix/nix-compat/src/nix_daemon/de/reader.rs new file mode 100644 index 000000000000..87c623b2220c --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/de/reader.rs @@ -0,0 +1,527 @@ +use std::future::poll_fn; +use std::io::{self, Cursor}; +use std::ops::RangeInclusive; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf}; + +use crate::nix_daemon::ProtocolVersion; +use crate::wire::EMPTY_BYTES; + +use super::{Error, NixRead}; + +pub struct NixReaderBuilder { + buf: Option<BytesMut>, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, +} + +impl Default for NixReaderBuilder { + fn default() -> Self { + Self { + buf: Default::default(), + reserved_buf_size: 8192, + max_buf_size: 8192, + version: Default::default(), + } + } +} + +impl NixReaderBuilder { + pub fn set_buffer(mut self, buf: BytesMut) -> Self { + self.buf = Some(buf); + self + } + + pub fn set_reserved_buf_size(mut self, size: usize) -> Self { + self.reserved_buf_size = size; + self + } + + pub fn set_max_buf_size(mut self, size: usize) -> Self { + self.max_buf_size = size; + self + } + + pub fn set_version(mut self, version: ProtocolVersion) -> Self { + self.version = version; + self + } + + pub fn build<R>(self, reader: R) -> NixReader<R> { + let buf = self.buf.unwrap_or_else(|| BytesMut::with_capacity(0)); + NixReader { + buf, + inner: reader, + reserved_buf_size: self.reserved_buf_size, + max_buf_size: self.max_buf_size, + version: self.version, + } + } +} + +pin_project! { + pub struct NixReader<R> { + #[pin] + inner: R, + buf: BytesMut, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, + } +} + +impl NixReader<Cursor<Vec<u8>>> { + pub fn builder() -> NixReaderBuilder { + NixReaderBuilder::default() + } +} + +impl<R> NixReader<R> +where + R: AsyncReadExt, +{ + pub fn new(reader: R) -> NixReader<R> { + NixReader::builder().build(reader) + } + + pub fn buffer(&self) -> &[u8] { + &self.buf[..] + } + + #[cfg(test)] + pub(crate) fn buffer_mut(&mut self) -> &mut BytesMut { + &mut self.buf + } + + /// Remaining capacity in internal buffer + pub fn remaining_mut(&self) -> usize { + self.buf.capacity() - self.buf.len() + } + + fn poll_force_fill_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<usize>> { + // Ensure that buffer has space for at least reserved_buf_size bytes + if self.remaining_mut() < self.reserved_buf_size { + let me = self.as_mut().project(); + me.buf.reserve(*me.reserved_buf_size); + } + let me = self.project(); + let n = { + let dst = me.buf.spare_capacity_mut(); + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(me.inner.poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // SAFETY: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + me.buf.advance_mut(n); + } + Poll::Ready(Ok(n)) + } +} + +impl<R> NixReader<R> +where + R: AsyncReadExt + Unpin, +{ + async fn force_fill(&mut self) -> io::Result<usize> { + let mut p = Pin::new(self); + let read = poll_fn(|cx| p.as_mut().poll_force_fill_buf(cx)).await?; + Ok(read) + } +} + +impl<R> NixRead for NixReader<R> +where + R: AsyncReadExt + Send + Unpin, +{ + type Error = io::Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> { + let mut buf = [0u8; 8]; + let read = self.read_buf(&mut &mut buf[..]).await?; + if read == 0 { + return Ok(None); + } + if read < 8 { + self.read_exact(&mut buf[read..]).await?; + } + let num = Buf::get_u64_le(&mut &buf[..]); + Ok(Some(num)) + } + + async fn try_read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> Result<Option<Bytes>, Self::Error> { + assert!( + *limit.end() <= self.max_buf_size, + "The limit must be smaller than {}", + self.max_buf_size + ); + match self.try_read_number().await? { + Some(raw_len) => { + // Check that length is in range and convert to usize + let len = raw_len + .try_into() + .ok() + .filter(|v| limit.contains(v)) + .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))?; + + // Calculate 64bit aligned length and convert to usize + let aligned: usize = raw_len + .checked_add(7) + .map(|v| v & !7) + .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))? + .try_into() + .map_err(Self::Error::invalid_data)?; + + // Ensure that there is enough space in buffer for contents + if self.buf.len() + self.remaining_mut() < aligned { + self.buf.reserve(aligned - self.buf.len()); + } + while self.buf.len() < aligned { + if self.force_fill().await? == 0 { + return Err(Self::Error::missing_data( + "unexpected end-of-file reading bytes", + )); + } + } + let mut contents = self.buf.split_to(aligned); + + let padding = aligned - len; + // Ensure padding is all zeros + if contents[len..] != EMPTY_BYTES[..padding] { + return Err(Self::Error::invalid_data("non-zero padding")); + } + + contents.truncate(len); + Ok(Some(contents.freeze())) + } + None => Ok(None), + } + } + + fn try_read_bytes( + &mut self, + ) -> impl std::future::Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + self.try_read_bytes_limited(0..=self.max_buf_size) + } + + fn read_bytes( + &mut self, + ) -> impl std::future::Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + self.read_bytes_limited(0..=self.max_buf_size) + } +} + +impl<R: AsyncRead> AsyncRead for NixReader<R> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let rem = ready!(self.as_mut().poll_fill_buf(cx))?; + let amt = std::cmp::min(rem.len(), buf.remaining()); + buf.put_slice(&rem[0..amt]); + self.consume(amt); + Poll::Ready(Ok(())) + } +} + +impl<R: AsyncRead> AsyncBufRead for NixReader<R> { + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + if self.as_ref().project_ref().buf.is_empty() { + ready!(self.as_mut().poll_force_fill_buf(cx))?; + } + let me = self.project(); + Poll::Ready(Ok(&me.buf[..])) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let me = self.project(); + me.buf.advance(amt) + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use super::*; + use crate::nix_daemon::de::NixRead; + + #[tokio::test] + async fn test_read_u64() { + let mock = Builder::new().read(&hex!("0100 0000 0000 0000")).build(); + let mut reader = NixReader::new(mock); + + assert_eq!(1, reader.read_number().await.unwrap()); + assert_eq!(hex!(""), reader.buffer()); + + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(hex!(""), &buf[..]); + } + + #[tokio::test] + async fn test_read_u64_rest() { + let mock = Builder::new() + .read(&hex!("0100 0000 0000 0000 0123 4567 89AB CDEF")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!(1, reader.read_number().await.unwrap()); + assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer()); + + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(hex!("0123 4567 89AB CDEF"), &buf[..]); + } + + #[tokio::test] + async fn test_read_u64_partial() { + let mock = Builder::new() + .read(&hex!("0100 0000")) + .wait(Duration::ZERO) + .read(&hex!("0000 0000 0123 4567 89AB CDEF")) + .wait(Duration::ZERO) + .read(&hex!("0100 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!(1, reader.read_number().await.unwrap()); + assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer()); + + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(hex!("0123 4567 89AB CDEF 0100 0000"), &buf[..]); + } + + #[tokio::test] + async fn test_read_u64_eof() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.read_number().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_u64_none() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!(None, reader.try_read_number().await.unwrap()); + } + + #[tokio::test] + async fn test_try_read_u64_eof() { + let mock = Builder::new().read(&hex!("0100 0000 0000")).build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_number().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_u64_eof2() { + let mock = Builder::new() + .read(&hex!("0100")) + .wait(Duration::ZERO) + .read(&hex!("0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_number().await.unwrap_err().kind() + ); + } + + #[rstest] + #[case::empty(b"", &hex!("0000 0000 0000 0000"))] + #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_read_bytes(#[case] expected: &[u8], #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual = reader.read_bytes().await.unwrap(); + assert_eq!(&actual[..], expected); + } + + #[tokio::test] + async fn test_read_bytes_empty() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_bytes_none() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!(None, reader.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn test_try_read_bytes_missing_data() { + let mock = Builder::new() + .read(&hex!("0500")) + .wait(Duration::ZERO) + .read(&hex!("0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_bytes_missing_padding() { + let mock = Builder::new() + .read(&hex!("0200 0000 0000 0000")) + .wait(Duration::ZERO) + .read(&hex!("1234")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_read_bytes_bad_padding() { + let mock = Builder::new() + .read(&hex!("0200 0000 0000 0000")) + .wait(Duration::ZERO) + .read(&hex!("1234 0100 0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::InvalidData, + reader.read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_read_bytes_limited_out_of_range() { + let mock = Builder::new().read(&hex!("FFFF 0000 0000 0000")).build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::InvalidData, + reader.read_bytes_limited(0..=50).await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_read_bytes_length_overflow() { + let mock = Builder::new().read(&hex!("F9FF FFFF FFFF FFFF")).build(); + let mut reader = NixReader::builder() + .set_max_buf_size(usize::MAX) + .build(mock); + + assert_eq!( + io::ErrorKind::InvalidData, + reader + .read_bytes_limited(0..=usize::MAX) + .await + .unwrap_err() + .kind() + ); + } + + // FUTUREWORK: Test this on supported hardware + #[tokio::test] + #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] + async fn test_bytes_length_conversion_overflow() { + let len = (usize::MAX as u64) + 1; + let mock = Builder::new().read(&len.to_le_bytes()).build(); + let mut reader = NixReader::new(mock); + assert_eq!( + std::io::ErrorKind::InvalidData, + reader.read_value::<usize>().await.unwrap_err().kind() + ); + } + + // FUTUREWORK: Test this on supported hardware + #[tokio::test] + #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] + async fn test_bytes_aligned_length_conversion_overflow() { + let len = (usize::MAX - 6) as u64; + let mock = Builder::new().read(&len.to_le_bytes()).build(); + let mut reader = NixReader::new(mock); + assert_eq!( + std::io::ErrorKind::InvalidData, + reader.read_value::<usize>().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_buffer_resize() { + let mock = Builder::new() + .read(&hex!("0100")) + .read(&hex!("0000 0000 0000")) + .build(); + let mut reader = NixReader::builder().set_reserved_buf_size(8).build(mock); + // buffer has no capacity initially + assert_eq!(0, reader.buffer_mut().capacity()); + + assert_eq!(2, reader.force_fill().await.unwrap()); + + // After first read buffer should have capacity we chose + assert_eq!(8, reader.buffer_mut().capacity()); + + // Because there was only 6 bytes remaining in buffer, + // which is enough to read the last 6 bytes, but we require + // capacity for 8 bytes, it doubled the capacity + assert_eq!(6, reader.force_fill().await.unwrap()); + assert_eq!(16, reader.buffer_mut().capacity()); + + assert_eq!(1, reader.read_number().await.unwrap()); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/mod.rs b/tvix/nix-compat/src/nix_daemon/mod.rs index fe652377d1b4..11413e85fd1b 100644 --- a/tvix/nix-compat/src/nix_daemon/mod.rs +++ b/tvix/nix-compat/src/nix_daemon/mod.rs @@ -2,3 +2,5 @@ pub mod worker_protocol; mod protocol_version; pub use protocol_version::ProtocolVersion; + +pub mod de; diff --git a/tvix/nix-compat/src/nix_daemon/protocol_version.rs b/tvix/nix-compat/src/nix_daemon/protocol_version.rs index 8fd2b085c962..3c8fe663e867 100644 --- a/tvix/nix-compat/src/nix_daemon/protocol_version.rs +++ b/tvix/nix-compat/src/nix_daemon/protocol_version.rs @@ -1,3 +1,6 @@ +/// The latest version that is currently supported by nix-compat. +static DEFAULT_PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::from_parts(1, 37); + /// Protocol versions are represented as a u16. /// The upper 8 bits are the major version, the lower bits the minor. /// This is not aware of any endianness, use [crate::wire::read_u64] to get an @@ -20,6 +23,12 @@ impl ProtocolVersion { } } +impl Default for ProtocolVersion { + fn default() -> Self { + DEFAULT_PROTOCOL_VERSION + } +} + impl PartialOrd for ProtocolVersion { fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { Some(self.cmp(other)) diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs index 47bfb5eabacf..33c5d7d171b8 100644 --- a/tvix/nix-compat/src/wire/bytes/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/mod.rs @@ -11,7 +11,7 @@ mod writer; pub use writer::BytesWriter; /// 8 null bytes, used to write out padding. -const EMPTY_BYTES: &[u8; 8] = &[0u8; 8]; +pub(crate) const EMPTY_BYTES: &[u8; 8] = &[0u8; 8]; /// The length of the size field, in bytes is always 8. const LEN_SIZE: usize = 8; |