diff options
Diffstat (limited to 'tvix/nix-compat')
34 files changed, 2481 insertions, 374 deletions
diff --git a/tvix/nix-compat/Cargo.toml b/tvix/nix-compat/Cargo.toml index 58137e4de2e1..160eb2c20c16 100644 --- a/tvix/nix-compat/Cargo.toml +++ b/tvix/nix-compat/Cargo.toml @@ -8,10 +8,13 @@ edition = "2021" async = ["tokio"] # code emitting low-level packets used in the daemon protocol. wire = ["tokio", "pin-project-lite", "bytes"] + +# nix-daemon protocol handling +daemon = ["tokio", "nix-compat-derive", "futures"] test = [] # Enable all features by default. -default = ["async", "wire", "nix-compat-derive"] +default = ["async", "daemon", "wire", "nix-compat-derive"] [dependencies] bitflags = { workspace = true } @@ -20,6 +23,7 @@ data-encoding = { workspace = true } ed25519 = { workspace = true } ed25519-dalek = { workspace = true } enum-primitive-derive = { workspace = true } +futures = { workspace = true, optional = true } glob = { workspace = true } mimalloc = { workspace = true } nom = { workspace = true } @@ -30,8 +34,9 @@ sha2 = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } bytes = { workspace = true, optional = true } -tokio = { workspace = true, features = ["io-util", "macros"], optional = true } +tokio = { workspace = true, features = ["io-util", "macros", "sync"], optional = true } pin-project-lite = { workspace = true, optional = true } +num_enum = "0.7.3" [dependencies.nix-compat-derive] path = "../nix-compat-derive" @@ -41,9 +46,9 @@ optional = true criterion = { workspace = true, features = ["html_reports"] } futures = { workspace = true } hex-literal = { workspace = true } -lazy_static = { workspace = true } mimalloc = { workspace = true } pretty_assertions = { workspace = true } +proptest = { workspace = true, features = ["std", "alloc", "tempfile"] } rstest = { workspace = true } serde_json = { workspace = true } smol_str = { workspace = true } diff --git a/tvix/nix-compat/benches/narinfo_parse.rs b/tvix/nix-compat/benches/narinfo_parse.rs index f35ba8468a88..ee2665a310d3 100644 --- a/tvix/nix-compat/benches/narinfo_parse.rs +++ b/tvix/nix-compat/benches/narinfo_parse.rs @@ -1,8 +1,9 @@ +use std::sync::LazyLock; +use std::{io, str}; + use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; -use lazy_static::lazy_static; use mimalloc::MiMalloc; use nix_compat::narinfo::NarInfo; -use std::{io, str}; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; @@ -19,18 +20,16 @@ Deriver: 2ch8jx910qk6721mp4yqsmvdfgj5c8ir-banking-0.3.0.drv Sig: cache.nixos.org-1:xcL67rBZPcdVZudDLpLeddkBa0KaFTw5A0udnaa0axysjrQ6Nvd9p3BLZ4rhKgl52/cKiU3c6aq60L8+IcE5Dw== "#; -lazy_static! { - static ref CASES: &'static [&'static str] = { - let data = - zstd::decode_all(io::Cursor::new(include_bytes!("../testdata/narinfo.zst"))).unwrap(); - let data = str::from_utf8(Vec::leak(data)).unwrap(); - Vec::leak( - data.split_inclusive("\n\n") - .map(|s| s.strip_suffix('\n').unwrap()) - .collect::<Vec<_>>(), - ) - }; -} +static CASES: LazyLock<&'static [&'static str]> = LazyLock::new(|| { + let data = + zstd::decode_all(io::Cursor::new(include_bytes!("../testdata/narinfo.zst"))).unwrap(); + let data = str::from_utf8(Vec::leak(data)).unwrap(); + Vec::leak( + data.split_inclusive("\n\n") + .map(|s| s.strip_suffix('\n').unwrap()) + .collect::<Vec<_>>(), + ) +}); pub fn parse(c: &mut Criterion) { let mut g = c.benchmark_group("parse"); diff --git a/tvix/nix-compat/src/derivation/mod.rs b/tvix/nix-compat/src/derivation/mod.rs index 6baeaba38299..445e9cb43143 100644 --- a/tvix/nix-compat/src/derivation/mod.rs +++ b/tvix/nix-compat/src/derivation/mod.rs @@ -257,8 +257,8 @@ impl Derivation { // For fixed output derivation we use [build_ca_path], otherwise we // use [build_output_path] with [hash_derivation_modulo]. - let abs_store_path = if let Some(ref hwm) = output.ca_hash { - build_ca_path(&path_name, hwm, Vec::<String>::new(), false).map_err(|e| { + let store_path = if let Some(ref hwm) = output.ca_hash { + build_ca_path(&path_name, hwm, Vec::<&str>::new(), false).map_err(|e| { DerivationError::InvalidOutputDerivationPath(output_name.to_string(), e) })? } else { @@ -270,11 +270,11 @@ impl Derivation { })? }; - output.path = Some(abs_store_path.to_owned()); self.environment.insert( output_name.to_string(), - abs_store_path.to_absolute_path().into(), + store_path.to_absolute_path().into(), ); + output.path = Some(store_path); } Ok(()) diff --git a/tvix/nix-compat/src/derivation/parser.rs b/tvix/nix-compat/src/derivation/parser.rs index 4fff7181ba40..a94ed2281a86 100644 --- a/tvix/nix-compat/src/derivation/parser.rs +++ b/tvix/nix-compat/src/derivation/parser.rs @@ -203,11 +203,7 @@ fn string_to_store_path<'a, 'i, S>( path_str: &'a str, ) -> Result<StorePath<S>, nom::Err<NomError<&'i [u8]>>> where - S: std::cmp::Eq - + std::fmt::Display - + std::clone::Clone - + std::ops::Deref<Target = str> - + std::convert::From<&'a str>, + S: std::clone::Clone + AsRef<str> + std::convert::From<&'a str>, { let path = StorePath::from_absolute_path(path_str.as_bytes()).map_err(|e: store_path::Error| { @@ -333,6 +329,7 @@ where mod tests { use crate::store_path::StorePathRef; use std::collections::{BTreeMap, BTreeSet}; + use std::sync::LazyLock; use crate::{ derivation::{ @@ -342,49 +339,48 @@ mod tests { }; use bstr::{BString, ByteSlice}; use hex_literal::hex; - use lazy_static::lazy_static; use rstest::rstest; const DIGEST_SHA256: [u8; 32] = hex!("a5ce9c155ed09397614646c9717fc7cd94b1023d7b76b618d409e4fefd6e9d39"); - lazy_static! { - pub static ref NIXHASH_SHA256: NixHash = NixHash::Sha256(DIGEST_SHA256); - static ref EXP_MULTI_OUTPUTS: BTreeMap<String, Output> = { - let mut b = BTreeMap::new(); - b.insert( - "lib".to_string(), - Output { - path: Some( - StorePath::from_bytes( - b"2vixb94v0hy2xc6p7mbnxxcyc095yyia-has-multi-out-lib", - ) - .unwrap(), - ), - ca_hash: None, - }, - ); - b.insert( - "out".to_string(), - Output { - path: Some( - StorePath::from_bytes( - b"55lwldka5nyxa08wnvlizyqw02ihy8ic-has-multi-out".as_bytes(), - ) + static NIXHASH_SHA256: NixHash = NixHash::Sha256(DIGEST_SHA256); + static EXP_MULTI_OUTPUTS: LazyLock<BTreeMap<String, Output>> = LazyLock::new(|| { + let mut b = BTreeMap::new(); + b.insert( + "lib".to_string(), + Output { + path: Some( + StorePath::from_bytes(b"2vixb94v0hy2xc6p7mbnxxcyc095yyia-has-multi-out-lib") .unwrap(), - ), - ca_hash: None, - }, - ); - b - }; - static ref EXP_AB_MAP: BTreeMap<String, BString> = { - let mut b = BTreeMap::new(); - b.insert("a".to_string(), b"1".into()); - b.insert("b".to_string(), b"2".into()); - b - }; - static ref EXP_INPUT_DERIVATIONS_SIMPLE: BTreeMap<StorePath<String>, BTreeSet<String>> = { + ), + ca_hash: None, + }, + ); + b.insert( + "out".to_string(), + Output { + path: Some( + StorePath::from_bytes( + b"55lwldka5nyxa08wnvlizyqw02ihy8ic-has-multi-out".as_bytes(), + ) + .unwrap(), + ), + ca_hash: None, + }, + ); + b + }); + + static EXP_AB_MAP: LazyLock<BTreeMap<String, BString>> = LazyLock::new(|| { + let mut b = BTreeMap::new(); + b.insert("a".to_string(), b"1".into()); + b.insert("b".to_string(), b"2".into()); + b + }); + + static EXP_INPUT_DERIVATIONS_SIMPLE: LazyLock<BTreeMap<StorePath<String>, BTreeSet<String>>> = + LazyLock::new(|| { let mut b = BTreeMap::new(); b.insert( StorePath::from_bytes(b"8bjm87p310sb7r2r0sg4xrynlvg86j8k-hello-2.12.1.tar.gz.drv") @@ -406,21 +402,22 @@ mod tests { }, ); b - }; - static ref EXP_INPUT_DERIVATIONS_SIMPLE_ATERM: String = { - format!( - "[(\"{0}\",[\"out\"]),(\"{1}\",[\"out\",\"lib\"])]", - "/nix/store/8bjm87p310sb7r2r0sg4xrynlvg86j8k-hello-2.12.1.tar.gz.drv", - "/nix/store/p3jc8aw45dza6h52v81j7lk69khckmcj-bash-5.2-p15.drv" - ) - }; - static ref EXP_INPUT_SOURCES_SIMPLE: BTreeSet<String> = { - let mut b = BTreeSet::new(); - b.insert("/nix/store/55lwldka5nyxa08wnvlizyqw02ihy8ic-has-multi-out".to_string()); - b.insert("/nix/store/2vixb94v0hy2xc6p7mbnxxcyc095yyia-has-multi-out-lib".to_string()); - b - }; - } + }); + + static EXP_INPUT_DERIVATIONS_SIMPLE_ATERM: LazyLock<String> = LazyLock::new(|| { + format!( + "[(\"{0}\",[\"out\"]),(\"{1}\",[\"out\",\"lib\"])]", + "/nix/store/8bjm87p310sb7r2r0sg4xrynlvg86j8k-hello-2.12.1.tar.gz.drv", + "/nix/store/p3jc8aw45dza6h52v81j7lk69khckmcj-bash-5.2-p15.drv" + ) + }); + + static EXP_INPUT_SOURCES_SIMPLE: LazyLock<BTreeSet<String>> = LazyLock::new(|| { + let mut b = BTreeSet::new(); + b.insert("/nix/store/55lwldka5nyxa08wnvlizyqw02ihy8ic-has-multi-out".to_string()); + b.insert("/nix/store/2vixb94v0hy2xc6p7mbnxxcyc095yyia-has-multi-out-lib".to_string()); + b + }); /// Ensure parsing KVs works #[rstest] diff --git a/tvix/nix-compat/src/derivation/write.rs b/tvix/nix-compat/src/derivation/write.rs index 42dadcd76064..a8b43fad4cc6 100644 --- a/tvix/nix-compat/src/derivation/write.rs +++ b/tvix/nix-compat/src/derivation/write.rs @@ -36,14 +36,14 @@ pub(crate) trait AtermWriteable { impl<S> AtermWriteable for StorePath<S> where - S: std::cmp::Eq + std::ops::Deref<Target = str>, + S: AsRef<str>, { fn aterm_write(&self, writer: &mut impl Write) -> std::io::Result<()> { write_char(writer, QUOTE)?; writer.write_all(STORE_DIR_WITH_SLASH.as_bytes())?; writer.write_all(nixbase32::encode(self.digest()).as_bytes())?; write_char(writer, '-')?; - writer.write_all(self.name().as_bytes())?; + writer.write_all(self.name().as_ref().as_bytes())?; write_char(writer, QUOTE)?; Ok(()) } diff --git a/tvix/nix-compat/src/lib.rs b/tvix/nix-compat/src/lib.rs index f30c557889a8..4c327fa4569b 100644 --- a/tvix/nix-compat/src/lib.rs +++ b/tvix/nix-compat/src/lib.rs @@ -14,9 +14,7 @@ pub mod store_path; #[cfg(feature = "wire")] pub mod wire; -#[cfg(feature = "wire")] +#[cfg(feature = "daemon")] pub mod nix_daemon; -#[cfg(feature = "wire")] +#[cfg(feature = "daemon")] pub use nix_daemon::worker_protocol; -#[cfg(feature = "wire")] -pub use nix_daemon::ProtocolVersion; diff --git a/tvix/nix-compat/src/nar/wire/mod.rs b/tvix/nix-compat/src/nar/wire/mod.rs index ddf021bc1fa1..67654129ee1d 100644 --- a/tvix/nix-compat/src/nar/wire/mod.rs +++ b/tvix/nix-compat/src/nar/wire/mod.rs @@ -91,9 +91,11 @@ pub const TOK_ENT: [u8; 48] = *b"\x05\0\0\0\0\0\0\0entry\0\0\0\x01\0\0\0\0\0\0\0 pub const TOK_NOD: [u8; 48] = *b"\x04\0\0\0\0\0\0\0node\0\0\0\0\x01\0\0\0\0\0\0\0(\0\0\0\0\0\0\0\x04\0\0\0\0\0\0\0type\0\0\0\0"; pub const TOK_PAR: [u8; 16] = *b"\x01\0\0\0\0\0\0\0)\0\0\0\0\0\0\0"; #[cfg(feature = "async")] +#[allow(dead_code)] const TOK_PAD_PAR: [u8; 24] = *b"\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0)\0\0\0\0\0\0\0"; #[cfg(feature = "async")] +#[allow(dead_code)] #[derive(Debug)] pub(crate) enum PadPar {} diff --git a/tvix/nix-compat/src/narinfo/mod.rs b/tvix/nix-compat/src/narinfo/mod.rs index 21aecf80b5a2..35146a927b39 100644 --- a/tvix/nix-compat/src/narinfo/mod.rs +++ b/tvix/nix-compat/src/narinfo/mod.rs @@ -417,8 +417,8 @@ const DUMMY_VERIFYING_KEY: &str = #[cfg(test)] mod test { use hex_literal::hex; - use lazy_static::lazy_static; use pretty_assertions::assert_eq; + use std::sync::LazyLock; use std::{io, str}; use crate::{ @@ -428,20 +428,18 @@ mod test { use super::{Flags, NarInfo}; - lazy_static! { - static ref CASES: &'static [&'static str] = { - let data = zstd::decode_all(io::Cursor::new(include_bytes!( - "../../testdata/narinfo.zst" - ))) - .unwrap(); - let data = str::from_utf8(Vec::leak(data)).unwrap(); - Vec::leak( - data.split_inclusive("\n\n") - .map(|s| s.strip_suffix('\n').unwrap()) - .collect::<Vec<_>>(), - ) - }; - } + static CASES: LazyLock<&'static [&'static str]> = LazyLock::new(|| { + let data = zstd::decode_all(io::Cursor::new(include_bytes!( + "../../testdata/narinfo.zst" + ))) + .unwrap(); + let data = str::from_utf8(Vec::leak(data)).unwrap(); + Vec::leak( + data.split_inclusive("\n\n") + .map(|s| s.strip_suffix('\n').unwrap()) + .collect::<Vec<_>>(), + ) + }); #[test] fn roundtrip() { diff --git a/tvix/nix-compat/src/narinfo/signature.rs b/tvix/nix-compat/src/narinfo/signature.rs index 33c49128c2d5..2005a5cb60df 100644 --- a/tvix/nix-compat/src/narinfo/signature.rs +++ b/tvix/nix-compat/src/narinfo/signature.rs @@ -92,6 +92,12 @@ where bytes: self.bytes, } } + pub fn to_owned(&self) -> Signature<String> { + Signature { + name: self.name.to_string(), + bytes: self.bytes, + } + } } impl<'a, 'de, S> Deserialize<'de> for Signature<S> @@ -133,7 +139,17 @@ where } } -#[derive(Debug, thiserror::Error)] +impl<S> std::hash::Hash for Signature<S> +where + S: AsRef<str>, +{ + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + state.write(self.name.as_ref().as_bytes()); + state.write(&self.bytes); + } +} + +#[derive(Debug, thiserror::Error, PartialEq, Eq)] pub enum Error { #[error("Invalid name: {0}")] InvalidName(String), @@ -150,32 +166,24 @@ mod test { use data_encoding::BASE64; use ed25519_dalek::VerifyingKey; use hex_literal::hex; - use lazy_static::lazy_static; + use std::sync::LazyLock; use super::Signature; use rstest::rstest; const FINGERPRINT: &str = "1;/nix/store/syd87l2rxw8cbsxmxl853h0r6pdwhwjr-curl-7.82.0-bin;sha256:1b4sb93wp679q4zx9k1ignby1yna3z7c4c2ri3wphylbc2dwsys0;196040;/nix/store/0jqd0rlxzra1rs38rdxl43yh6rxchgc6-curl-7.82.0,/nix/store/6w8g7njm4mck5dmjxws0z1xnrxvl81xa-glibc-2.34-115,/nix/store/j5jxw3iy7bbz4a57fh9g2xm2gxmyal8h-zlib-1.2.12,/nix/store/yxvjs9drzsphm9pcf42a4byzj1kb9m7k-openssl-1.1.1n"; - // The signing key labelled as `cache.nixos.org-1`, - lazy_static! { - static ref PUB_CACHE_NIXOS_ORG_1: VerifyingKey = ed25519_dalek::VerifyingKey::from_bytes( + /// The signing key labelled as `cache.nixos.org-1`, + static PUB_CACHE_NIXOS_ORG_1: LazyLock<VerifyingKey> = LazyLock::new(|| { + ed25519_dalek::VerifyingKey::from_bytes( BASE64 .decode(b"6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY=") .unwrap()[..] .try_into() - .unwrap() - ) - .unwrap(); - static ref PUB_TEST_1: VerifyingKey = ed25519_dalek::VerifyingKey::from_bytes( - BASE64 - .decode(b"tLAEn+EeaBUJYqEpTd2yeerr7Ic6+0vWe+aXL/vYUpE=") - .unwrap()[..] - .try_into() - .unwrap() + .unwrap(), ) - .unwrap(); - } + .expect("embedded public key is valid") + }); #[rstest] #[case::valid_cache_nixos_org_1(&PUB_CACHE_NIXOS_ORG_1, &"cache.nixos.org-1:TsTTb3WGTZKphvYdBHXwo6weVILmTytUjLB+vcX89fOjjRicCHmKA4RCPMVLkj6TMJ4GMX3HPVWRdD1hkeKZBQ==", FINGERPRINT, true)] diff --git a/tvix/nix-compat/src/nix_daemon/handler.rs b/tvix/nix-compat/src/nix_daemon/handler.rs new file mode 100644 index 000000000000..65c5c2d60d08 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/handler.rs @@ -0,0 +1,294 @@ +use std::{future::Future, sync::Arc}; + +use bytes::Bytes; +use tokio::{ + io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, + sync::Mutex, +}; +use tracing::{debug, warn}; + +use super::{ + types::QueryValidPaths, + worker_protocol::{server_handshake_client, ClientSettings, Operation, Trust, STDERR_LAST}, + NixDaemonIO, +}; + +use crate::{ + store_path::StorePath, + wire::{ + de::{NixRead, NixReader}, + ser::{NixSerialize, NixWrite, NixWriter, NixWriterBuilder}, + ProtocolVersion, + }, +}; + +use crate::{nix_daemon::types::NixError, worker_protocol::STDERR_ERROR}; + +/// Handles a single connection with a nix client. +/// +/// As part of its [`initialization`] it performs the handshake with the client +/// and determines the [ProtocolVersion] and [ClientSettings] to use for the remainder of the session. +/// +/// Once initialized, [`handle_client`] needs to be called to handle the rest of the session, +/// it delegates all operation handling to an instance of [NixDaemonIO]. +/// +/// [`initialization`]: NixDaemon::initialize +#[allow(dead_code)] +pub struct NixDaemon<IO, R, W> { + io: Arc<IO>, + protocol_version: ProtocolVersion, + client_settings: ClientSettings, + reader: NixReader<R>, + writer: Arc<Mutex<NixWriter<W>>>, +} + +impl<IO, R, W> NixDaemon<IO, R, W> +where + IO: NixDaemonIO + Sync + Send, +{ + pub fn new( + io: Arc<IO>, + protocol_version: ProtocolVersion, + client_settings: ClientSettings, + reader: NixReader<R>, + writer: NixWriter<W>, + ) -> Self { + Self { + io, + protocol_version, + client_settings, + reader, + writer: Arc::new(Mutex::new(writer)), + } + } +} + +impl<IO, RW> NixDaemon<IO, ReadHalf<RW>, WriteHalf<RW>> +where + RW: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static, + IO: NixDaemonIO + Sync + Send, +{ + /// Async constructor for NixDaemon. + /// + /// Performs the initial handshake with the client and retrieves the client's preferred + /// settings. + /// + /// The resulting daemon can handle the client session by calling [NixDaemon::handle_client]. + pub async fn initialize(io: Arc<IO>, mut connection: RW) -> Result<Self, std::io::Error> + where + RW: AsyncReadExt + AsyncWriteExt + Send + Unpin, + { + let protocol_version = + server_handshake_client(&mut connection, "2.18.2", Trust::Trusted).await?; + + connection.write_u64_le(STDERR_LAST).await?; + let (reader, writer) = split(connection); + let mut reader = NixReader::builder() + .set_version(protocol_version) + .build(reader); + let mut writer = NixWriterBuilder::default() + .set_version(protocol_version) + .build(writer); + + // The first op is always SetOptions + let operation: Operation = reader.read_value().await?; + if operation != Operation::SetOptions { + return Err(std::io::Error::other( + "Expected SetOptions operation, but got {operation}", + )); + } + let client_settings: ClientSettings = reader.read_value().await?; + writer.write_number(STDERR_LAST).await?; + writer.flush().await?; + + Ok(Self::new( + io, + protocol_version, + client_settings, + reader, + writer, + )) + } + + /// Main client connection loop, reads client's requests and responds to them accordingly. + pub async fn handle_client(&mut self) -> Result<(), std::io::Error> { + let io = self.io.clone(); + loop { + let op_code = self.reader.read_number().await?; + match TryInto::<Operation>::try_into(op_code) { + // Note: please keep operations sorted in ascending order of their numerical op number. + Ok(operation) => match operation { + Operation::IsValidPath => { + let path: StorePath<String> = self.reader.read_value().await?; + self.handle(io.is_valid_path(&path)).await? + } + // Note this operation does not currently delegate to NixDaemonIO, + // The general idea is that we will pass relevant ClientSettings + // into individual NixDaemonIO method calls if the need arises. + // For now we just store the settings in the NixDaemon for future use. + Operation::SetOptions => { + self.client_settings = self.reader.read_value().await?; + self.handle(async { Ok(()) }).await? + } + Operation::QueryPathInfo => { + let path: StorePath<String> = self.reader.read_value().await?; + self.handle(io.query_path_info(&path)).await? + } + Operation::QueryPathFromHashPart => { + let hash: Bytes = self.reader.read_value().await?; + self.handle(io.query_path_from_hash_part(&hash)).await? + } + Operation::QueryValidPaths => { + let query: QueryValidPaths = self.reader.read_value().await?; + self.handle(io.query_valid_paths(&query)).await? + } + Operation::QueryValidDerivers => { + let path: StorePath<String> = self.reader.read_value().await?; + self.handle(io.query_valid_derivers(&path)).await? + } + // FUTUREWORK: These are just stubs that return an empty list. + // It's important not to return an error for the local-overlay:// store + // to work properly. While it will not see certain referrers and realizations + // it will not fail on various operations like gc and optimize store. At the + // same time, returning an empty list here shouldn't break any of local-overlay store's + // invariants. + Operation::QueryReferrers | Operation::QueryRealisation => { + let _: String = self.reader.read_value().await?; + self.handle(async move { + warn!( + ?operation, + "This operation is not implemented. Returning empty result..." + ); + Ok(Vec::<StorePath<String>>::new()) + }) + .await? + } + _ => { + return Err(std::io::Error::other(format!( + "Operation {operation:?} is not implemented" + ))); + } + }, + _ => { + return Err(std::io::Error::other(format!( + "Unknown operation code received: {op_code}" + ))); + } + } + } + } + + /// Handles the operation and sends the response or error to the client. + /// + /// As per nix daemon protocol, after sending the request, the client expects zero or more + /// log lines/activities followed by either + /// * STDERR_LAST and the response bytes + /// * STDERR_ERROR and the error + /// + /// This is a helper method, awaiting on the passed in future and then + /// handling log lines/activities as described above. + async fn handle<T>( + &mut self, + future: impl Future<Output = std::io::Result<T>>, + ) -> Result<(), std::io::Error> + where + T: NixSerialize + Send, + { + let result = future.await; + let mut writer = self.writer.lock().await; + + match result { + Ok(r) => { + // the protocol requires that we first indicate that we are done sending logs + // by sending STDERR_LAST and then the response. + writer.write_number(STDERR_LAST).await?; + writer.write_value(&r).await?; + writer.flush().await + } + Err(e) => { + debug!(err = ?e, "IO error"); + writer.write_number(STDERR_ERROR).await?; + writer.write_value(&NixError::new(format!("{e:?}"))).await?; + writer.flush().await + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{io::Result, sync::Arc}; + + use tokio::io::AsyncWriteExt; + + use crate::{ + nix_daemon::types::UnkeyedValidPathInfo, + wire::ProtocolVersion, + worker_protocol::{ClientSettings, WORKER_MAGIC_1, WORKER_MAGIC_2}, + }; + + struct MockDaemonIO {} + + impl NixDaemonIO for MockDaemonIO { + async fn query_path_info( + &self, + _path: &crate::store_path::StorePath<String>, + ) -> Result<Option<UnkeyedValidPathInfo>> { + Ok(None) + } + + async fn query_path_from_hash_part( + &self, + _hash: &[u8], + ) -> Result<Option<UnkeyedValidPathInfo>> { + Ok(None) + } + } + + #[tokio::test] + async fn test_daemon_initialization() { + let mut builder = tokio_test::io::Builder::new(); + let test_conn = builder + .read(&WORKER_MAGIC_1.to_le_bytes()) + .write(&WORKER_MAGIC_2.to_le_bytes()) + // Our version is 1.37 + .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // The client's versin is 1.35 + .read(&[35, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // cpu affinity + .read(&[0; 8]) + // reservespace + .read(&[0; 8]) + // version (size) + .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // version (data == 2.18.2 + padding) + .write(&[50, 46, 49, 56, 46, 50, 0, 0]) + // Trusted (1 == client trusted) + .write(&[1, 0, 0, 0, 0, 0, 0, 0]) + // STDERR_LAST + .write(&[115, 116, 108, 97, 0, 0, 0, 0]); + + let mut bytes = Vec::new(); + let mut writer = NixWriter::new(&mut bytes); + writer + .write_value(&ClientSettings::default()) + .await + .unwrap(); + writer.flush().await.unwrap(); + + let test_conn = test_conn + // SetOptions op + .read(&[19, 0, 0, 0, 0, 0, 0, 0]) + .read(&bytes) + // STDERR_LAST + .write(&[115, 116, 108, 97, 0, 0, 0, 0]) + .build(); + + let daemon = NixDaemon::initialize(Arc::new(MockDaemonIO {}), test_conn) + .await + .unwrap(); + assert_eq!(daemon.client_settings, ClientSettings::default()); + assert_eq!(daemon.protocol_version, ProtocolVersion::from_parts(1, 35)); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/mod.rs b/tvix/nix-compat/src/nix_daemon/mod.rs index 11413e85fd1b..e475263d2302 100644 --- a/tvix/nix-compat/src/nix_daemon/mod.rs +++ b/tvix/nix-compat/src/nix_daemon/mod.rs @@ -1,6 +1,211 @@ pub mod worker_protocol; -mod protocol_version; -pub use protocol_version::ProtocolVersion; +use std::io::Result; -pub mod de; +use futures::future::try_join_all; +use tracing::warn; +use types::{QueryValidPaths, UnkeyedValidPathInfo}; + +use crate::store_path::StorePath; + +pub mod handler; +pub mod types; + +/// Represents all possible operations over the nix-daemon protocol. +pub trait NixDaemonIO: Sync { + fn is_valid_path( + &self, + path: &StorePath<String>, + ) -> impl std::future::Future<Output = Result<bool>> + Send { + async move { Ok(self.query_path_info(path).await?.is_some()) } + } + + fn query_path_info( + &self, + path: &StorePath<String>, + ) -> impl std::future::Future<Output = Result<Option<UnkeyedValidPathInfo>>> + Send; + + fn query_path_from_hash_part( + &self, + hash: &[u8], + ) -> impl std::future::Future<Output = Result<Option<UnkeyedValidPathInfo>>> + Send; + + fn query_valid_paths( + &self, + request: &QueryValidPaths, + ) -> impl std::future::Future<Output = Result<Vec<UnkeyedValidPathInfo>>> + Send { + async move { + if request.substitute { + warn!("tvix does not yet support substitution, ignoring the 'substitute' flag..."); + } + // Using try_join_all here to avoid returning partial results to the client. + // The only reason query_path_info can fail is due to transient IO errors, + // so we return such errors to the client as opposed to only returning paths + // that succeeded. + let result = + try_join_all(request.paths.iter().map(|path| self.query_path_info(path))).await?; + let result: Vec<UnkeyedValidPathInfo> = result.into_iter().flatten().collect(); + Ok(result) + } + } + + fn query_valid_derivers( + &self, + path: &StorePath<String>, + ) -> impl std::future::Future<Output = Result<Vec<StorePath<String>>>> + Send { + async move { + let result = self.query_path_info(path).await?; + let result: Vec<_> = result.into_iter().filter_map(|info| info.deriver).collect(); + Ok(result) + } + } +} + +#[cfg(test)] +mod tests { + + use crate::{nix_daemon::types::QueryValidPaths, store_path::StorePath}; + + use super::{types::UnkeyedValidPathInfo, NixDaemonIO}; + + // Very simple mock + // Unable to use mockall as it does not support unboxed async traits. + pub struct MockNixDaemonIO { + query_path_info_result: Option<UnkeyedValidPathInfo>, + } + + impl NixDaemonIO for MockNixDaemonIO { + async fn query_path_info( + &self, + _path: &StorePath<String>, + ) -> std::io::Result<Option<UnkeyedValidPathInfo>> { + Ok(self.query_path_info_result.clone()) + } + + async fn query_path_from_hash_part( + &self, + _hash: &[u8], + ) -> std::io::Result<Option<UnkeyedValidPathInfo>> { + Ok(None) + } + } + + #[tokio::test] + async fn test_is_valid_path_returns_true() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: Some(UnkeyedValidPathInfo::default()), + }; + + let result = io + .is_valid_path(&path) + .await + .expect("expected to get a non-empty response"); + assert!(result, "expected to get true"); + } + + #[tokio::test] + async fn test_is_valid_path_returns_false() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: None, + }; + + let result = io + .is_valid_path(&path) + .await + .expect("expected to get a non-empty response"); + assert!(!result, "expected to get false"); + } + + #[tokio::test] + async fn test_query_valid_paths_returns_empty_response() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: None, + }; + + let result = io + .query_valid_paths(&QueryValidPaths { + paths: vec![path], + substitute: false, + }) + .await + .expect("expected to get a non-empty response"); + assert_eq!(result, vec![], "expected to get empty response"); + } + + #[tokio::test] + async fn test_query_valid_paths_returns_non_empty_response() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: Some(UnkeyedValidPathInfo::default()), + }; + + let result = io + .query_valid_paths(&QueryValidPaths { + paths: vec![path], + substitute: false, + }) + .await + .expect("expected to get a non-empty response"); + assert_eq!( + result, + vec![UnkeyedValidPathInfo::default()], + "expected to get non empty response" + ); + } + + #[tokio::test] + async fn test_query_valid_derivers_returns_empty_response() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: None, + }; + + let result = io + .query_valid_derivers(&path) + .await + .expect("expected to get a non-empty response"); + assert_eq!(result, vec![], "expected to get empty response"); + } + + #[tokio::test] + async fn test_query_valid_derivers_returns_non_empty_response() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let deriver = StorePath::<String>::from_bytes( + "z6r3bn5l51679pwkvh9nalp6c317z34m-hello.drv".as_bytes(), + ) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: Some(UnkeyedValidPathInfo { + deriver: Some(deriver.clone()), + nar_hash: "".to_owned(), + references: vec![], + registration_time: 0, + nar_size: 1, + ultimate: true, + signatures: vec![], + ca: None, + }), + }; + + let result = io + .query_valid_derivers(&path) + .await + .expect("expected to get a non-empty response"); + assert_eq!(result, vec![deriver], "expected to get non empty response"); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/types.rs b/tvix/nix-compat/src/nix_daemon/types.rs new file mode 100644 index 000000000000..bf7b1e6f6e58 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/types.rs @@ -0,0 +1,176 @@ +use crate::{ + narinfo::Signature, + nixhash::CAHash, + store_path::StorePath, + wire::{ + de::{NixDeserialize, NixRead}, + ser::{NixSerialize, NixWrite}, + }, +}; +use nix_compat_derive::{NixDeserialize, NixSerialize}; +use std::future::Future; + +/// Marker type that consumes/sends and ignores a u64. +#[derive(Clone, Debug, NixDeserialize, NixSerialize)] +#[nix(from = "u64", into = "u64")] +pub struct IgnoredZero; +impl From<u64> for IgnoredZero { + fn from(_: u64) -> Self { + IgnoredZero + } +} + +impl From<IgnoredZero> for u64 { + fn from(_: IgnoredZero) -> Self { + 0 + } +} + +#[derive(Debug, NixSerialize)] +pub struct TraceLine { + have_pos: IgnoredZero, + hint: String, +} + +/// Represents an error returned by the nix-daemon to its client. +/// +/// Adheres to the format described in serialization.md +#[derive(NixSerialize)] +pub struct NixError { + #[nix(version = "26..")] + type_: &'static str, + + #[nix(version = "26..")] + level: u64, + + #[nix(version = "26..")] + name: &'static str, + + msg: String, + #[nix(version = "26..")] + have_pos: IgnoredZero, + + #[nix(version = "26..")] + traces: Vec<TraceLine>, + + #[nix(version = "..=25")] + exit_status: u64, +} + +impl NixError { + pub fn new(msg: String) -> Self { + Self { + type_: "Error", + level: 0, // error + name: "Error", + msg, + have_pos: IgnoredZero {}, + traces: vec![], + exit_status: 1, + } + } +} + +nix_compat_derive::nix_serialize_remote!(#[nix(display)] Signature<String>); + +impl NixSerialize for CAHash { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_value(&self.to_nix_nixbase32_string()).await + } +} + +impl NixSerialize for Option<CAHash> { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + match self { + Some(value) => writer.write_value(value).await, + None => writer.write_value("").await, + } + } +} + +impl NixSerialize for Option<UnkeyedValidPathInfo> { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + match self { + Some(value) => { + writer.write_value(&true).await?; + writer.write_value(value).await + } + None => writer.write_value(&false).await, + } + } +} + +// Custom implementation since FromStr does not use from_absolute_path +impl NixDeserialize for StorePath<String> { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + use crate::wire::de::Error; + if let Some(buf) = reader.try_read_bytes().await? { + let result = StorePath::<String>::from_absolute_path(&buf); + result.map(Some).map_err(R::Error::invalid_data) + } else { + Ok(None) + } + } +} + +// Custom implementation since Display does not use absolute paths. +impl<S> NixSerialize for StorePath<S> +where + S: AsRef<str>, +{ + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite, + { + let sp = self.to_absolute_path(); + async move { writer.write_value(&sp).await } + } +} + +// Writes StorePath or an empty string. +impl NixSerialize for Option<StorePath<String>> { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + match self { + Some(value) => writer.write_value(value).await, + None => writer.write_value("").await, + } + } +} + +#[derive(NixSerialize, Debug, Clone, Default, PartialEq)] +pub struct UnkeyedValidPathInfo { + pub deriver: Option<StorePath<String>>, + pub nar_hash: String, + pub references: Vec<StorePath<String>>, + pub registration_time: u64, + pub nar_size: u64, + pub ultimate: bool, + pub signatures: Vec<Signature<String>>, + pub ca: Option<CAHash>, +} + +/// Request tupe for [super::worker_protocol::Operation::QueryValidPaths] +#[derive(NixDeserialize)] +pub struct QueryValidPaths { + // Paths to query + pub paths: Vec<StorePath<String>>, + + // Whether to try and substitute the paths. + #[nix(version = "27..")] + pub substitute: bool, +} diff --git a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs index 7e3adc0db2ff..1ef9b9ab02d7 100644 --- a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs +++ b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs @@ -1,19 +1,21 @@ use std::{ - collections::HashMap, + cmp::min, + collections::BTreeMap, io::{Error, ErrorKind}, }; -use enum_primitive_derive::Primitive; -use num_traits::{FromPrimitive, ToPrimitive}; +use nix_compat_derive::{NixDeserialize, NixSerialize}; +use num_enum::{FromPrimitive, IntoPrimitive, TryFromPrimitive}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::wire; -use super::ProtocolVersion; +use crate::wire::ProtocolVersion; -static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc" -static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio" +pub(crate) static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc" +pub(crate) static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio" pub static STDERR_LAST: u64 = 0x616c7473; // "alts" +pub(crate) static STDERR_ERROR: u64 = 0x63787470; // "cxtp" /// | Nix version | Protocol | /// |-----------------|----------| @@ -54,7 +56,11 @@ pub static MAX_SETTING_SIZE: usize = 1024; /// Note: for now, we're using the Nix 2.20 operation description. The /// operations marked as obsolete are obsolete for Nix 2.20, not /// necessarily for Nix 2.3. We'll revisit this later on. -#[derive(Debug, PartialEq, Primitive)] +#[derive( + Clone, Debug, PartialEq, TryFromPrimitive, IntoPrimitive, NixDeserialize, NixSerialize, +)] +#[nix(try_from = "u64", into = "u64")] +#[repr(u64)] pub enum Operation { IsValidPath = 1, HasSubstitutes = 3, @@ -105,8 +111,13 @@ pub enum Operation { /// Log verbosity. In the Nix wire protocol, the client requests a /// verbosity level to the daemon, which in turns does not produce any /// log below this verbosity. -#[derive(Debug, PartialEq, Primitive)] +#[derive( + Debug, PartialEq, FromPrimitive, IntoPrimitive, NixDeserialize, NixSerialize, Default, Clone, +)] +#[nix(from = "u64", into = "u64")] +#[repr(u64)] pub enum Verbosity { + #[default] LvlError = 0, LvlWarn = 1, LvlNotice = 2, @@ -119,7 +130,7 @@ pub enum Verbosity { /// Settings requested by the client. These settings are applied to a /// connection to between the daemon and a client. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, NixDeserialize, NixSerialize, Default)] pub struct ClientSettings { pub keep_failed: bool, pub keep_going: bool, @@ -127,70 +138,21 @@ pub struct ClientSettings { pub verbosity: Verbosity, pub max_build_jobs: u64, pub max_silent_time: u64, - pub verbose_build: bool, + pub use_build_hook: bool, + pub verbose_build: u64, + pub log_type: u64, + pub print_build_trace: u64, pub build_cores: u64, pub use_substitutes: bool, + /// Key/Value dictionary in charge of overriding the settings set /// by the Nix config file. /// /// Some settings can be safely overidden, /// some other require the user running the Nix client to be part /// of the trusted users group. - pub overrides: HashMap<String, String>, -} - -/// Reads the client settings from the wire. -/// -/// Note: this function **only** reads the settings. It does not -/// manage the log state with the daemon. You'll have to do that on -/// your own. A minimal log implementation will consist in sending -/// back [STDERR_LAST] to the client after reading the client -/// settings. -/// -/// FUTUREWORK: write serialization. -pub async fn read_client_settings<R: AsyncReadExt + Unpin>( - r: &mut R, - client_version: ProtocolVersion, -) -> std::io::Result<ClientSettings> { - let keep_failed = r.read_u64_le().await? != 0; - let keep_going = r.read_u64_le().await? != 0; - let try_fallback = r.read_u64_le().await? != 0; - let verbosity_uint = r.read_u64_le().await?; - let verbosity = Verbosity::from_u64(verbosity_uint).ok_or_else(|| { - Error::new( - ErrorKind::InvalidData, - format!("Can't convert integer {} to verbosity", verbosity_uint), - ) - })?; - let max_build_jobs = r.read_u64_le().await?; - let max_silent_time = r.read_u64_le().await?; - _ = r.read_u64_le().await?; // obsolete useBuildHook - let verbose_build = r.read_u64_le().await? != 0; - _ = r.read_u64_le().await?; // obsolete logType - _ = r.read_u64_le().await?; // obsolete printBuildTrace - let build_cores = r.read_u64_le().await?; - let use_substitutes = r.read_u64_le().await? != 0; - let mut overrides = HashMap::new(); - if client_version.minor() >= 12 { - let num_overrides = r.read_u64_le().await?; - for _ in 0..num_overrides { - let name = wire::read_string(r, 0..=MAX_SETTING_SIZE).await?; - let value = wire::read_string(r, 0..=MAX_SETTING_SIZE).await?; - overrides.insert(name, value); - } - } - Ok(ClientSettings { - keep_failed, - keep_going, - try_fallback, - verbosity, - max_build_jobs, - max_silent_time, - verbose_build, - build_cores, - use_substitutes, - overrides, - }) + #[nix(version = "12..")] + pub overrides: BTreeMap<String, String>, } /// Performs the initial handshake the server is sending to a connecting client. @@ -209,7 +171,7 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>( /// /// # Return /// -/// The protocol version of the client. +/// The protocol version to use for further comms, min(client_version, our_version). pub async fn server_handshake_client<'a, RW: 'a>( mut conn: &'a mut RW, nix_version: &str, @@ -239,46 +201,46 @@ where format!("The nix client version {} is too old", client_version), )); } - if client_version.minor() >= 14 { + let picked_version = min(PROTOCOL_VERSION, client_version); + if picked_version.minor() >= 14 { // Obsolete CPU affinity. let read_affinity = conn.read_u64_le().await?; if read_affinity != 0 { let _cpu_affinity = conn.read_u64_le().await?; }; } - if client_version.minor() >= 11 { + if picked_version.minor() >= 11 { // Obsolete reserveSpace let _reserve_space = conn.read_u64_le().await?; } - if client_version.minor() >= 33 { + if picked_version.minor() >= 33 { // Nix version. We're plain lying, we're not Nix, but ehโฆ // Setting it to the 2.3 lineage. Not 100% sure this is a // good idea. wire::write_bytes(&mut conn, nix_version).await?; conn.flush().await?; } - if client_version.minor() >= 35 { + if picked_version.minor() >= 35 { write_worker_trust_level(&mut conn, trusted).await?; } - Ok(client_version) + Ok(picked_version) } } /// Read a worker [Operation] from the wire. pub async fn read_op<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<Operation> { let op_number = r.read_u64_le().await?; - Operation::from_u64(op_number).ok_or(Error::new( - ErrorKind::InvalidData, - format!("Invalid OP number {}", op_number), - )) + Operation::try_from(op_number).map_err(|_| { + Error::new( + ErrorKind::InvalidData, + format!("Invalid OP number {}", op_number), + ) + }) } /// Write a worker [Operation] to the wire. -pub async fn write_op<W: AsyncWriteExt + Unpin>(w: &mut W, op: &Operation) -> std::io::Result<()> { - let op = Operation::to_u64(op).ok_or(Error::new( - ErrorKind::Other, - format!("Can't convert the OP {:?} to u64", op), - ))?; +pub async fn write_op<W: AsyncWriteExt + Unpin>(w: &mut W, op: Operation) -> std::io::Result<()> { + let op: u64 = op.into(); w.write_u64(op).await } @@ -307,8 +269,6 @@ where #[cfg(test)] mod tests { use super::*; - use hex_literal::hex; - use tokio_test::io::Builder; #[tokio::test] async fn test_init_hanshake() { @@ -330,105 +290,63 @@ mod tests { // Trusted (1 == client trusted .write(&[1, 0, 0, 0, 0, 0, 0, 0]) .build(); - let client_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) + let picked_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) .await .unwrap(); - assert_eq!(client_version, PROTOCOL_VERSION) + assert_eq!(picked_version, PROTOCOL_VERSION) } #[tokio::test] - async fn test_read_client_settings_without_overrides() { - // Client settings bits captured from a Nix 2.3.17 run w/ sockdump (protocol version 21). - let wire_bits = hex!( - "00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 02 00 00 00 00 00 00 00 \ - 10 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 01 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 01 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00" - ); - let mut mock = Builder::new().read(&wire_bits).build(); - let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21)) + async fn test_init_hanshake_with_newer_client_should_use_older_version() { + let mut test_conn = tokio_test::io::Builder::new() + .read(&WORKER_MAGIC_1.to_le_bytes()) + .write(&WORKER_MAGIC_2.to_le_bytes()) + .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // Client is newer than us. + .read(&[38, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // cpu affinity + .read(&[0; 8]) + // reservespace + .read(&[0; 8]) + // version (size) + .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // version (data == 2.18.2 + padding) + .write(&[50, 46, 49, 56, 46, 50, 0, 0]) + // Trusted (1 == client trusted + .write(&[1, 0, 0, 0, 0, 0, 0, 0]) + .build(); + let picked_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) .await - .expect("should parse"); - let expected = ClientSettings { - keep_failed: false, - keep_going: false, - try_fallback: false, - verbosity: Verbosity::LvlNotice, - max_build_jobs: 16, - max_silent_time: 0, - verbose_build: false, - build_cores: 0, - use_substitutes: true, - overrides: HashMap::new(), - }; - assert_eq!(settings, expected); + .unwrap(); + + assert_eq!(picked_version, PROTOCOL_VERSION) } #[tokio::test] - async fn test_read_client_settings_with_overrides() { - // Client settings bits captured from a Nix 2.3.17 run w/ sockdump (protocol version 21). - let wire_bits = hex!( - "00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 02 00 00 00 00 00 00 00 \ - 10 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 01 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 01 00 00 00 00 00 00 00 \ - 02 00 00 00 00 00 00 00 \ - 0c 00 00 00 00 00 00 00 \ - 61 6c 6c 6f 77 65 64 2d \ - 75 72 69 73 00 00 00 00 \ - 1e 00 00 00 00 00 00 00 \ - 68 74 74 70 73 3a 2f 2f \ - 62 6f 72 64 65 61 75 78 \ - 2e 67 75 69 78 2e 67 6e \ - 75 2e 6f 72 67 2f 00 00 \ - 0d 00 00 00 00 00 00 00 \ - 61 6c 6c 6f 77 65 64 2d \ - 75 73 65 72 73 00 00 00 \ - 0b 00 00 00 00 00 00 00 \ - 6a 65 61 6e 20 70 69 65 \ - 72 72 65 00 00 00 00 00" - ); - let mut mock = Builder::new().read(&wire_bits).build(); - let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21)) + async fn test_init_hanshake_with_older_client_should_use_older_version() { + let mut test_conn = tokio_test::io::Builder::new() + .read(&WORKER_MAGIC_1.to_le_bytes()) + .write(&WORKER_MAGIC_2.to_le_bytes()) + .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // Client is newer than us. + .read(&[24, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // cpu affinity + .read(&[0; 8]) + // reservespace + .read(&[0; 8]) + // NOTE: we are not writing version and trust since the client is too old. + // version (size) + //.write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // version (data == 2.18.2 + padding) + //.write(&[50, 46, 49, 56, 46, 50, 0, 0]) + // Trusted (1 == client trusted + //.write(&[1, 0, 0, 0, 0, 0, 0, 0]) + .build(); + let picked_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) .await - .expect("should parse"); - let overrides = HashMap::from([ - ( - String::from("allowed-uris"), - String::from("https://bordeaux.guix.gnu.org/"), - ), - (String::from("allowed-users"), String::from("jean pierre")), - ]); - let expected = ClientSettings { - keep_failed: false, - keep_going: false, - try_fallback: false, - verbosity: Verbosity::LvlNotice, - max_build_jobs: 16, - max_silent_time: 0, - verbose_build: false, - build_cores: 0, - use_substitutes: true, - overrides, - }; - assert_eq!(settings, expected); + .unwrap(); + + assert_eq!(picked_version, ProtocolVersion::from_parts(1, 24)) } } diff --git a/tvix/nix-compat/src/path_info.rs b/tvix/nix-compat/src/path_info.rs index f289ebde338c..63512805fe09 100644 --- a/tvix/nix-compat/src/path_info.rs +++ b/tvix/nix-compat/src/path_info.rs @@ -1,4 +1,4 @@ -use crate::{nixbase32, nixhash::NixHash, store_path::StorePathRef}; +use crate::{narinfo::SignatureRef, nixbase32, nixhash::NixHash, store_path::StorePathRef}; use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; @@ -15,7 +15,7 @@ pub struct ExportedPathInfo<'a> { #[serde( rename = "narHash", serialize_with = "to_nix_nixbase32_string", - deserialize_with = "from_nix_nixbase32_string" + deserialize_with = "from_nix_hash_string" )] pub nar_sha256: [u8; 32], @@ -25,11 +25,17 @@ pub struct ExportedPathInfo<'a> { #[serde(borrow)] pub path: StorePathRef<'a>, + #[serde(borrow)] + #[serde(skip_serializing_if = "Option::is_none")] + pub deriver: Option<StorePathRef<'a>>, + /// The list of other Store Paths this Store Path refers to. /// StorePathRef does Ord by the nixbase32-encoded string repr, so this is correct. pub references: BTreeSet<StorePathRef<'a>>, // more recent versions of Nix also have a `valid: true` field here, Nix 2.3 doesn't, // and nothing seems to use it. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub signatures: Vec<SignatureRef<'a>>, } /// ExportedPathInfo are ordered by their `path` field. @@ -56,18 +62,49 @@ where /// The length of a sha256 digest, nixbase32-encoded. const NIXBASE32_SHA256_ENCODE_LEN: usize = nixbase32::encode_len(32); -fn from_nix_nixbase32_string<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error> +fn from_nix_hash_string<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error> where D: serde::Deserializer<'de>, { let str: &'de str = Deserialize::deserialize(deserializer)?; + if let Some(digest_str) = str.strip_prefix("sha256:") { + return from_nix_nixbase32_string::<D>(digest_str); + } + if let Some(digest_str) = str.strip_prefix("sha256-") { + return from_sri_string::<D>(digest_str); + } + Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(str), + &"extected a valid nixbase32 or sri narHash", + )) +} - let digest_str = str.strip_prefix("sha256:").ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Str(str), &"sha256:โฆ") - })?; +fn from_sri_string<'de, D>(str: &str) -> Result<[u8; 32], D::Error> +where + D: serde::Deserializer<'de>, +{ + let digest: [u8; 32] = data_encoding::BASE64 + .decode(str.as_bytes()) + .map_err(|_| { + serde::de::Error::invalid_value( + serde::de::Unexpected::Str(str), + &"valid base64 encoded string", + ) + })? + .try_into() + .map_err(|_| { + serde::de::Error::invalid_value(serde::de::Unexpected::Str(str), &"valid digest len") + })?; + Ok(digest) +} + +fn from_nix_nixbase32_string<'de, D>(str: &str) -> Result<[u8; 32], D::Error> +where + D: serde::Deserializer<'de>, +{ let digest_str: [u8; NIXBASE32_SHA256_ENCODE_LEN] = - digest_str.as_bytes().try_into().map_err(|_| { + str.as_bytes().try_into().map_err(|_| { serde::de::Error::invalid_value(serde::de::Unexpected::Str(str), &"valid digest len") })?; @@ -110,10 +147,49 @@ mod tests { b"7n0mbqydcipkpbxm24fab066lxk68aqk-libunistring-1.1" ) .expect("must parse"), + deriver: None, references: BTreeSet::from_iter([StorePathRef::from_bytes( b"7n0mbqydcipkpbxm24fab066lxk68aqk-libunistring-1.1" ) .unwrap()]), + signatures: vec![], + }, + deserialized.first().unwrap() + ); + } + + /// Ensure we can parse output from `nix path-info --json`` + #[test] + fn serialize_deserialize_from_path_info() { + // JSON extracted from + // nix path-info /nix/store/z6r3bn5l51679pwkvh9nalp6c317z34m-libcxx-16.0.6-dev --json --closure-size + let pathinfos_str_json = r#"[{"closureSize":10756176,"deriver":"/nix/store/vs9976cyyxpykvdnlv7x85fpp3shn6ij-libcxx-16.0.6.drv","narHash":"sha256-E73Nt0NAKGxCnsyBFDUaCAbA+wiF5qjq1O9J7WrnT0E=","narSize":7020664,"path":"/nix/store/z6r3bn5l51679pwkvh9nalp6c317z34m-libcxx-16.0.6-dev","references":["/nix/store/lzzd5jgybnpfj86xkcpnd54xgwc4m457-libcxx-16.0.6"],"registrationTime":1730048276,"signatures":["cache.nixos.org-1:cTdhK6hnpPwtMXFX43CYb7v+CbpAusVI/MORZ3v5aHvpBYNg1MfBHVVeoexMBpNtHA8uFAn0aEsJaLXYIDhJDg=="],"valid":true}]"#; + + let deserialized: BTreeSet<ExportedPathInfo> = + serde_json::from_str(pathinfos_str_json).expect("must serialize"); + + assert_eq!( + &ExportedPathInfo { + closure_size: 10756176, + nar_sha256: hex!( + "13bdcdb74340286c429ecc8114351a0806c0fb0885e6a8ead4ef49ed6ae74f41" + ), + nar_size: 7020664, + path: StorePathRef::from_bytes( + b"z6r3bn5l51679pwkvh9nalp6c317z34m-libcxx-16.0.6-dev" + ) + .expect("must parse"), + deriver: Some( + StorePathRef::from_bytes( + b"vs9976cyyxpykvdnlv7x85fpp3shn6ij-libcxx-16.0.6.drv" + ) + .expect("must parse") + ), + references: BTreeSet::from_iter([StorePathRef::from_bytes( + b"lzzd5jgybnpfj86xkcpnd54xgwc4m457-libcxx-16.0.6" + ) + .unwrap()]), + signatures: vec![SignatureRef::parse("cache.nixos.org-1:cTdhK6hnpPwtMXFX43CYb7v+CbpAusVI/MORZ3v5aHvpBYNg1MfBHVVeoexMBpNtHA8uFAn0aEsJaLXYIDhJDg==").expect("must parse")], }, deserialized.first().unwrap() ); diff --git a/tvix/nix-compat/src/store_path/mod.rs b/tvix/nix-compat/src/store_path/mod.rs index 177cc96ce20f..13265048641b 100644 --- a/tvix/nix-compat/src/store_path/mod.rs +++ b/tvix/nix-compat/src/store_path/mod.rs @@ -2,16 +2,12 @@ use crate::nixbase32; use data_encoding::{DecodeError, BASE64}; use serde::{Deserialize, Serialize}; use std::{ - fmt::{self, Display}, - ops::Deref, + fmt, path::Path, str::{self, FromStr}, }; use thiserror; -#[cfg(target_family = "unix")] -use std::os::unix::ffi::OsStrExt; - mod utils; pub use utils::*; @@ -54,21 +50,40 @@ pub enum Error { /// /// A [StorePath] does not encode any additional subpath "inside" the store /// path. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct StorePath<S> -where - S: std::cmp::Eq + std::cmp::PartialEq, -{ +#[derive(Clone, Debug)] +pub struct StorePath<S> { digest: [u8; DIGEST_SIZE], name: S, } + +impl<S> PartialEq for StorePath<S> +where + S: AsRef<str>, +{ + fn eq(&self, other: &Self) -> bool { + self.digest() == other.digest() && self.name().as_ref() == other.name().as_ref() + } +} + +impl<S> Eq for StorePath<S> where S: AsRef<str> {} + +impl<S> std::hash::Hash for StorePath<S> +where + S: AsRef<str>, +{ + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + state.write(&self.digest); + state.write(self.name.as_ref().as_bytes()); + } +} + /// Like [StorePath], but without a heap allocation for the name. /// Used by [StorePath] for parsing. pub type StorePathRef<'a> = StorePath<&'a str>; impl<S> StorePath<S> where - S: std::cmp::Eq + Deref<Target = str>, + S: AsRef<str>, { pub fn digest(&self) -> &[u8; DIGEST_SIZE] { &self.digest @@ -81,14 +96,14 @@ where pub fn as_ref(&self) -> StorePathRef<'_> { StorePathRef { digest: self.digest, - name: &self.name, + name: self.name.as_ref(), } } pub fn to_owned(&self) -> StorePath<String> { StorePath { digest: self.digest, - name: self.name.to_string(), + name: self.name.as_ref().to_string(), } } @@ -139,7 +154,7 @@ where S: From<&'a str>, { Ok(Self { - name: validate_name(name.as_bytes())?.into(), + name: validate_name(name)?.into(), digest, }) } @@ -160,47 +175,40 @@ where /// Decompose a string into a [StorePath] and a [PathBuf] containing the /// rest of the path, or an error. #[cfg(target_family = "unix")] - pub fn from_absolute_path_full<'a>(s: &'a str) -> Result<(Self, &'a Path), Error> + pub fn from_absolute_path_full<'a, P>(path: &'a P) -> Result<(Self, &'a Path), Error> where S: From<&'a str>, + P: AsRef<std::path::Path> + ?Sized, { // strip [STORE_DIR_WITH_SLASH] from s + let p = path + .as_ref() + .strip_prefix(STORE_DIR_WITH_SLASH) + .map_err(|_e| Error::MissingStoreDir)?; - match s.strip_prefix(STORE_DIR_WITH_SLASH) { - None => Err(Error::MissingStoreDir), - Some(rest) => { - let mut it = Path::new(rest).components(); + let mut it = Path::new(p).components(); - // The first component of the rest must be parse-able as a [StorePath] - if let Some(first_component) = it.next() { - // convert first component to StorePath - let store_path = StorePath::from_bytes(first_component.as_os_str().as_bytes())?; + // The first component of the rest must be parse-able as a [StorePath] + let first_component = it.next().ok_or(Error::InvalidLength)?; + let store_path = StorePath::from_bytes(first_component.as_os_str().as_encoded_bytes())?; - // collect rest - let rest_buf = it.as_path(); + // collect rest + let rest_buf = it.as_path(); - Ok((store_path, rest_buf)) - } else { - Err(Error::InvalidLength) // Well, or missing "/"? - } - } - } + Ok((store_path, rest_buf)) } /// Returns an absolute store path string. /// That is just the string representation, prefixed with the store prefix /// ([STORE_DIR_WITH_SLASH]), - pub fn to_absolute_path(&self) -> String - where - S: Display, - { + pub fn to_absolute_path(&self) -> String { format!("{}{}", STORE_DIR_WITH_SLASH, self) } } impl<S> PartialOrd for StorePath<S> where - S: std::cmp::PartialEq + std::cmp::Eq, + S: AsRef<str>, { fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { Some(self.cmp(other)) @@ -211,7 +219,7 @@ where /// of the nixbase32-encoded string. impl<S> Ord for StorePath<S> where - S: std::cmp::PartialEq + std::cmp::Eq, + S: AsRef<str>, { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.digest.iter().rev().cmp(other.digest.iter().rev()) @@ -230,7 +238,7 @@ impl<'a, 'b: 'a> FromStr for StorePath<String> { impl<'a, 'de: 'a, S> Deserialize<'de> for StorePath<S> where - S: std::cmp::Eq + Deref<Target = str> + From<&'a str>, + S: AsRef<str> + From<&'a str>, { fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where @@ -252,7 +260,7 @@ where impl<S> Serialize for StorePath<S> where - S: std::cmp::Eq + Deref<Target = str> + Display, + S: AsRef<str>, { fn serialize<SR>(&self, serializer: SR) -> Result<SR::Ok, SR::Error> where @@ -312,13 +320,18 @@ pub(crate) fn validate_name(s: &(impl AsRef<[u8]> + ?Sized)) -> Result<&str, Err impl<S> fmt::Display for StorePath<S> where - S: fmt::Display + std::cmp::Eq, + S: AsRef<str>, { /// The string representation of a store path starts with a digest (20 /// bytes), [crate::nixbase32]-encoded, followed by a `-`, /// and ends with the name. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}-{}", nixbase32::encode(&self.digest), self.name) + write!( + f, + "{}-{}", + nixbase32::encode(&self.digest), + self.name.as_ref() + ) } } diff --git a/tvix/nix-compat/src/store_path/utils.rs b/tvix/nix-compat/src/store_path/utils.rs index 4bfbb72fcdde..63b6969464d8 100644 --- a/tvix/nix-compat/src/store_path/utils.rs +++ b/tvix/nix-compat/src/store_path/utils.rs @@ -1,6 +1,6 @@ use crate::nixbase32; use crate::nixhash::{CAHash, NixHash}; -use crate::store_path::{Error, StorePath, StorePathRef, STORE_DIR}; +use crate::store_path::{Error, StorePath, STORE_DIR}; use data_encoding::HEXLOWER; use sha2::{Digest, Sha256}; use thiserror; @@ -50,7 +50,7 @@ pub fn build_text_path<'a, S, SP, I, C>( ) -> Result<StorePath<SP>, BuildStorePathError> where S: AsRef<str>, - SP: std::cmp::Eq + std::ops::Deref<Target = str> + std::convert::From<&'a str>, + SP: AsRef<str> + std::convert::From<&'a str>, I: IntoIterator<Item = S>, C: AsRef<[u8]>, { @@ -69,7 +69,7 @@ pub fn build_ca_path<'a, S, SP, I>( ) -> Result<StorePath<SP>, BuildStorePathError> where S: AsRef<str>, - SP: std::cmp::Eq + std::ops::Deref<Target = str> + std::convert::From<&'a str>, + SP: AsRef<str> + std::convert::From<&'a str>, I: IntoIterator<Item = S>, { // self references are only allowed for CAHash::Nar(NixHash::Sha256(_)). @@ -119,26 +119,18 @@ where .map_err(BuildStorePathError::InvalidStorePath) } -/// For given NAR sha256 digest and name, return the new [StorePathRef] this -/// would have, or an error, in case the name is invalid. -pub fn build_nar_based_store_path<'a>( - nar_sha256_digest: &[u8; 32], - name: &'a str, -) -> Result<StorePathRef<'a>, BuildStorePathError> { - let nar_hash_with_mode = CAHash::Nar(NixHash::Sha256(nar_sha256_digest.to_owned())); - - build_ca_path(name, &nar_hash_with_mode, Vec::<String>::new(), false) -} - /// This builds an input-addressed store path. /// /// Input-addresed store paths are always derivation outputs, the "input" in question is the /// derivation and its closure. -pub fn build_output_path<'a>( +pub fn build_output_path<'a, SP>( drv_sha256: &[u8; 32], output_name: &str, output_path_name: &'a str, -) -> Result<StorePathRef<'a>, Error> { +) -> Result<StorePath<SP>, Error> +where + SP: AsRef<str> + std::convert::From<&'a str>, +{ build_store_path_from_fingerprint_parts( &(String::from("output:") + output_name), drv_sha256, @@ -156,13 +148,13 @@ pub fn build_output_path<'a>( /// bytes. /// Inside a StorePath, that digest is printed nixbase32-encoded /// (32 characters). -fn build_store_path_from_fingerprint_parts<'a, S>( +fn build_store_path_from_fingerprint_parts<'a, SP>( ty: &str, inner_digest: &[u8; 32], name: &'a str, -) -> Result<StorePath<S>, Error> +) -> Result<StorePath<SP>, Error> where - S: std::cmp::Eq + std::ops::Deref<Target = str> + std::convert::From<&'a str>, + SP: AsRef<str> + std::convert::From<&'a str>, { let fingerprint = format!( "{ty}:sha256:{}:{STORE_DIR}:{name}", @@ -221,7 +213,10 @@ mod test { use hex_literal::hex; use super::*; - use crate::nixhash::{CAHash, NixHash}; + use crate::{ + nixhash::{CAHash, NixHash}, + store_path::StorePathRef, + }; #[test] fn build_text_path_with_zero_references() { diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs index 74adfb49b6a4..9b981fbbd2c0 100644 --- a/tvix/nix-compat/src/wire/bytes/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/mod.rs @@ -181,7 +181,7 @@ pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>( /// Computes the number of bytes we should add to len (a length in /// bytes) to be aligned on 64 bits (8 bytes). -fn padding_len(len: u64) -> u8 { +pub(crate) fn padding_len(len: u64) -> u8 { let aligned = len.wrapping_add(7) & !7; aligned.wrapping_sub(len) as u8 } diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs index 77950496ed6b..a6209a6e6dad 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs @@ -299,11 +299,11 @@ fn with_limited<R>(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) #[cfg(test)] mod tests { + use std::sync::LazyLock; use std::time::Duration; use crate::wire::bytes::{padding_len, write_bytes}; use hex_literal::hex; - use lazy_static::lazy_static; use rstest::rstest; use tokio::io::{AsyncReadExt, BufReader}; use tokio_test::io::Builder; @@ -314,9 +314,8 @@ mod tests { /// cases. const MAX_LEN: u64 = 1024; - lazy_static! { - pub static ref LARGE_PAYLOAD: Vec<u8> = (0..255).collect::<Vec<u8>>().repeat(4 * 1024); - } + pub static LARGE_PAYLOAD: LazyLock<Vec<u8>> = + LazyLock::new(|| (0..255).collect::<Vec<u8>>().repeat(4 * 1024)); /// Helper function, calling the (simpler) write_bytes with the payload. /// We use this to create data we want to read from the wire. diff --git a/tvix/nix-compat/src/wire/bytes/writer.rs b/tvix/nix-compat/src/wire/bytes/writer.rs index f5632771e961..8b9b59aa1b85 100644 --- a/tvix/nix-compat/src/wire/bytes/writer.rs +++ b/tvix/nix-compat/src/wire/bytes/writer.rs @@ -232,19 +232,18 @@ where #[cfg(test)] mod tests { + use std::sync::LazyLock; use std::time::Duration; use crate::wire::bytes::write_bytes; use hex_literal::hex; - use lazy_static::lazy_static; use tokio::io::AsyncWriteExt; use tokio_test::{assert_err, assert_ok, io::Builder}; use super::*; - lazy_static! { - pub static ref LARGE_PAYLOAD: Vec<u8> = (0..255).collect::<Vec<u8>>().repeat(4 * 1024); - } + pub static LARGE_PAYLOAD: LazyLock<Vec<u8>> = + LazyLock::new(|| (0..255).collect::<Vec<u8>>().repeat(4 * 1024)); /// Helper function, calling the (simpler) write_bytes with the payload. /// We use this to create data we want to see on the wire. diff --git a/tvix/nix-compat/src/nix_daemon/de/bytes.rs b/tvix/nix-compat/src/wire/de/bytes.rs index 7daced54eef7..4c64247f7051 100644 --- a/tvix/nix-compat/src/nix_daemon/de/bytes.rs +++ b/tvix/nix-compat/src/wire/de/bytes.rs @@ -34,7 +34,7 @@ mod test { use rstest::rstest; use tokio_test::io::Builder; - use crate::nix_daemon::de::{NixRead, NixReader}; + use crate::wire::de::{NixRead, NixReader}; #[rstest] #[case::empty("", &hex!("0000 0000 0000 0000"))] diff --git a/tvix/nix-compat/src/nix_daemon/de/collections.rs b/tvix/nix-compat/src/wire/de/collections.rs index cf79f584506a..e1271635e4e6 100644 --- a/tvix/nix-compat/src/nix_daemon/de/collections.rs +++ b/tvix/nix-compat/src/wire/de/collections.rs @@ -64,7 +64,7 @@ mod test { use rstest::rstest; use tokio_test::io::Builder; - use crate::nix_daemon::de::{NixDeserialize, NixRead, NixReader}; + use crate::wire::de::{NixDeserialize, NixRead, NixReader}; #[rstest] #[case::empty(vec![], &hex!("0000 0000 0000 0000"))] diff --git a/tvix/nix-compat/src/nix_daemon/de/int.rs b/tvix/nix-compat/src/wire/de/int.rs index eecf641cfe99..d505de9b1b24 100644 --- a/tvix/nix-compat/src/nix_daemon/de/int.rs +++ b/tvix/nix-compat/src/wire/de/int.rs @@ -45,7 +45,7 @@ mod test { use rstest::rstest; use tokio_test::io::Builder; - use crate::nix_daemon::de::{NixRead, NixReader}; + use crate::wire::de::{NixRead, NixReader}; #[rstest] #[case::simple_false(false, &hex!("0000 0000 0000 0000"))] diff --git a/tvix/nix-compat/src/nix_daemon/de/mock.rs b/tvix/nix-compat/src/wire/de/mock.rs index 31cc3a4897ba..8a1fb817743c 100644 --- a/tvix/nix-compat/src/nix_daemon/de/mock.rs +++ b/tvix/nix-compat/src/wire/de/mock.rs @@ -6,7 +6,7 @@ use std::thread; use bytes::Bytes; use thiserror::Error; -use crate::nix_daemon::ProtocolVersion; +use crate::wire::ProtocolVersion; use super::NixRead; @@ -189,7 +189,7 @@ mod test { use bytes::Bytes; use hex_literal::hex; - use crate::nix_daemon::de::NixRead; + use crate::wire::de::NixRead; use super::{Builder, Error}; diff --git a/tvix/nix-compat/src/nix_daemon/de/mod.rs b/tvix/nix-compat/src/wire/de/mod.rs index f85ccd8fea0e..f85ccd8fea0e 100644 --- a/tvix/nix-compat/src/nix_daemon/de/mod.rs +++ b/tvix/nix-compat/src/wire/de/mod.rs diff --git a/tvix/nix-compat/src/nix_daemon/de/reader.rs b/tvix/nix-compat/src/wire/de/reader.rs index 87c623b2220c..b7825f393c4e 100644 --- a/tvix/nix-compat/src/nix_daemon/de/reader.rs +++ b/tvix/nix-compat/src/wire/de/reader.rs @@ -8,8 +8,7 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use pin_project_lite::pin_project; use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf}; -use crate::nix_daemon::ProtocolVersion; -use crate::wire::EMPTY_BYTES; +use crate::wire::{ProtocolVersion, EMPTY_BYTES}; use super::{Error, NixRead}; @@ -270,7 +269,7 @@ mod test { use tokio_test::io::Builder; use super::*; - use crate::nix_daemon::de::NixRead; + use crate::wire::de::NixRead; #[tokio::test] async fn test_read_u64() { diff --git a/tvix/nix-compat/src/wire/mod.rs b/tvix/nix-compat/src/wire/mod.rs index a197e3a1f451..c3e88dda05ec 100644 --- a/tvix/nix-compat/src/wire/mod.rs +++ b/tvix/nix-compat/src/wire/mod.rs @@ -3,3 +3,9 @@ mod bytes; pub use bytes::*; + +mod protocol_version; +pub use protocol_version::ProtocolVersion; + +pub mod de; +pub mod ser; diff --git a/tvix/nix-compat/src/nix_daemon/protocol_version.rs b/tvix/nix-compat/src/wire/protocol_version.rs index 19da28d484dd..19da28d484dd 100644 --- a/tvix/nix-compat/src/nix_daemon/protocol_version.rs +++ b/tvix/nix-compat/src/wire/protocol_version.rs diff --git a/tvix/nix-compat/src/wire/ser/bytes.rs b/tvix/nix-compat/src/wire/ser/bytes.rs new file mode 100644 index 000000000000..737edb059b5b --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/bytes.rs @@ -0,0 +1,98 @@ +use bytes::Bytes; + +use super::{NixSerialize, NixWrite}; + +impl NixSerialize for Bytes { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self).await + } +} + +impl<'a> NixSerialize for &'a [u8] { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self).await + } +} + +impl NixSerialize for String { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +impl NixSerialize for str { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +impl NixSerialize for &str { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::{NixWrite, NixWriter}; + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown ๐ฆ jumps over 13 lazy ๐ถ.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_write_str(#[case] value: &str, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown ๐ฆ jumps over 13 lazy ๐ถ.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_write_string(#[case] value: &str, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value.to_string()).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/wire/ser/collections.rs b/tvix/nix-compat/src/wire/ser/collections.rs new file mode 100644 index 000000000000..478e1d04d809 --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/collections.rs @@ -0,0 +1,94 @@ +use std::collections::BTreeMap; +use std::future::Future; + +use super::{NixSerialize, NixWrite}; + +impl<T> NixSerialize for Vec<T> +where + T: NixSerialize + Send + Sync, +{ + #[allow(clippy::manual_async_fn)] + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite, + { + async move { + writer.write_value(&self.len()).await?; + for value in self.iter() { + writer.write_value(value).await?; + } + Ok(()) + } + } +} + +impl<K, V> NixSerialize for BTreeMap<K, V> +where + K: NixSerialize + Ord + Send + Sync, + V: NixSerialize + Send + Sync, +{ + #[allow(clippy::manual_async_fn)] + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite, + { + async move { + writer.write_value(&self.len()).await?; + for (key, value) in self.iter() { + writer.write_value(key).await?; + writer.write_value(value).await?; + } + Ok(()) + } + } +} + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + use std::fmt; + + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::{NixSerialize, NixWrite, NixWriter}; + + #[rstest] + #[case::empty(vec![], &hex!("0000 0000 0000 0000"))] + #[case::one(vec![0x29], &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(vec![0x7469, 10], &hex!("0200 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_write_small_vec(#[case] value: Vec<usize>, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + fn empty_map() -> BTreeMap<usize, u64> { + BTreeMap::new() + } + macro_rules! map { + ($($key:expr => $value:expr),*) => {{ + let mut ret = BTreeMap::new(); + $(ret.insert($key, $value);)* + ret + }}; + } + + #[rstest] + #[case::empty(empty_map(), &hex!("0000 0000 0000 0000"))] + #[case::one(map![0x7469usize => 10u64], &hex!("0100 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_write_small_btree_map<E>(#[case] value: E, #[case] data: &[u8]) + where + E: NixSerialize + Send + PartialEq + fmt::Debug, + { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/wire/ser/display.rs b/tvix/nix-compat/src/wire/ser/display.rs new file mode 100644 index 000000000000..a3438d50d8ff --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/display.rs @@ -0,0 +1,8 @@ +use nix_compat_derive::nix_serialize_remote; + +use crate::nixhash; + +nix_serialize_remote!( + #[nix(display)] + nixhash::HashAlgo +); diff --git a/tvix/nix-compat/src/wire/ser/int.rs b/tvix/nix-compat/src/wire/ser/int.rs new file mode 100644 index 000000000000..e68179c71dc7 --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/int.rs @@ -0,0 +1,108 @@ +#[cfg(feature = "nix-compat-derive")] +use nix_compat_derive::nix_serialize_remote; + +use super::{Error, NixSerialize, NixWrite}; + +impl NixSerialize for u64 { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_number(*self).await + } +} + +impl NixSerialize for usize { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + let v = (*self).try_into().map_err(W::Error::unsupported_data)?; + writer.write_number(v).await + } +} + +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u8 +); +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u16 +); +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u32 +); + +impl NixSerialize for bool { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + if *self { + writer.write_number(1).await + } else { + writer.write_number(0).await + } + } +} + +impl NixSerialize for i64 { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_number(*self as u64).await + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::{NixWrite, NixWriter}; + + #[rstest] + #[case::simple_false(false, &hex!("0000 0000 0000 0000"))] + #[case::simple_true(true, &hex!("0100 0000 0000 0000"))] + #[tokio::test] + async fn test_write_bool(#[case] value: bool, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_write_u64(#[case] value: u64, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(usize::MAX, &usize::MAX.to_le_bytes())] + #[tokio::test] + async fn test_write_usize(#[case] value: usize, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/wire/ser/mock.rs b/tvix/nix-compat/src/wire/ser/mock.rs new file mode 100644 index 000000000000..7104a94238ff --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/mock.rs @@ -0,0 +1,672 @@ +use std::collections::VecDeque; +use std::fmt; +use std::io; +use std::thread; + +#[cfg(test)] +use ::proptest::prelude::TestCaseError; +use thiserror::Error; + +use crate::wire::ProtocolVersion; + +use super::NixWrite; + +#[derive(Debug, Error, PartialEq, Eq, Clone)] +pub enum Error { + #[error("custom error '{0}'")] + Custom(String), + #[error("unsupported data error '{0}'")] + UnsupportedData(String), + #[error("Invalid enum: {0}")] + InvalidEnum(String), + #[error("IO error {0} '{1}'")] + IO(io::ErrorKind, String), + #[error("wrong write: expected {0} got {1}")] + WrongWrite(OperationType, OperationType), + #[error("unexpected write: got an extra {0}")] + ExtraWrite(OperationType), + #[error("got an unexpected number {0} in write_number")] + UnexpectedNumber(u64), + #[error("got an unexpected slice '{0:?}' in write_slice")] + UnexpectedSlice(Vec<u8>), + #[error("got an unexpected display '{0:?}' in write_slice")] + UnexpectedDisplay(String), +} + +impl Error { + pub fn unexpected_write_number(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteNumber) + } + + pub fn extra_write_number() -> Error { + Error::ExtraWrite(OperationType::WriteNumber) + } + + pub fn unexpected_write_slice(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteSlice) + } + + pub fn unexpected_write_display(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteDisplay) + } +} + +impl super::Error for Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + Self::Custom(msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + Self::IO(err.kind(), err.to_string()) + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + Self::UnsupportedData(msg.to_string()) + } + + fn invalid_enum<T: fmt::Display>(msg: T) -> Self { + Self::InvalidEnum(msg.to_string()) + } +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OperationType { + WriteNumber, + WriteSlice, + WriteDisplay, +} + +impl fmt::Display for OperationType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::WriteNumber => write!(f, "write_number"), + Self::WriteSlice => write!(f, "write_slice"), + Self::WriteDisplay => write!(f, "write_display"), + } + } +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, PartialEq, Eq)] +enum Operation { + WriteNumber(u64, Result<(), Error>), + WriteSlice(Vec<u8>, Result<(), Error>), + WriteDisplay(String, Result<(), Error>), +} + +impl From<Operation> for OperationType { + fn from(value: Operation) -> Self { + match value { + Operation::WriteNumber(_, _) => OperationType::WriteNumber, + Operation::WriteSlice(_, _) => OperationType::WriteSlice, + Operation::WriteDisplay(_, _) => OperationType::WriteDisplay, + } + } +} + +pub struct Builder { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Builder { + pub fn new() -> Builder { + Builder { + version: Default::default(), + ops: VecDeque::new(), + } + } + + pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self { + self.version = version.into(); + self + } + + pub fn write_number(&mut self, value: u64) -> &mut Self { + self.ops.push_back(Operation::WriteNumber(value, Ok(()))); + self + } + + pub fn write_number_error(&mut self, value: u64, err: Error) -> &mut Self { + self.ops.push_back(Operation::WriteNumber(value, Err(err))); + self + } + + pub fn write_slice(&mut self, value: &[u8]) -> &mut Self { + self.ops + .push_back(Operation::WriteSlice(value.to_vec(), Ok(()))); + self + } + + pub fn write_slice_error(&mut self, value: &[u8], err: Error) -> &mut Self { + self.ops + .push_back(Operation::WriteSlice(value.to_vec(), Err(err))); + self + } + + pub fn write_display<D>(&mut self, value: D) -> &mut Self + where + D: fmt::Display, + { + let msg = value.to_string(); + self.ops.push_back(Operation::WriteDisplay(msg, Ok(()))); + self + } + + pub fn write_display_error<D>(&mut self, value: D, err: Error) -> &mut Self + where + D: fmt::Display, + { + let msg = value.to_string(); + self.ops.push_back(Operation::WriteDisplay(msg, Err(err))); + self + } + + #[cfg(test)] + fn write_operation_type(&mut self, op: OperationType) -> &mut Self { + match op { + OperationType::WriteNumber => self.write_number(10), + OperationType::WriteSlice => self.write_slice(b"testing"), + OperationType::WriteDisplay => self.write_display("testing"), + } + } + + #[cfg(test)] + fn write_operation(&mut self, op: &Operation) -> &mut Self { + match op { + Operation::WriteNumber(value, Ok(_)) => self.write_number(*value), + Operation::WriteNumber(value, Err(Error::UnexpectedNumber(_))) => { + self.write_number(*value) + } + Operation::WriteNumber(_, Err(Error::ExtraWrite(OperationType::WriteNumber))) => self, + Operation::WriteNumber(_, Err(Error::WrongWrite(op, OperationType::WriteNumber))) => { + self.write_operation_type(*op) + } + Operation::WriteNumber(value, Err(Error::Custom(msg))) => { + self.write_number_error(*value, Error::Custom(msg.clone())) + } + Operation::WriteNumber(value, Err(Error::IO(kind, msg))) => { + self.write_number_error(*value, Error::IO(*kind, msg.clone())) + } + Operation::WriteSlice(value, Ok(_)) => self.write_slice(value), + Operation::WriteSlice(value, Err(Error::UnexpectedSlice(_))) => self.write_slice(value), + Operation::WriteSlice(_, Err(Error::ExtraWrite(OperationType::WriteSlice))) => self, + Operation::WriteSlice(_, Err(Error::WrongWrite(op, OperationType::WriteSlice))) => { + self.write_operation_type(*op) + } + Operation::WriteSlice(value, Err(Error::Custom(msg))) => { + self.write_slice_error(value, Error::Custom(msg.clone())) + } + Operation::WriteSlice(value, Err(Error::IO(kind, msg))) => { + self.write_slice_error(value, Error::IO(*kind, msg.clone())) + } + Operation::WriteDisplay(value, Ok(_)) => self.write_display(value), + Operation::WriteDisplay(value, Err(Error::Custom(msg))) => { + self.write_display_error(value, Error::Custom(msg.clone())) + } + Operation::WriteDisplay(value, Err(Error::IO(kind, msg))) => { + self.write_display_error(value, Error::IO(*kind, msg.clone())) + } + Operation::WriteDisplay(value, Err(Error::UnexpectedDisplay(_))) => { + self.write_display(value) + } + Operation::WriteDisplay(_, Err(Error::ExtraWrite(OperationType::WriteDisplay))) => self, + Operation::WriteDisplay(_, Err(Error::WrongWrite(op, OperationType::WriteDisplay))) => { + self.write_operation_type(*op) + } + s => panic!("Invalid operation {:?}", s), + } + } + + pub fn build(&mut self) -> Mock { + Mock { + version: self.version, + ops: self.ops.clone(), + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +pub struct Mock { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Mock { + #[cfg(test)] + #[allow(dead_code)] + async fn assert_operation(&mut self, op: Operation) { + match op { + Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { + assert_eq!( + self.write_number(value).await, + Err(Error::UnexpectedNumber(value)) + ); + } + Operation::WriteNumber(value, res) => { + assert_eq!(self.write_number(value).await, res); + } + Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { + assert_eq!(self.write_slice(value).await, res.clone()); + } + Operation::WriteSlice(value, res) => { + assert_eq!(self.write_slice(&value).await, res); + } + Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { + assert_eq!(self.write_display(value).await, res.clone()); + } + Operation::WriteDisplay(value, res) => { + assert_eq!(self.write_display(value).await, res); + } + } + } + + #[cfg(test)] + async fn prop_assert_operation(&mut self, op: Operation) -> Result<(), TestCaseError> { + use ::proptest::prop_assert_eq; + + match op { + Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { + prop_assert_eq!( + self.write_number(value).await, + Err(Error::UnexpectedNumber(value)) + ); + } + Operation::WriteNumber(value, res) => { + prop_assert_eq!(self.write_number(value).await, res); + } + Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { + prop_assert_eq!(self.write_slice(value).await, res.clone()); + } + Operation::WriteSlice(value, res) => { + prop_assert_eq!(self.write_slice(&value).await, res); + } + Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { + prop_assert_eq!(self.write_display(&value).await, res.clone()); + } + Operation::WriteDisplay(value, res) => { + prop_assert_eq!(self.write_display(&value).await, res); + } + } + Ok(()) + } +} + +impl NixWrite for Mock { + type Error = Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { + match self.ops.pop_front() { + Some(Operation::WriteNumber(expected, ret)) => { + if value != expected { + return Err(Error::UnexpectedNumber(value)); + } + ret + } + Some(op) => Err(Error::unexpected_write_number(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteNumber)), + } + } + + async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + match self.ops.pop_front() { + Some(Operation::WriteSlice(expected, ret)) => { + if buf != expected { + return Err(Error::UnexpectedSlice(buf.to_vec())); + } + ret + } + Some(op) => Err(Error::unexpected_write_slice(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteSlice)), + } + } + + async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error> + where + D: fmt::Display + Send, + Self: Sized, + { + let value = msg.to_string(); + match self.ops.pop_front() { + Some(Operation::WriteDisplay(expected, ret)) => { + if value != expected { + return Err(Error::UnexpectedDisplay(value)); + } + ret + } + Some(op) => Err(Error::unexpected_write_display(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteDisplay)), + } + } +} + +impl Drop for Mock { + fn drop(&mut self) { + // No need to panic again + if thread::panicking() { + return; + } + if let Some(op) = self.ops.front() { + panic!("reader dropped with {op:?} operation still unread") + } + } +} + +#[cfg(test)] +mod proptest { + use std::io; + + use proptest::{ + prelude::{any, Arbitrary, BoxedStrategy, Just, Strategy}, + prop_oneof, + }; + + use super::{Error, Operation, OperationType}; + + pub fn arb_write_number_operation() -> impl Strategy<Value = Operation> { + ( + any::<u64>(), + prop_oneof![ + Just(Ok(())), + any::<u64>().prop_map(|v| Err(Error::UnexpectedNumber(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteSlice, + OperationType::WriteNumber + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteDisplay, + OperationType::WriteNumber + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same number", |(v, res)| match res { + Err(Error::UnexpectedNumber(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteNumber(v, res)) + } + + pub fn arb_write_slice_operation() -> impl Strategy<Value = Operation> { + ( + any::<Vec<u8>>(), + prop_oneof![ + Just(Ok(())), + any::<Vec<u8>>().prop_map(|v| Err(Error::UnexpectedSlice(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteNumber, + OperationType::WriteSlice + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteDisplay, + OperationType::WriteSlice + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same slice", |(v, res)| match res { + Err(Error::UnexpectedSlice(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteSlice(v, res)) + } + + #[allow(dead_code)] + pub fn arb_extra_write() -> impl Strategy<Value = Operation> { + prop_oneof![ + any::<u64>().prop_map(|msg| { + Operation::WriteNumber(msg, Err(Error::ExtraWrite(OperationType::WriteNumber))) + }), + any::<Vec<u8>>().prop_map(|msg| { + Operation::WriteSlice(msg, Err(Error::ExtraWrite(OperationType::WriteSlice))) + }), + any::<String>().prop_map(|msg| { + Operation::WriteDisplay(msg, Err(Error::ExtraWrite(OperationType::WriteDisplay))) + }), + ] + } + + pub fn arb_write_display_operation() -> impl Strategy<Value = Operation> { + ( + any::<String>(), + prop_oneof![ + Just(Ok(())), + any::<String>().prop_map(|v| Err(Error::UnexpectedDisplay(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteNumber, + OperationType::WriteDisplay + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteSlice, + OperationType::WriteDisplay + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same string", |(v, res)| match res { + Err(Error::UnexpectedDisplay(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteDisplay(v, res)) + } + + pub fn arb_operation() -> impl Strategy<Value = Operation> { + prop_oneof![ + arb_write_number_operation(), + arb_write_slice_operation(), + arb_write_display_operation(), + ] + } + + impl Arbitrary for Operation { + type Parameters = (); + type Strategy = BoxedStrategy<Operation>; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + arb_operation().boxed() + } + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use proptest::prelude::any; + use proptest::prelude::TestCaseError; + use proptest::proptest; + + use crate::wire::ser::mock::proptest::arb_extra_write; + use crate::wire::ser::mock::Operation; + use crate::wire::ser::mock::OperationType; + use crate::wire::ser::Error as _; + use crate::wire::ser::NixWrite; + + use super::{Builder, Error}; + + #[tokio::test] + async fn write_number() { + let mut mock = Builder::new().write_number(10).build(); + mock.write_number(10).await.unwrap(); + } + + #[tokio::test] + async fn write_number_error() { + let mut mock = Builder::new() + .write_number_error(10, Error::custom("bad number")) + .build(); + assert_eq!( + Err(Error::custom("bad number")), + mock.write_number(10).await + ); + } + + #[tokio::test] + async fn write_number_unexpected() { + let mut mock = Builder::new().write_slice(b"").build(); + assert_eq!( + Err(Error::unexpected_write_number(OperationType::WriteSlice)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn write_number_unexpected_number() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::UnexpectedNumber(11)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn extra_write_number() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteNumber)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn write_slice() { + let mut mock = Builder::new() + .write_slice(&[]) + .write_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .build(); + mock.write_slice(&[]).await.expect("write_slice empty"); + mock.write_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .await + .expect("write_slice"); + } + + #[tokio::test] + async fn write_slice_error() { + let mut mock = Builder::new() + .write_slice_error(&[], Error::custom("bad slice")) + .build(); + assert_eq!(Err(Error::custom("bad slice")), mock.write_slice(&[]).await); + } + + #[tokio::test] + async fn write_slice_unexpected() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::unexpected_write_slice(OperationType::WriteNumber)), + mock.write_slice(b"").await + ); + } + + #[tokio::test] + async fn write_slice_unexpected_slice() { + let mut mock = Builder::new().write_slice(b"").build(); + assert_eq!( + Err(Error::UnexpectedSlice(b"bad slice".to_vec())), + mock.write_slice(b"bad slice").await + ); + } + + #[tokio::test] + async fn extra_write_slice() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteSlice)), + mock.write_slice(b"extra slice").await + ); + } + + #[tokio::test] + async fn write_display() { + let mut mock = Builder::new().write_display("testing").build(); + mock.write_display("testing").await.unwrap(); + } + + #[tokio::test] + async fn write_display_error() { + let mut mock = Builder::new() + .write_display_error("testing", Error::custom("bad number")) + .build(); + assert_eq!( + Err(Error::custom("bad number")), + mock.write_display("testing").await + ); + } + + #[tokio::test] + async fn write_display_unexpected() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::unexpected_write_display(OperationType::WriteNumber)), + mock.write_display("").await + ); + } + + #[tokio::test] + async fn write_display_unexpected_display() { + let mut mock = Builder::new().write_display("").build(); + assert_eq!( + Err(Error::UnexpectedDisplay("bad display".to_string())), + mock.write_display("bad display").await + ); + } + + #[tokio::test] + async fn extra_write_display() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteDisplay)), + mock.write_display("extra slice").await + ); + } + + #[test] + #[should_panic] + fn operations_left() { + let _ = Builder::new().write_number(10).build(); + } + + #[test] + fn proptest_mock() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + proptest!(|( + operations in any::<Vec<Operation>>(), + extra_operations in proptest::collection::vec(arb_extra_write(), 0..3) + )| { + rt.block_on(async { + let mut builder = Builder::new(); + for op in operations.iter() { + builder.write_operation(op); + } + for op in extra_operations.iter() { + builder.write_operation(op); + } + let mut mock = builder.build(); + for op in operations { + mock.prop_assert_operation(op).await?; + } + for op in extra_operations { + mock.prop_assert_operation(op).await?; + } + Ok(()) as Result<(), TestCaseError> + })?; + }); + } +} diff --git a/tvix/nix-compat/src/wire/ser/mod.rs b/tvix/nix-compat/src/wire/ser/mod.rs new file mode 100644 index 000000000000..ef3c6e2e372f --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/mod.rs @@ -0,0 +1,134 @@ +use std::error::Error as StdError; +use std::future::Future; +use std::{fmt, io}; + +use super::ProtocolVersion; + +mod bytes; +mod collections; +#[cfg(feature = "nix-compat-derive")] +mod display; +mod int; +#[cfg(any(test, feature = "test"))] +pub mod mock; +mod writer; + +pub use writer::{NixWriter, NixWriterBuilder}; + +pub trait Error: Sized + StdError { + fn custom<T: fmt::Display>(msg: T) -> Self; + + fn io_error(err: std::io::Error) -> Self { + Self::custom(format_args!("There was an I/O error {}", err)) + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } + + fn invalid_enum<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } +} + +impl Error for io::Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::Other, msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + err + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::InvalidData, msg.to_string()) + } +} + +pub trait NixWrite: Send { + type Error: Error; + + /// Some types are serialized differently depending on the version + /// of the protocol and so this can be used for implementing that. + fn version(&self) -> ProtocolVersion; + + /// Write a single u64 to the protocol. + fn write_number(&mut self, value: u64) -> impl Future<Output = Result<(), Self::Error>> + Send; + + /// Write a slice of bytes to the protocol. + fn write_slice(&mut self, buf: &[u8]) -> impl Future<Output = Result<(), Self::Error>> + Send; + + /// Write a value that implements `std::fmt::Display` to the protocol. + /// The protocol uses many small string formats and instead of allocating + /// a `String` each time we want to write one an implementation of `NixWrite` + /// can instead use `Display` to dump these formats to a reusable buffer. + fn write_display<D>(&mut self, msg: D) -> impl Future<Output = Result<(), Self::Error>> + Send + where + D: fmt::Display + Send, + Self: Sized, + { + async move { + let s = msg.to_string(); + self.write_slice(s.as_bytes()).await + } + } + + /// Write a value to the protocol. + /// Uses `NixSerialize::serialize` to write the value. + fn write_value<V>(&mut self, value: &V) -> impl Future<Output = Result<(), Self::Error>> + Send + where + V: NixSerialize + Send + ?Sized, + Self: Sized, + { + value.serialize(self) + } +} + +impl<T: NixWrite> NixWrite for &mut T { + type Error = T::Error; + + fn version(&self) -> ProtocolVersion { + (**self).version() + } + + fn write_number(&mut self, value: u64) -> impl Future<Output = Result<(), Self::Error>> + Send { + (**self).write_number(value) + } + + fn write_slice(&mut self, buf: &[u8]) -> impl Future<Output = Result<(), Self::Error>> + Send { + (**self).write_slice(buf) + } + + fn write_display<D>(&mut self, msg: D) -> impl Future<Output = Result<(), Self::Error>> + Send + where + D: fmt::Display + Send, + Self: Sized, + { + (**self).write_display(msg) + } + + fn write_value<V>(&mut self, value: &V) -> impl Future<Output = Result<(), Self::Error>> + Send + where + V: NixSerialize + Send + ?Sized, + Self: Sized, + { + (**self).write_value(value) + } +} + +pub trait NixSerialize { + /// Write a value to the writer. + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite; +} + +// Noop +impl NixSerialize for () { + async fn serialize<W>(&self, _writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + Ok(()) + } +} diff --git a/tvix/nix-compat/src/wire/ser/writer.rs b/tvix/nix-compat/src/wire/ser/writer.rs new file mode 100644 index 000000000000..da1c2b18c5e2 --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/writer.rs @@ -0,0 +1,306 @@ +use std::fmt::{self, Write as _}; +use std::future::poll_fn; +use std::io::{self, Cursor}; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use bytes::{Buf, BufMut, BytesMut}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +use crate::wire::{padding_len, ProtocolVersion, EMPTY_BYTES}; + +use super::{Error, NixWrite}; + +pub struct NixWriterBuilder { + buf: Option<BytesMut>, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, +} + +impl Default for NixWriterBuilder { + fn default() -> Self { + Self { + buf: Default::default(), + reserved_buf_size: 8192, + max_buf_size: 8192, + version: Default::default(), + } + } +} + +impl NixWriterBuilder { + pub fn set_buffer(mut self, buf: BytesMut) -> Self { + self.buf = Some(buf); + self + } + + pub fn set_reserved_buf_size(mut self, size: usize) -> Self { + self.reserved_buf_size = size; + self + } + + pub fn set_max_buf_size(mut self, size: usize) -> Self { + self.max_buf_size = size; + self + } + + pub fn set_version(mut self, version: ProtocolVersion) -> Self { + self.version = version; + self + } + + pub fn build<W>(self, writer: W) -> NixWriter<W> { + let buf = self + .buf + .unwrap_or_else(|| BytesMut::with_capacity(self.max_buf_size)); + NixWriter { + buf, + inner: writer, + reserved_buf_size: self.reserved_buf_size, + max_buf_size: self.max_buf_size, + version: self.version, + } + } +} + +pin_project! { + pub struct NixWriter<W> { + #[pin] + inner: W, + buf: BytesMut, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, + } +} + +impl NixWriter<Cursor<Vec<u8>>> { + pub fn builder() -> NixWriterBuilder { + NixWriterBuilder::default() + } +} + +impl<W> NixWriter<W> +where + W: AsyncWriteExt, +{ + pub fn new(writer: W) -> NixWriter<W> { + NixWriter::builder().build(writer) + } + + pub fn buffer(&self) -> &[u8] { + &self.buf[..] + } + + pub fn set_version(&mut self, version: ProtocolVersion) { + self.version = version; + } + + /// Remaining capacity in internal buffer + pub fn remaining_mut(&self) -> usize { + self.buf.capacity() - self.buf.len() + } + + fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + let mut this = self.project(); + while !this.buf.is_empty() { + let n = ready!(this.inner.as_mut().poll_write(cx, &this.buf[..]))?; + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write the buffer", + ))); + } + this.buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl<W> NixWriter<W> +where + W: AsyncWriteExt + Unpin, +{ + async fn flush_buf(&mut self) -> Result<(), io::Error> { + let mut s = Pin::new(self); + poll_fn(move |cx| s.as_mut().poll_flush_buf(cx)).await + } +} + +impl<W> AsyncWrite for NixWriter<W> +where + W: AsyncWrite, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + // Flush + if self.remaining_mut() < buf.len() { + ready!(self.as_mut().poll_flush_buf(cx))?; + } + let this = self.project(); + if buf.len() > this.buf.capacity() { + this.inner.poll_write(cx, buf) + } else { + this.buf.put_slice(buf); + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + ready!(self.as_mut().poll_flush_buf(cx))?; + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + ready!(self.as_mut().poll_flush_buf(cx))?; + self.project().inner.poll_shutdown(cx) + } +} + +impl<W> NixWrite for NixWriter<W> +where + W: AsyncWrite + Send + Unpin, +{ + type Error = io::Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { + let mut buf = [0u8; 8]; + BufMut::put_u64_le(&mut &mut buf[..], value); + self.write_all(&buf).await + } + + async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + let padding = padding_len(buf.len() as u64) as usize; + self.write_value(&buf.len()).await?; + self.write_all(buf).await?; + if padding > 0 { + self.write_all(&EMPTY_BYTES[..padding]).await + } else { + Ok(()) + } + } + + async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error> + where + D: fmt::Display + Send, + Self: Sized, + { + // Ensure that buffer has space for at least reserved_buf_size bytes + if self.remaining_mut() < self.reserved_buf_size && !self.buf.is_empty() { + self.flush_buf().await?; + } + let offset = self.buf.len(); + self.buf.put_u64_le(0); + if let Err(err) = write!(self.buf, "{}", msg) { + self.buf.truncate(offset); + return Err(Self::Error::unsupported_data(err)); + } + let len = self.buf.len() - offset - 8; + BufMut::put_u64_le(&mut &mut self.buf[offset..(offset + 8)], len as u64); + let padding = padding_len(len as u64) as usize; + self.write_all(&EMPTY_BYTES[..padding]).await + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::NixWrite; + + use super::NixWriter; + + #[rstest] + #[case(1, &hex!("0100 0000 0000 0000"))] + #[case::evil(666, &hex!("9A02 0000 0000 0000"))] + #[case::max(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_write_number(#[case] number: u64, #[case] buf: &[u8]) { + let mock = Builder::new().write(buf).build(); + let mut writer = NixWriter::new(mock); + + writer.write_number(number).await.unwrap(); + assert_eq!(writer.buffer(), buf); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } + + #[rstest] + #[case::empty(b"", &hex!("0000 0000 0000 0000"))] + #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_write_slice( + #[case] value: &[u8], + #[case] buf: &[u8], + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] buf_size: usize, + ) { + let mut builder = Builder::new(); + for chunk in buf.chunks(chunks_size) { + builder.write(chunk); + builder.wait(Duration::ZERO); + } + let mock = builder.build(); + let mut writer = NixWriter::builder().set_max_buf_size(buf_size).build(mock); + + writer.write_slice(value).await.unwrap(); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_write_display( + #[case] value: &str, + #[case] buf: &[u8], + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, + ) { + let mut builder = Builder::new(); + for chunk in buf.chunks(chunks_size) { + builder.write(chunk); + builder.wait(Duration::ZERO); + } + let mock = builder.build(); + let mut writer = NixWriter::builder().build(mock); + + writer.write_display(value).await.unwrap(); + assert_eq!(writer.buffer(), buf); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } +} |