diff options
author | Brian Olsen <brian@maven-group.org> | 2024-07-20T09·25+0200 |
---|---|---|
committer | clbot <clbot@tvl.fyi> | 2024-08-25T15·05+0000 |
commit | 9af69204787d47cfe551f524d01b1a726971f06e (patch) | |
tree | ab8f2b90b2624c0f73262899019a3f1098fb0bdb /tvix/nix-compat/src/nix_daemon/de/mock.rs | |
parent | a774cb8c10ea976bcdff2e296e3cefc6adbc21d3 (diff) |
feat(nix-compat): Add NixDeserialize and NixRead traits r/8585
Add a trait for deserializing a type from a daemon worker connection. This adds the NixDeserialize trait which is kind of like the serde Deserialize trait in that individual types are meant to implement it and it can potentially be derived in the future. The NixDeserialize trait takes something that implements NixRead as input so that you can among other things mock the reader. Change-Id: Ibb59e3562dfc822652f7d18039f00a1c0d422997 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11990 Autosubmit: Brian Olsen <me@griff.name> Reviewed-by: flokli <flokli@flokli.de> Tested-by: BuildkiteCI
Diffstat (limited to 'tvix/nix-compat/src/nix_daemon/de/mock.rs')
-rw-r--r-- | tvix/nix-compat/src/nix_daemon/de/mock.rs | 261 |
1 files changed, 261 insertions, 0 deletions
diff --git a/tvix/nix-compat/src/nix_daemon/de/mock.rs b/tvix/nix-compat/src/nix_daemon/de/mock.rs new file mode 100644 index 000000000000..31cc3a4897ba --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/de/mock.rs @@ -0,0 +1,261 @@ +use std::collections::VecDeque; +use std::fmt; +use std::io; +use std::thread; + +use bytes::Bytes; +use thiserror::Error; + +use crate::nix_daemon::ProtocolVersion; + +use super::NixRead; + +#[derive(Debug, Error, PartialEq, Eq, Clone)] +pub enum Error { + #[error("custom error '{0}'")] + Custom(String), + #[error("invalid data '{0}'")] + InvalidData(String), + #[error("missing data '{0}'")] + MissingData(String), + #[error("IO error {0} '{1}'")] + IO(io::ErrorKind, String), + #[error("wrong read: expected {0} got {1}")] + WrongRead(OperationType, OperationType), +} + +impl Error { + pub fn expected_read_number() -> Error { + Error::WrongRead(OperationType::ReadNumber, OperationType::ReadBytes) + } + + pub fn expected_read_bytes() -> Error { + Error::WrongRead(OperationType::ReadBytes, OperationType::ReadNumber) + } +} + +impl super::Error for Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + Self::Custom(msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + Self::IO(err.kind(), err.to_string()) + } + + fn invalid_data<T: fmt::Display>(msg: T) -> Self { + Self::InvalidData(msg.to_string()) + } + + fn missing_data<T: fmt::Display>(msg: T) -> Self { + Self::MissingData(msg.to_string()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OperationType { + ReadNumber, + ReadBytes, +} + +impl fmt::Display for OperationType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::ReadNumber => write!(f, "read_number"), + Self::ReadBytes => write!(f, "read_bytess"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum Operation { + ReadNumber(Result<u64, Error>), + ReadBytes(Result<Bytes, Error>), +} + +impl From<Operation> for OperationType { + fn from(value: Operation) -> Self { + match value { + Operation::ReadNumber(_) => OperationType::ReadNumber, + Operation::ReadBytes(_) => OperationType::ReadBytes, + } + } +} + +pub struct Builder { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Builder { + pub fn new() -> Builder { + Builder { + version: Default::default(), + ops: VecDeque::new(), + } + } + + pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self { + self.version = version.into(); + self + } + + pub fn read_number(&mut self, value: u64) -> &mut Self { + self.ops.push_back(Operation::ReadNumber(Ok(value))); + self + } + + pub fn read_number_error(&mut self, err: Error) -> &mut Self { + self.ops.push_back(Operation::ReadNumber(Err(err))); + self + } + + pub fn read_bytes(&mut self, value: Bytes) -> &mut Self { + self.ops.push_back(Operation::ReadBytes(Ok(value))); + self + } + + pub fn read_slice(&mut self, data: &[u8]) -> &mut Self { + let value = Bytes::copy_from_slice(data); + self.ops.push_back(Operation::ReadBytes(Ok(value))); + self + } + + pub fn read_bytes_error(&mut self, err: Error) -> &mut Self { + self.ops.push_back(Operation::ReadBytes(Err(err))); + self + } + + pub fn build(&mut self) -> Mock { + Mock { + version: self.version, + ops: self.ops.clone(), + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +pub struct Mock { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl NixRead for Mock { + type Error = Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> { + match self.ops.pop_front() { + Some(Operation::ReadNumber(ret)) => ret.map(Some), + Some(Operation::ReadBytes(_)) => Err(Error::expected_read_bytes()), + None => Ok(None), + } + } + + async fn try_read_bytes_limited( + &mut self, + _limit: std::ops::RangeInclusive<usize>, + ) -> Result<Option<Bytes>, Self::Error> { + match self.ops.pop_front() { + Some(Operation::ReadBytes(ret)) => ret.map(Some), + Some(Operation::ReadNumber(_)) => Err(Error::expected_read_number()), + None => Ok(None), + } + } +} + +impl Drop for Mock { + fn drop(&mut self) { + // No need to panic again + if thread::panicking() { + return; + } + if let Some(op) = self.ops.front() { + panic!("reader dropped with {op:?} operation still unread") + } + } +} + +#[cfg(test)] +mod test { + use bytes::Bytes; + use hex_literal::hex; + + use crate::nix_daemon::de::NixRead; + + use super::{Builder, Error}; + + #[tokio::test] + async fn read_slice() { + let mut mock = Builder::new() + .read_number(10) + .read_slice(&[]) + .read_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]); + assert_eq!( + &hex!("0000 1234 5678 9ABC DEFF"), + &mock.read_bytes().await.unwrap()[..] + ); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn read_bytes() { + let mut mock = Builder::new() + .read_number(10) + .read_bytes(Bytes::from_static(&[])) + .read_bytes(Bytes::from_static(&hex!("0000 1234 5678 9ABC DEFF"))) + .build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]); + assert_eq!( + &hex!("0000 1234 5678 9ABC DEFF"), + &mock.read_bytes().await.unwrap()[..] + ); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn read_number() { + let mut mock = Builder::new().read_number(10).build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn expect_number() { + let mut mock = Builder::new().read_number(10).build(); + assert_eq!( + Error::expected_read_number(), + mock.read_bytes().await.unwrap_err() + ); + } + + #[tokio::test] + async fn expect_bytes() { + let mut mock = Builder::new().read_slice(&[]).build(); + assert_eq!( + Error::expected_read_bytes(), + mock.read_number().await.unwrap_err() + ); + } + + #[test] + #[should_panic] + fn operations_left() { + let _ = Builder::new().read_number(10).build(); + } +} |