diff options
Diffstat (limited to 'tvix/nix-compat/src')
56 files changed, 5867 insertions, 717 deletions
diff --git a/tvix/nix-compat/src/aterm/mod.rs b/tvix/nix-compat/src/aterm/mod.rs index 8806b6caf2e5..bb3b77bc7399 100644 --- a/tvix/nix-compat/src/aterm/mod.rs +++ b/tvix/nix-compat/src/aterm/mod.rs @@ -2,6 +2,6 @@ mod escape; mod parser; pub(crate) use escape::escape_bytes; -pub(crate) use parser::parse_bstr_field; -pub(crate) use parser::parse_str_list; +pub(crate) use parser::parse_bytes_field; pub(crate) use parser::parse_string_field; +pub(crate) use parser::parse_string_list; diff --git a/tvix/nix-compat/src/aterm/parser.rs b/tvix/nix-compat/src/aterm/parser.rs index a30cb40ab08d..a570573a8700 100644 --- a/tvix/nix-compat/src/aterm/parser.rs +++ b/tvix/nix-compat/src/aterm/parser.rs @@ -11,8 +11,10 @@ use nom::multi::separated_list0; use nom::sequence::delimited; use nom::IResult; -/// Parse a bstr and undo any escaping. -fn parse_escaped_bstr(i: &[u8]) -> IResult<&[u8], BString> { +/// Parse a bstr and undo any escaping (which is why this needs to allocate). +// FUTUREWORK: have a version for fields that are known to not need escaping +// (like store paths), and use &str. +fn parse_escaped_bytes(i: &[u8]) -> IResult<&[u8], BString> { escaped_transform( is_not("\"\\"), '\\', @@ -29,14 +31,14 @@ fn parse_escaped_bstr(i: &[u8]) -> IResult<&[u8], BString> { /// Parse a field in double quotes, undo any escaping, and return the unquoted /// and decoded `Vec<u8>`. -pub(crate) fn parse_bstr_field(i: &[u8]) -> IResult<&[u8], BString> { +pub(crate) fn parse_bytes_field(i: &[u8]) -> IResult<&[u8], BString> { // inside double quotes… delimited( nomchar('\"'), // There is alt(( // …either is a bstr after unescaping - parse_escaped_bstr, + parse_escaped_bytes, // …or an empty string. map(tag(b""), |_| BString::default()), )), @@ -45,8 +47,8 @@ pub(crate) fn parse_bstr_field(i: &[u8]) -> IResult<&[u8], BString> { } /// Parse a field in double quotes, undo any escaping, and return the unquoted -/// and decoded string, if it's a valid string. Or fail parsing if the bytes are -/// no valid UTF-8. +/// and decoded [String], if it's valid UTF-8. +/// Or fail parsing if the bytes are no valid UTF-8. pub(crate) fn parse_string_field(i: &[u8]) -> IResult<&[u8], String> { // inside double quotes… delimited( @@ -54,18 +56,18 @@ pub(crate) fn parse_string_field(i: &[u8]) -> IResult<&[u8], String> { // There is alt(( // either is a String after unescaping - nom::combinator::map_opt(parse_escaped_bstr, |escaped_bstr| { - String::from_utf8(escaped_bstr.into()).ok() + nom::combinator::map_opt(parse_escaped_bytes, |escaped_bytes| { + String::from_utf8(escaped_bytes.into()).ok() }), // or an empty string. - map(tag(b""), |_| String::new()), + map(tag(b""), |_| "".to_string()), )), nomchar('\"'), )(i) } -/// Parse a list of of string fields (enclosed in brackets) -pub(crate) fn parse_str_list(i: &[u8]) -> IResult<&[u8], Vec<String>> { +/// Parse a list of string fields (enclosed in brackets) +pub(crate) fn parse_string_list(i: &[u8]) -> IResult<&[u8], Vec<String>> { // inside brackets delimited( nomchar('['), @@ -89,7 +91,7 @@ mod tests { #[case] expected: &[u8], #[case] exp_rest: &[u8], ) { - let (rest, parsed) = super::parse_bstr_field(input).expect("must parse"); + let (rest, parsed) = super::parse_bytes_field(input).expect("must parse"); assert_eq!(exp_rest, rest, "expected remainder"); assert_eq!(expected, parsed); } @@ -118,7 +120,7 @@ mod tests { #[case::empty_list(b"[]", vec![], b"")] #[case::empty_list_with_rest(b"[]blub", vec![], b"blub")] fn parse_list(#[case] input: &[u8], #[case] expected: Vec<String>, #[case] exp_rest: &[u8]) { - let (rest, parsed) = super::parse_str_list(input).expect("must parse"); + let (rest, parsed) = super::parse_string_list(input).expect("must parse"); assert_eq!(exp_rest, rest, "expected remainder"); assert_eq!(expected, parsed); } diff --git a/tvix/nix-compat/src/bin/drvfmt.rs b/tvix/nix-compat/src/bin/drvfmt.rs index ddc1f0389f26..fca22c2cb2ae 100644 --- a/tvix/nix-compat/src/bin/drvfmt.rs +++ b/tvix/nix-compat/src/bin/drvfmt.rs @@ -3,6 +3,11 @@ use std::{collections::BTreeMap, io::Read}; use nix_compat::derivation::Derivation; use serde_json::json; +use mimalloc::MiMalloc; + +#[global_allocator] +static GLOBAL: MiMalloc = MiMalloc; + /// construct a serde_json::Value from a Derivation. /// Some environment values can be non-valid UTF-8 strings. /// `serde_json` prints them out really unreadable. diff --git a/tvix/nix-compat/src/derivation/mod.rs b/tvix/nix-compat/src/derivation/mod.rs index 6e12e3ea86e6..445e9cb43143 100644 --- a/tvix/nix-compat/src/derivation/mod.rs +++ b/tvix/nix-compat/src/derivation/mod.rs @@ -36,11 +36,11 @@ pub struct Derivation { /// Map from drv path to output names used from this derivation. #[serde(rename = "inputDrvs")] - pub input_derivations: BTreeMap<StorePath, BTreeSet<String>>, + pub input_derivations: BTreeMap<StorePath<String>, BTreeSet<String>>, /// Plain store paths of additional inputs. #[serde(rename = "inputSrcs")] - pub input_sources: BTreeSet<StorePath>, + pub input_sources: BTreeSet<StorePath<String>>, /// Maps output names to Output. pub outputs: BTreeMap<String, Output>, @@ -127,7 +127,10 @@ impl Derivation { /// the `name` with a `.drv` suffix as name, all [Derivation::input_sources] and /// keys of [Derivation::input_derivations] as references, and the ATerm string of /// the [Derivation] as content. - pub fn calculate_derivation_path(&self, name: &str) -> Result<StorePath, DerivationError> { + pub fn calculate_derivation_path( + &self, + name: &str, + ) -> Result<StorePath<String>, DerivationError> { // append .drv to the name let name = &format!("{}.drv", name); @@ -141,7 +144,6 @@ impl Derivation { .collect(); build_text_path(name, self.to_aterm_bytes(), references) - .map(|s| s.to_owned()) .map_err(|_e| DerivationError::InvalidOutputName(name.to_string())) } @@ -210,7 +212,7 @@ impl Derivation { self.input_derivations .iter() .map(|(drv_path, output_names)| { - let hash = fn_lookup_hash_derivation_modulo(&drv_path.into()); + let hash = fn_lookup_hash_derivation_modulo(&drv_path.as_ref()); (hash, output_names.to_owned()) }), @@ -255,8 +257,8 @@ impl Derivation { // For fixed output derivation we use [build_ca_path], otherwise we // use [build_output_path] with [hash_derivation_modulo]. - let abs_store_path = if let Some(ref hwm) = output.ca_hash { - build_ca_path(&path_name, hwm, Vec::<String>::new(), false).map_err(|e| { + let store_path = if let Some(ref hwm) = output.ca_hash { + build_ca_path(&path_name, hwm, Vec::<&str>::new(), false).map_err(|e| { DerivationError::InvalidOutputDerivationPath(output_name.to_string(), e) })? } else { @@ -268,11 +270,11 @@ impl Derivation { })? }; - output.path = Some(abs_store_path.to_owned()); self.environment.insert( output_name.to_string(), - abs_store_path.to_absolute_path().into(), + store_path.to_absolute_path().into(), ); + output.path = Some(store_path); } Ok(()) diff --git a/tvix/nix-compat/src/derivation/output.rs b/tvix/nix-compat/src/derivation/output.rs index 266617f587f8..0b81ef3c3155 100644 --- a/tvix/nix-compat/src/derivation/output.rs +++ b/tvix/nix-compat/src/derivation/output.rs @@ -1,5 +1,4 @@ use crate::nixhash::CAHash; -use crate::store_path::StorePathRef; use crate::{derivation::OutputError, store_path::StorePath}; use serde::de::Unexpected; use serde::{Deserialize, Serialize}; @@ -10,7 +9,7 @@ use std::borrow::Cow; #[derive(Clone, Debug, Default, Eq, PartialEq, Serialize)] pub struct Output { /// Store path of build result. - pub path: Option<StorePath>, + pub path: Option<StorePath<String>>, #[serde(flatten)] pub ca_hash: Option<CAHash>, // we can only represent a subset here. @@ -33,10 +32,10 @@ impl<'de> Deserialize<'de> for Output { &"a string", ))?; - let path = StorePathRef::from_absolute_path(path.as_bytes()) + let path = StorePath::from_absolute_path(path.as_bytes()) .map_err(|_| serde::de::Error::invalid_value(Unexpected::Str(path), &"StorePath"))?; Ok(Self { - path: Some(path.to_owned()), + path: Some(path), ca_hash: CAHash::from_map::<D>(&fields)?, }) } diff --git a/tvix/nix-compat/src/derivation/parse_error.rs b/tvix/nix-compat/src/derivation/parse_error.rs index fc97f1a9883b..f625d9aeb724 100644 --- a/tvix/nix-compat/src/derivation/parse_error.rs +++ b/tvix/nix-compat/src/derivation/parse_error.rs @@ -20,7 +20,7 @@ pub enum ErrorKind { DuplicateInputDerivationOutputName(String, String), #[error("duplicate input source: {0}")] - DuplicateInputSource(StorePath), + DuplicateInputSource(StorePath<String>), #[error("nix hash error: {0}")] NixHashError(nixhash::Error), diff --git a/tvix/nix-compat/src/derivation/parser.rs b/tvix/nix-compat/src/derivation/parser.rs index 2775294960fe..a94ed2281a86 100644 --- a/tvix/nix-compat/src/derivation/parser.rs +++ b/tvix/nix-compat/src/derivation/parser.rs @@ -3,7 +3,6 @@ //! //! [ATerm]: http://program-transformation.org/Tools/ATermFormat.html -use bstr::BString; use nom::bytes::complete::tag; use nom::character::complete::char as nomchar; use nom::combinator::{all_consuming, map_res}; @@ -14,7 +13,7 @@ use thiserror; use crate::derivation::parse_error::{into_nomerror, ErrorKind, NomError, NomResult}; use crate::derivation::{write, CAHash, Derivation, Output}; -use crate::store_path::{self, StorePath, StorePathRef}; +use crate::store_path::{self, StorePath}; use crate::{aterm, nixhash}; #[derive(Debug, thiserror::Error)] @@ -73,7 +72,7 @@ fn parse_output(i: &[u8]) -> NomResult<&[u8], (String, Output)> { terminated(aterm::parse_string_field, nomchar(',')), terminated(aterm::parse_string_field, nomchar(',')), terminated(aterm::parse_string_field, nomchar(',')), - aterm::parse_bstr_field, + aterm::parse_bytes_field, ))(i) .map_err(into_nomerror) }, @@ -102,7 +101,7 @@ fn parse_output(i: &[u8]) -> NomResult<&[u8], (String, Output)> { path: if output_path.is_empty() { None } else { - Some(string_to_store_path(i, output_path)?) + Some(string_to_store_path(i, &output_path)?) }, ca_hash: hash_with_mode, }, @@ -132,12 +131,12 @@ fn parse_outputs(i: &[u8]) -> NomResult<&[u8], BTreeMap<String, Output>> { match res { Ok((rst, outputs_lst)) => { - let mut outputs: BTreeMap<String, Output> = BTreeMap::default(); + let mut outputs = BTreeMap::default(); for (output_name, output) in outputs_lst.into_iter() { if outputs.contains_key(&output_name) { return Err(nom::Err::Failure(NomError { input: i, - code: ErrorKind::DuplicateMapKey(output_name), + code: ErrorKind::DuplicateMapKey(output_name.to_string()), })); } outputs.insert(output_name, output); @@ -149,11 +148,13 @@ fn parse_outputs(i: &[u8]) -> NomResult<&[u8], BTreeMap<String, Output>> { } } -fn parse_input_derivations(i: &[u8]) -> NomResult<&[u8], BTreeMap<StorePath, BTreeSet<String>>> { - let (i, input_derivations_list) = parse_kv::<Vec<String>, _>(aterm::parse_str_list)(i)?; +fn parse_input_derivations( + i: &[u8], +) -> NomResult<&[u8], BTreeMap<StorePath<String>, BTreeSet<String>>> { + let (i, input_derivations_list) = parse_kv(aterm::parse_string_list)(i)?; // This is a HashMap of drv paths to a list of output names. - let mut input_derivations: BTreeMap<StorePath, BTreeSet<String>> = BTreeMap::new(); + let mut input_derivations: BTreeMap<StorePath<String>, BTreeSet<_>> = BTreeMap::new(); for (input_derivation, output_names) in input_derivations_list { let mut new_output_names = BTreeSet::new(); @@ -170,7 +171,7 @@ fn parse_input_derivations(i: &[u8]) -> NomResult<&[u8], BTreeMap<StorePath, BTr new_output_names.insert(output_name); } - let input_derivation: StorePath = string_to_store_path(i, input_derivation)?; + let input_derivation = string_to_store_path(i, input_derivation.as_str())?; input_derivations.insert(input_derivation, new_output_names); } @@ -178,16 +179,16 @@ fn parse_input_derivations(i: &[u8]) -> NomResult<&[u8], BTreeMap<StorePath, BTr Ok((i, input_derivations)) } -fn parse_input_sources(i: &[u8]) -> NomResult<&[u8], BTreeSet<StorePath>> { - let (i, input_sources_lst) = aterm::parse_str_list(i).map_err(into_nomerror)?; +fn parse_input_sources(i: &[u8]) -> NomResult<&[u8], BTreeSet<StorePath<String>>> { + let (i, input_sources_lst) = aterm::parse_string_list(i).map_err(into_nomerror)?; let mut input_sources: BTreeSet<_> = BTreeSet::new(); for input_source in input_sources_lst.into_iter() { - let input_source: StorePath = string_to_store_path(i, input_source)?; + let input_source = string_to_store_path(i, input_source.as_str())?; if input_sources.contains(&input_source) { return Err(nom::Err::Failure(NomError { input: i, - code: ErrorKind::DuplicateInputSource(input_source), + code: ErrorKind::DuplicateInputSource(input_source.to_owned()), })); } else { input_sources.insert(input_source); @@ -197,24 +198,23 @@ fn parse_input_sources(i: &[u8]) -> NomResult<&[u8], BTreeSet<StorePath>> { Ok((i, input_sources)) } -fn string_to_store_path( - i: &[u8], - path_str: String, -) -> Result<StorePath, nom::Err<NomError<&[u8]>>> { - #[cfg(debug_assertions)] - let path_str2 = path_str.clone(); - - let path: StorePath = StorePathRef::from_absolute_path(path_str.as_bytes()) - .map_err(|e: store_path::Error| { +fn string_to_store_path<'a, 'i, S>( + i: &'i [u8], + path_str: &'a str, +) -> Result<StorePath<S>, nom::Err<NomError<&'i [u8]>>> +where + S: std::clone::Clone + AsRef<str> + std::convert::From<&'a str>, +{ + let path = + StorePath::from_absolute_path(path_str.as_bytes()).map_err(|e: store_path::Error| { nom::Err::Failure(NomError { input: i, code: e.into(), }) - })? - .to_owned(); + })?; #[cfg(debug_assertions)] - assert_eq!(path_str2, path.to_absolute_path()); + assert_eq!(path_str, path.to_absolute_path()); Ok(path) } @@ -240,9 +240,9 @@ pub fn parse_derivation(i: &[u8]) -> NomResult<&[u8], Derivation> { // // parse builder |i| terminated(aterm::parse_string_field, nomchar(','))(i).map_err(into_nomerror), // // parse arguments - |i| terminated(aterm::parse_str_list, nomchar(','))(i).map_err(into_nomerror), + |i| terminated(aterm::parse_string_list, nomchar(','))(i).map_err(into_nomerror), // parse environment - parse_kv::<BString, _>(aterm::parse_bstr_field), + parse_kv(aterm::parse_bytes_field), )), nomchar(')'), ) @@ -329,6 +329,7 @@ where mod tests { use crate::store_path::StorePathRef; use std::collections::{BTreeMap, BTreeSet}; + use std::sync::LazyLock; use crate::{ derivation::{ @@ -338,49 +339,48 @@ mod tests { }; use bstr::{BString, ByteSlice}; use hex_literal::hex; - use lazy_static::lazy_static; use rstest::rstest; const DIGEST_SHA256: [u8; 32] = hex!("a5ce9c155ed09397614646c9717fc7cd94b1023d7b76b618d409e4fefd6e9d39"); - lazy_static! { - pub static ref NIXHASH_SHA256: NixHash = NixHash::Sha256(DIGEST_SHA256); - static ref EXP_MULTI_OUTPUTS: BTreeMap<String, Output> = { - let mut b = BTreeMap::new(); - b.insert( - "lib".to_string(), - Output { - path: Some( - StorePath::from_bytes( - b"2vixb94v0hy2xc6p7mbnxxcyc095yyia-has-multi-out-lib", - ) + static NIXHASH_SHA256: NixHash = NixHash::Sha256(DIGEST_SHA256); + static EXP_MULTI_OUTPUTS: LazyLock<BTreeMap<String, Output>> = LazyLock::new(|| { + let mut b = BTreeMap::new(); + b.insert( + "lib".to_string(), + Output { + path: Some( + StorePath::from_bytes(b"2vixb94v0hy2xc6p7mbnxxcyc095yyia-has-multi-out-lib") .unwrap(), - ), - ca_hash: None, - }, - ); - b.insert( - "out".to_string(), - Output { - path: Some( - StorePath::from_bytes( - b"55lwldka5nyxa08wnvlizyqw02ihy8ic-has-multi-out".as_bytes(), - ) - .unwrap(), - ), - ca_hash: None, - }, - ); - b - }; - static ref EXP_AB_MAP: BTreeMap<String, BString> = { - let mut b = BTreeMap::new(); - b.insert("a".to_string(), b"1".as_bstr().to_owned()); - b.insert("b".to_string(), b"2".as_bstr().to_owned()); - b - }; - static ref EXP_INPUT_DERIVATIONS_SIMPLE: BTreeMap<StorePath, BTreeSet<String>> = { + ), + ca_hash: None, + }, + ); + b.insert( + "out".to_string(), + Output { + path: Some( + StorePath::from_bytes( + b"55lwldka5nyxa08wnvlizyqw02ihy8ic-has-multi-out".as_bytes(), + ) + .unwrap(), + ), + ca_hash: None, + }, + ); + b + }); + + static EXP_AB_MAP: LazyLock<BTreeMap<String, BString>> = LazyLock::new(|| { + let mut b = BTreeMap::new(); + b.insert("a".to_string(), b"1".into()); + b.insert("b".to_string(), b"2".into()); + b + }); + + static EXP_INPUT_DERIVATIONS_SIMPLE: LazyLock<BTreeMap<StorePath<String>, BTreeSet<String>>> = + LazyLock::new(|| { let mut b = BTreeMap::new(); b.insert( StorePath::from_bytes(b"8bjm87p310sb7r2r0sg4xrynlvg86j8k-hello-2.12.1.tar.gz.drv") @@ -402,21 +402,22 @@ mod tests { }, ); b - }; - static ref EXP_INPUT_DERIVATIONS_SIMPLE_ATERM: String = { - format!( - "[(\"{0}\",[\"out\"]),(\"{1}\",[\"out\",\"lib\"])]", - "/nix/store/8bjm87p310sb7r2r0sg4xrynlvg86j8k-hello-2.12.1.tar.gz.drv", - "/nix/store/p3jc8aw45dza6h52v81j7lk69khckmcj-bash-5.2-p15.drv" - ) - }; - static ref EXP_INPUT_SOURCES_SIMPLE: BTreeSet<String> = { - let mut b = BTreeSet::new(); - b.insert("/nix/store/55lwldka5nyxa08wnvlizyqw02ihy8ic-has-multi-out".to_string()); - b.insert("/nix/store/2vixb94v0hy2xc6p7mbnxxcyc095yyia-has-multi-out-lib".to_string()); - b - }; - } + }); + + static EXP_INPUT_DERIVATIONS_SIMPLE_ATERM: LazyLock<String> = LazyLock::new(|| { + format!( + "[(\"{0}\",[\"out\"]),(\"{1}\",[\"out\",\"lib\"])]", + "/nix/store/8bjm87p310sb7r2r0sg4xrynlvg86j8k-hello-2.12.1.tar.gz.drv", + "/nix/store/p3jc8aw45dza6h52v81j7lk69khckmcj-bash-5.2-p15.drv" + ) + }); + + static EXP_INPUT_SOURCES_SIMPLE: LazyLock<BTreeSet<String>> = LazyLock::new(|| { + let mut b = BTreeSet::new(); + b.insert("/nix/store/55lwldka5nyxa08wnvlizyqw02ihy8ic-has-multi-out".to_string()); + b.insert("/nix/store/2vixb94v0hy2xc6p7mbnxxcyc095yyia-has-multi-out-lib".to_string()); + b + }); /// Ensure parsing KVs works #[rstest] @@ -427,8 +428,8 @@ mod tests { #[case] expected: &BTreeMap<String, BString>, #[case] exp_rest: &[u8], ) { - let (rest, parsed) = super::parse_kv::<BString, _>(crate::aterm::parse_bstr_field)(input) - .expect("must parse"); + let (rest, parsed) = + super::parse_kv(crate::aterm::parse_bytes_field)(input).expect("must parse"); assert_eq!(exp_rest, rest, "expected remainder"); assert_eq!(*expected, parsed); } @@ -437,8 +438,7 @@ mod tests { #[test] fn parse_kv_fail_dup_keys() { let input: &'static [u8] = b"[(\"a\",\"1\"),(\"a\",\"2\")]"; - let e = super::parse_kv::<BString, _>(crate::aterm::parse_bstr_field)(input) - .expect_err("must fail"); + let e = super::parse_kv(crate::aterm::parse_bytes_field)(input).expect_err("must fail"); match e { nom::Err::Failure(e) => { @@ -454,7 +454,7 @@ mod tests { #[case::simple(EXP_INPUT_DERIVATIONS_SIMPLE_ATERM.as_bytes(), &EXP_INPUT_DERIVATIONS_SIMPLE)] fn parse_input_derivations( #[case] input: &'static [u8], - #[case] expected: &BTreeMap<StorePath, BTreeSet<String>>, + #[case] expected: &BTreeMap<StorePath<String>, BTreeSet<String>>, ) { let (rest, parsed) = super::parse_input_derivations(input).expect("must parse"); diff --git a/tvix/nix-compat/src/derivation/write.rs b/tvix/nix-compat/src/derivation/write.rs index 735b781574e1..a8b43fad4cc6 100644 --- a/tvix/nix-compat/src/derivation/write.rs +++ b/tvix/nix-compat/src/derivation/write.rs @@ -6,7 +6,7 @@ use crate::aterm::escape_bytes; use crate::derivation::{ca_kind_prefix, output::Output}; use crate::nixbase32; -use crate::store_path::{StorePath, StorePathRef, STORE_DIR_WITH_SLASH}; +use crate::store_path::{StorePath, STORE_DIR_WITH_SLASH}; use bstr::BString; use data_encoding::HEXLOWER; @@ -32,34 +32,23 @@ pub const QUOTE: char = '"'; /// the context a lot. pub(crate) trait AtermWriteable { fn aterm_write(&self, writer: &mut impl Write) -> std::io::Result<()>; - - fn aterm_bytes(&self) -> Vec<u8> { - let mut bytes = Vec::new(); - self.aterm_write(&mut bytes) - .expect("unexpected write errors to Vec"); - bytes - } } -impl AtermWriteable for StorePathRef<'_> { +impl<S> AtermWriteable for StorePath<S> +where + S: AsRef<str>, +{ fn aterm_write(&self, writer: &mut impl Write) -> std::io::Result<()> { write_char(writer, QUOTE)?; writer.write_all(STORE_DIR_WITH_SLASH.as_bytes())?; writer.write_all(nixbase32::encode(self.digest()).as_bytes())?; write_char(writer, '-')?; - writer.write_all(self.name().as_bytes())?; + writer.write_all(self.name().as_ref().as_bytes())?; write_char(writer, QUOTE)?; Ok(()) } } -impl AtermWriteable for StorePath { - fn aterm_write(&self, writer: &mut impl Write) -> std::io::Result<()> { - let r: StorePathRef = self.into(); - r.aterm_write(writer) - } -} - impl AtermWriteable for String { fn aterm_write(&self, writer: &mut impl Write) -> std::io::Result<()> { write_field(writer, self, true) @@ -186,7 +175,7 @@ pub(crate) fn write_input_derivations( pub(crate) fn write_input_sources( writer: &mut impl Write, - input_sources: &BTreeSet<StorePath>, + input_sources: &BTreeSet<StorePath<String>>, ) -> Result<(), io::Error> { write_char(writer, BRACKET_OPEN)?; write_array_elements( diff --git a/tvix/nix-compat/src/lib.rs b/tvix/nix-compat/src/lib.rs index a71ede3eecf0..4c327fa4569b 100644 --- a/tvix/nix-compat/src/lib.rs +++ b/tvix/nix-compat/src/lib.rs @@ -1,8 +1,12 @@ +extern crate self as nix_compat; + pub(crate) mod aterm; pub mod derivation; pub mod nar; pub mod narinfo; +pub mod nix_http; pub mod nixbase32; +pub mod nixcpp; pub mod nixhash; pub mod path_info; pub mod store_path; @@ -10,9 +14,7 @@ pub mod store_path; #[cfg(feature = "wire")] pub mod wire; -#[cfg(feature = "wire")] -mod nix_daemon; -#[cfg(feature = "wire")] +#[cfg(feature = "daemon")] +pub mod nix_daemon; +#[cfg(feature = "daemon")] pub use nix_daemon::worker_protocol; -#[cfg(feature = "wire")] -pub use nix_daemon::ProtocolVersion; diff --git a/tvix/nix-compat/src/nar/listing/mod.rs b/tvix/nix-compat/src/nar/listing/mod.rs new file mode 100644 index 000000000000..5a9a3b4d3613 --- /dev/null +++ b/tvix/nix-compat/src/nar/listing/mod.rs @@ -0,0 +1,128 @@ +//! Parser for the Nix archive listing format, aka .ls. +//! +//! LS files are produced by the C++ Nix implementation via `write-nar-listing=1` query parameter +//! passed to a store implementation when transferring store paths. +//! +//! Listing files contains metadata about a file and its offset in the corresponding NAR. +//! +//! NOTE: LS entries does not offer any integrity field to validate the retrieved file at the provided +//! offset. Validating the contents is the caller's responsibility. + +use std::{ + collections::HashMap, + path::{Component, Path}, +}; + +use serde::Deserialize; + +#[cfg(test)] +mod test; + +#[derive(Debug, thiserror::Error)] +pub enum ListingError { + // TODO: add an enum of what component was problematic + // reusing `std::path::Component` is not possible as it contains a lifetime. + /// An unsupported path component can be: + /// - either a Windows prefix (`C:\\`, `\\share\\`) + /// - either a parent directory (`..`) + /// - either a root directory (`/`) + #[error("unsupported path component")] + UnsupportedPathComponent, + #[error("invalid encoding for entry component")] + InvalidEncoding, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ListingEntry { + Regular { + size: u64, + #[serde(default)] + executable: bool, + #[serde(rename = "narOffset")] + nar_offset: u64, + }, + Directory { + // It's tempting to think that the key should be a `Vec<u8>` + // but Nix does not support that and will fail to emit a listing version 1 for any non-UTF8 + // encodeable string. + entries: HashMap<String, ListingEntry>, + }, + Symlink { + target: String, + }, +} + +impl ListingEntry { + /// Given a relative path without `..` component, this will locate, relative to this entry, a + /// deeper entry. + /// + /// If the path is invalid, a listing error [`ListingError`] will be returned. + /// If the entry cannot be found, `None` will be returned. + pub fn locate<P: AsRef<Path>>(&self, path: P) -> Result<Option<&ListingEntry>, ListingError> { + // We perform a simple DFS on the components of the path + // while rejecting dangerous components, e.g. `..` or `/` + // Files and symlinks are *leaves*, i.e. we return them + let mut cur = self; + for component in path.as_ref().components() { + match component { + Component::CurDir => continue, + Component::RootDir | Component::Prefix(_) | Component::ParentDir => { + return Err(ListingError::UnsupportedPathComponent) + } + Component::Normal(file_or_dir_name) => { + if let Self::Directory { entries } = cur { + // As Nix cannot encode non-UTF8 components in the listing (see comment on + // the `Directory` enum variant), invalid encodings path components are + // errors. + let entry_name = file_or_dir_name + .to_str() + .ok_or(ListingError::InvalidEncoding)?; + + if let Some(new_entry) = entries.get(entry_name) { + cur = new_entry; + } else { + return Ok(None); + } + } else { + return Ok(None); + } + } + } + } + + // By construction, we found the node that corresponds to the path traversal. + Ok(Some(cur)) + } +} + +#[derive(Debug)] +pub struct ListingVersion<const V: u8>; + +#[derive(Debug, thiserror::Error)] +#[error("Invalid version: {0}")] +struct ListingVersionError(u8); + +impl<'de, const V: u8> Deserialize<'de> for ListingVersion<V> { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + let value = u8::deserialize(deserializer)?; + if value == V { + Ok(ListingVersion::<V>) + } else { + Err(serde::de::Error::custom(ListingVersionError(value))) + } + } +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +#[non_exhaustive] +pub enum Listing { + V1 { + root: ListingEntry, + version: ListingVersion<1>, + }, +} diff --git a/tvix/nix-compat/src/nar/listing/test.rs b/tvix/nix-compat/src/nar/listing/test.rs new file mode 100644 index 000000000000..5b2ac3f166fe --- /dev/null +++ b/tvix/nix-compat/src/nar/listing/test.rs @@ -0,0 +1,59 @@ +use std::{collections::HashMap, path::PathBuf, str::FromStr}; + +use crate::nar; + +#[test] +fn weird_paths() { + let root = nar::listing::ListingEntry::Directory { + entries: HashMap::new(), + }; + + root.locate("../../../../etc/passwd") + .expect_err("Failed to reject `../` fragment in a path during traversal"); + + // Gated on Windows as C:\\ is parsed as `Component::Normal(_)` on Linux. + #[cfg(target_os = "windows")] + root.locate("C:\\\\Windows\\System32") + .expect_err("Failed to reject Windows-style prefixes"); + + root.locate("/etc/passwd") + .expect_err("Failed to reject absolute UNIX paths"); +} + +#[test] +fn nixos_release() { + let listing_bytes = include_bytes!("../tests/nixos-release.ls"); + let listing: nar::listing::Listing = serde_json::from_slice(listing_bytes).unwrap(); + + let nar::listing::Listing::V1 { root, .. } = listing; + assert!(matches!(root, nar::listing::ListingEntry::Directory { .. })); + + let build_products = root + .locate(PathBuf::from_str("nix-support/hydra-build-products").unwrap()) + .expect("Failed to locate a known file in a directory") + .expect("File was unexpectedly not found in the listing"); + + assert!(matches!( + build_products, + nar::listing::ListingEntry::Regular { .. } + )); + + let nonexisting_file = root + .locate(PathBuf::from_str("nix-support/does-not-exist").unwrap()) + .expect("Failed to locate an unknown file in a directory"); + + assert!( + nonexisting_file.is_none(), + "Non-existing file was unexpectedly found in the listing" + ); + + let existing_dir = root + .locate(PathBuf::from_str("nix-support").unwrap()) + .expect("Failed to locate a known directory in a directory") + .expect("Directory was expectedly found in the listing"); + + assert!(matches!( + existing_dir, + nar::listing::ListingEntry::Directory { .. } + )); +} diff --git a/tvix/nix-compat/src/nar/mod.rs b/tvix/nix-compat/src/nar/mod.rs index 058977f4fcd1..d0e8ee8a412f 100644 --- a/tvix/nix-compat/src/nar/mod.rs +++ b/tvix/nix-compat/src/nar/mod.rs @@ -1,4 +1,5 @@ -mod wire; +pub(crate) mod wire; +pub mod listing; pub mod reader; pub mod writer; diff --git a/tvix/nix-compat/src/nar/reader/async/mod.rs b/tvix/nix-compat/src/nar/reader/async/mod.rs new file mode 100644 index 000000000000..0808fba38c47 --- /dev/null +++ b/tvix/nix-compat/src/nar/reader/async/mod.rs @@ -0,0 +1,173 @@ +use std::{ + mem::MaybeUninit, + pin::Pin, + task::{self, Poll}, +}; + +use tokio::io::{self, AsyncBufRead, AsyncRead, ErrorKind::InvalidData}; + +// Required reading for understanding this module. +use crate::{ + nar::{self, wire::PadPar}, + wire::{self, BytesReader}, +}; + +mod read; +#[cfg(test)] +mod test; + +pub type Reader<'a> = dyn AsyncBufRead + Unpin + Send + 'a; + +/// Start reading a NAR file from `reader`. +pub async fn open<'a, 'r>(reader: &'a mut Reader<'r>) -> io::Result<Node<'a, 'r>> { + read::token(reader, &nar::wire::TOK_NAR).await?; + Node::new(reader).await +} + +pub enum Node<'a, 'r: 'a> { + Symlink { + target: Vec<u8>, + }, + File { + executable: bool, + reader: FileReader<'a, 'r>, + }, + Directory(DirReader<'a, 'r>), +} + +impl<'a, 'r: 'a> Node<'a, 'r> { + /// Start reading a [Node], matching the next [wire::Node]. + /// + /// Reading the terminating [wire::TOK_PAR] is done immediately for [Node::Symlink], + /// but is otherwise left to [DirReader] or [BytesReader]. + async fn new(reader: &'a mut Reader<'r>) -> io::Result<Self> { + Ok(match read::tag(reader).await? { + nar::wire::Node::Sym => { + let target = wire::read_bytes(reader, 1..=nar::wire::MAX_TARGET_LEN).await?; + + if target.contains(&0) { + return Err(InvalidData.into()); + } + + read::token(reader, &nar::wire::TOK_PAR).await?; + + Node::Symlink { target } + } + tag @ (nar::wire::Node::Reg | nar::wire::Node::Exe) => Node::File { + executable: tag == nar::wire::Node::Exe, + reader: FileReader { + inner: BytesReader::new_internal(reader, ..).await?, + }, + }, + nar::wire::Node::Dir => Node::Directory(DirReader::new(reader)), + }) + } +} + +/// File contents, readable through the [AsyncRead] trait. +/// +/// It comes with some caveats: +/// * You must always read the entire file, unless you intend to abandon the entire archive reader. +/// * You must abandon the entire archive reader upon the first error. +/// +/// It's fine to read exactly `reader.len()` bytes without ever seeing an explicit EOF. +pub struct FileReader<'a, 'r> { + inner: BytesReader<&'a mut Reader<'r>, PadPar>, +} + +impl<'a, 'r> FileReader<'a, 'r> { + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn len(&self) -> u64 { + self.inner.len() + } +} + +impl<'a, 'r> AsyncRead for FileReader<'a, 'r> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &mut io::ReadBuf, + ) -> Poll<io::Result<()>> { + Pin::new(&mut self.get_mut().inner).poll_read(cx, buf) + } +} + +impl<'a, 'r> AsyncBufRead for FileReader<'a, 'r> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<io::Result<&[u8]>> { + Pin::new(&mut self.get_mut().inner).poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + Pin::new(&mut self.get_mut().inner).consume(amt) + } +} + +/// A directory iterator, yielding a sequence of [Node]s. +/// It must be fully consumed before reading further from the [DirReader] that produced it, if any. +pub struct DirReader<'a, 'r> { + reader: &'a mut Reader<'r>, + /// Previous directory entry name. + /// We have to hang onto this to enforce name monotonicity. + prev_name: Vec<u8>, +} + +pub struct Entry<'a, 'r> { + pub name: &'a [u8], + pub node: Node<'a, 'r>, +} + +impl<'a, 'r> DirReader<'a, 'r> { + fn new(reader: &'a mut Reader<'r>) -> Self { + Self { + reader, + prev_name: vec![], + } + } + + /// Read the next [Entry] from the directory. + /// + /// We explicitly don't implement [Iterator], since treating this as + /// a regular Rust iterator will surely lead you astray. + /// + /// * You must always consume the entire iterator, unless you abandon the entire archive reader. + /// * You must abandon the entire archive reader on the first error. + /// * You must abandon the directory reader upon the first [None]. + /// * Even if you know the amount of elements up front, you must keep reading until you encounter [None]. + pub async fn next(&mut self) -> io::Result<Option<Entry<'_, 'r>>> { + // COME FROM the previous iteration: if we've already read an entry, + // read its terminating TOK_PAR here. + if !self.prev_name.is_empty() { + read::token(self.reader, &nar::wire::TOK_PAR).await?; + } + + if let nar::wire::Entry::None = read::tag(self.reader).await? { + return Ok(None); + } + + let mut name = [MaybeUninit::uninit(); nar::wire::MAX_NAME_LEN + 1]; + let name = + wire::read_bytes_buf(self.reader, &mut name, 1..=nar::wire::MAX_NAME_LEN).await?; + + if name.contains(&0) || name.contains(&b'/') || name == b"." || name == b".." { + return Err(InvalidData.into()); + } + + // Enforce strict monotonicity of directory entry names. + if &self.prev_name[..] >= name { + return Err(InvalidData.into()); + } + + self.prev_name.clear(); + self.prev_name.extend_from_slice(name); + + read::token(self.reader, &nar::wire::TOK_NOD).await?; + + Ok(Some(Entry { + name: &self.prev_name, + node: Node::new(self.reader).await?, + })) + } +} diff --git a/tvix/nix-compat/src/nar/reader/async/read.rs b/tvix/nix-compat/src/nar/reader/async/read.rs new file mode 100644 index 000000000000..2adf894922c5 --- /dev/null +++ b/tvix/nix-compat/src/nar/reader/async/read.rs @@ -0,0 +1,69 @@ +use tokio::io::{ + self, AsyncReadExt, + ErrorKind::{InvalidData, UnexpectedEof}, +}; + +use crate::nar::wire::Tag; + +use super::Reader; + +/// Consume a known token from the reader. +pub async fn token<const N: usize>(reader: &mut Reader<'_>, token: &[u8; N]) -> io::Result<()> { + let mut buf = [0u8; N]; + + // This implements something similar to [AsyncReadExt::read_exact], but verifies that + // the input data matches the token while we read it. These two slices respectively + // represent the remaining token to be verified, and the remaining input buffer. + let mut token = &token[..]; + let mut buf = &mut buf[..]; + + while !token.is_empty() { + match reader.read(buf).await? { + 0 => { + return Err(UnexpectedEof.into()); + } + n => { + let (t, b); + (t, token) = token.split_at(n); + (b, buf) = buf.split_at_mut(n); + + if t != b { + return Err(InvalidData.into()); + } + } + } + } + + Ok(()) +} + +/// Consume a [Tag] from the reader. +pub async fn tag<T: Tag>(reader: &mut Reader<'_>) -> io::Result<T> { + let mut buf = T::make_buf(); + let buf = buf.as_mut(); + + // first read the known minimum length… + reader.read_exact(&mut buf[..T::MIN]).await?; + + // then decide which tag we're expecting + let tag = T::from_u8(buf[T::OFF]).ok_or(InvalidData)?; + let (head, tail) = tag.as_bytes().split_at(T::MIN); + + // make sure what we've read so far is valid + if buf[..T::MIN] != *head { + return Err(InvalidData.into()); + } + + // …then read the rest, if any + if !tail.is_empty() { + let rest = tail.len(); + reader.read_exact(&mut buf[..rest]).await?; + + // and make sure it's what we expect + if buf[..rest] != *tail { + return Err(InvalidData.into()); + } + } + + Ok(tag) +} diff --git a/tvix/nix-compat/src/nar/reader/async/test.rs b/tvix/nix-compat/src/nar/reader/async/test.rs new file mode 100644 index 000000000000..7bc1f8942f50 --- /dev/null +++ b/tvix/nix-compat/src/nar/reader/async/test.rs @@ -0,0 +1,310 @@ +use tokio::io::AsyncReadExt; + +mod nar { + pub use crate::nar::reader::r#async as reader; +} + +#[tokio::test] +async fn symlink() { + let mut f = std::io::Cursor::new(include_bytes!("../../tests/symlink.nar")); + let node = nar::reader::open(&mut f).await.unwrap(); + + match node { + nar::reader::Node::Symlink { target } => { + assert_eq!( + &b"/nix/store/somewhereelse"[..], + &target, + "target must match" + ); + } + _ => panic!("unexpected type"), + } +} + +#[tokio::test] +async fn file() { + let mut f = std::io::Cursor::new(include_bytes!("../../tests/helloworld.nar")); + let node = nar::reader::open(&mut f).await.unwrap(); + + match node { + nar::reader::Node::File { + executable, + mut reader, + } => { + assert!(!executable); + let mut buf = vec![]; + reader + .read_to_end(&mut buf) + .await + .expect("read must succeed"); + assert_eq!(&b"Hello World!"[..], &buf); + } + _ => panic!("unexpected type"), + } +} + +#[tokio::test] +async fn complicated() { + let mut f = std::io::Cursor::new(include_bytes!("../../tests/complicated.nar")); + let node = nar::reader::open(&mut f).await.unwrap(); + + match node { + nar::reader::Node::Directory(mut dir_reader) => { + // first entry is .keep, an empty regular file. + must_read_file( + ".keep", + dir_reader + .next() + .await + .expect("next must succeed") + .expect("must be some"), + ) + .await; + + // second entry is aa, a symlink to /nix/store/somewhereelse + must_be_symlink( + "aa", + "/nix/store/somewhereelse", + dir_reader + .next() + .await + .expect("next must be some") + .expect("must be some"), + ); + + { + // third entry is a directory called "keep" + let entry = dir_reader + .next() + .await + .expect("next must be some") + .expect("must be some"); + + assert_eq!(b"keep", entry.name); + + match entry.node { + nar::reader::Node::Directory(mut subdir_reader) => { + { + // first entry is .keep, an empty regular file. + let entry = subdir_reader + .next() + .await + .expect("next must succeed") + .expect("must be some"); + + must_read_file(".keep", entry).await; + } + + // we must read the None + assert!( + subdir_reader + .next() + .await + .expect("next must succeed") + .is_none(), + "keep directory contains only .keep" + ); + } + _ => panic!("unexpected type for keep/.keep"), + } + }; + + // reading more entries yields None (and we actually must read until this) + assert!(dir_reader.next().await.expect("must succeed").is_none()); + } + _ => panic!("unexpected type"), + } +} + +#[tokio::test] +#[should_panic] +#[ignore = "TODO: async poisoning"] +async fn file_read_abandoned() { + let mut f = std::io::Cursor::new(include_bytes!("../../tests/complicated.nar")); + let node = nar::reader::open(&mut f).await.unwrap(); + + match node { + nar::reader::Node::Directory(mut dir_reader) => { + // first entry is .keep, an empty regular file. + { + let entry = dir_reader + .next() + .await + .expect("next must succeed") + .expect("must be some"); + + assert_eq!(b".keep", entry.name); + // don't bother to finish reading it. + }; + + // this should panic (not return an error), because we are meant to abandon the archive reader now. + assert!(dir_reader.next().await.expect("must succeed").is_none()); + } + _ => panic!("unexpected type"), + } +} + +#[tokio::test] +#[should_panic] +#[ignore = "TODO: async poisoning"] +async fn dir_read_abandoned() { + let mut f = std::io::Cursor::new(include_bytes!("../../tests/complicated.nar")); + let node = nar::reader::open(&mut f).await.unwrap(); + + match node { + nar::reader::Node::Directory(mut dir_reader) => { + // first entry is .keep, an empty regular file. + must_read_file( + ".keep", + dir_reader + .next() + .await + .expect("next must succeed") + .expect("must be some"), + ) + .await; + + // second entry is aa, a symlink to /nix/store/somewhereelse + must_be_symlink( + "aa", + "/nix/store/somewhereelse", + dir_reader + .next() + .await + .expect("next must be some") + .expect("must be some"), + ); + + { + // third entry is a directory called "keep" + let entry = dir_reader + .next() + .await + .expect("next must be some") + .expect("must be some"); + + assert_eq!(b"keep", entry.name); + + match entry.node { + nar::reader::Node::Directory(_) => { + // don't finish using it, which poisons the archive reader + } + _ => panic!("unexpected type for keep/.keep"), + } + }; + + // this should panic, because we didn't finish reading the child subdirectory + assert!(dir_reader.next().await.expect("must succeed").is_none()); + } + _ => panic!("unexpected type"), + } +} + +#[tokio::test] +#[should_panic] +#[ignore = "TODO: async poisoning"] +async fn dir_read_after_none() { + let mut f = std::io::Cursor::new(include_bytes!("../../tests/complicated.nar")); + let node = nar::reader::open(&mut f).await.unwrap(); + + match node { + nar::reader::Node::Directory(mut dir_reader) => { + // first entry is .keep, an empty regular file. + must_read_file( + ".keep", + dir_reader + .next() + .await + .expect("next must succeed") + .expect("must be some"), + ) + .await; + + // second entry is aa, a symlink to /nix/store/somewhereelse + must_be_symlink( + "aa", + "/nix/store/somewhereelse", + dir_reader + .next() + .await + .expect("next must be some") + .expect("must be some"), + ); + + { + // third entry is a directory called "keep" + let entry = dir_reader + .next() + .await + .expect("next must be some") + .expect("must be some"); + + assert_eq!(b"keep", entry.name); + + match entry.node { + nar::reader::Node::Directory(mut subdir_reader) => { + // first entry is .keep, an empty regular file. + must_read_file( + ".keep", + subdir_reader + .next() + .await + .expect("next must succeed") + .expect("must be some"), + ) + .await; + + // we must read the None + assert!( + subdir_reader + .next() + .await + .expect("next must succeed") + .is_none(), + "keep directory contains only .keep" + ); + } + _ => panic!("unexpected type for keep/.keep"), + } + }; + + // reading more entries yields None (and we actually must read until this) + assert!(dir_reader.next().await.expect("must succeed").is_none()); + + // this should panic, because we already got a none so we're meant to stop. + dir_reader.next().await.unwrap(); + unreachable!() + } + _ => panic!("unexpected type"), + } +} + +async fn must_read_file(name: &'static str, entry: nar::reader::Entry<'_, '_>) { + assert_eq!(name.as_bytes(), entry.name); + + match entry.node { + nar::reader::Node::File { + executable, + mut reader, + } => { + assert!(!executable); + assert_eq!(reader.read(&mut [0]).await.unwrap(), 0); + } + _ => panic!("unexpected type for {}", name), + } +} + +fn must_be_symlink( + name: &'static str, + exp_target: &'static str, + entry: nar::reader::Entry<'_, '_>, +) { + assert_eq!(name.as_bytes(), entry.name); + + match entry.node { + nar::reader::Node::Symlink { target } => { + assert_eq!(exp_target.as_bytes(), &target); + } + _ => panic!("unexpected type for {}", name), + } +} diff --git a/tvix/nix-compat/src/nar/reader/mod.rs b/tvix/nix-compat/src/nar/reader/mod.rs index ecf9c3d78912..eef3b10f3c28 100644 --- a/tvix/nix-compat/src/nar/reader/mod.rs +++ b/tvix/nix-compat/src/nar/reader/mod.rs @@ -16,6 +16,9 @@ use std::marker::PhantomData; // Required reading for understanding this module. use crate::nar::wire; +#[cfg(all(feature = "async", feature = "wire"))] +pub mod r#async; + mod read; #[cfg(test)] mod test; @@ -26,9 +29,11 @@ struct ArchiveReader<'a, 'r> { inner: &'a mut Reader<'r>, /// In debug mode, also track when we need to abandon this archive reader. + /// /// The archive reader must be abandoned when: /// * An error is encountered at any point /// * A file or directory reader is dropped before being read entirely. + /// /// All of these checks vanish in release mode. status: ArchiveReaderStatus<'a>, } @@ -261,11 +266,11 @@ pub struct DirReader<'a, 'r> { reader: ArchiveReader<'a, 'r>, /// Previous directory entry name. /// We have to hang onto this to enforce name monotonicity. - prev_name: Option<Vec<u8>>, + prev_name: Vec<u8>, } pub struct Entry<'a, 'r> { - pub name: Vec<u8>, + pub name: &'a [u8], pub node: Node<'a, 'r>, } @@ -273,7 +278,7 @@ impl<'a, 'r> DirReader<'a, 'r> { fn new(reader: ArchiveReader<'a, 'r>) -> Self { Self { reader, - prev_name: None, + prev_name: vec![], } } @@ -292,7 +297,7 @@ impl<'a, 'r> DirReader<'a, 'r> { // COME FROM the previous iteration: if we've already read an entry, // read its terminating TOK_PAR here. - if self.prev_name.is_some() { + if !self.prev_name.is_empty() { try_or_poison!(self.reader, read::token(self.reader.inner, &wire::TOK_PAR)); } @@ -303,9 +308,10 @@ impl<'a, 'r> DirReader<'a, 'r> { return Ok(None); } + let mut name = [0; wire::MAX_NAME_LEN + 1]; let name = try_or_poison!( self.reader, - read::bytes(self.reader.inner, wire::MAX_NAME_LEN) + read::bytes_buf(self.reader.inner, &mut name, wire::MAX_NAME_LEN) ); if name.is_empty() @@ -319,24 +325,18 @@ impl<'a, 'r> DirReader<'a, 'r> { } // Enforce strict monotonicity of directory entry names. - match &mut self.prev_name { - None => { - self.prev_name = Some(name.clone()); - } - Some(prev_name) => { - if *prev_name >= name { - self.reader.status.poison(); - return Err(InvalidData.into()); - } - - name[..].clone_into(prev_name); - } + if &self.prev_name[..] >= name { + self.reader.status.poison(); + return Err(InvalidData.into()); } + self.prev_name.clear(); + self.prev_name.extend_from_slice(name); + try_or_poison!(self.reader, read::token(self.reader.inner, &wire::TOK_NOD)); Ok(Some(Entry { - name, + name: &self.prev_name, // Don't need to worry about poisoning here: Node::new will do it for us if needed node: Node::new(self.reader.child())?, })) diff --git a/tvix/nix-compat/src/nar/reader/read.rs b/tvix/nix-compat/src/nar/reader/read.rs index 1ce161376424..9938581f2a2e 100644 --- a/tvix/nix-compat/src/nar/reader/read.rs +++ b/tvix/nix-compat/src/nar/reader/read.rs @@ -15,6 +15,38 @@ pub fn u64(reader: &mut Reader) -> io::Result<u64> { Ok(u64::from_le_bytes(buf)) } +/// Consume a byte string from the reader into a provided buffer, +/// returning the data bytes. +pub fn bytes_buf<'a, const N: usize>( + reader: &mut Reader, + buf: &'a mut [u8; N], + max_len: usize, +) -> io::Result<&'a [u8]> { + assert_eq!(N % 8, 0); + assert!(max_len <= N); + + // read the length, and reject excessively large values + let len = self::u64(reader)?; + if len > max_len as u64 { + return Err(InvalidData.into()); + } + // we know the length fits in a usize now + let len = len as usize; + + // read the data and padding into a buffer + let buf_len = (len + 7) & !7; + reader.read_exact(&mut buf[..buf_len])?; + + // verify that the padding is all zeroes + for &b in &buf[len..buf_len] { + if b != 0 { + return Err(InvalidData.into()); + } + } + + Ok(&buf[..len]) +} + /// Consume a byte string of up to `max_len` bytes from the reader. pub fn bytes(reader: &mut Reader, max_len: usize) -> io::Result<Vec<u8>> { assert!(max_len <= isize::MAX as usize); diff --git a/tvix/nix-compat/src/nar/reader/test.rs b/tvix/nix-compat/src/nar/reader/test.rs index 02dc4767c916..63e4fb289ffc 100644 --- a/tvix/nix-compat/src/nar/reader/test.rs +++ b/tvix/nix-compat/src/nar/reader/test.rs @@ -71,7 +71,7 @@ fn complicated() { .expect("next must be some") .expect("must be some"); - assert_eq!(&b"keep"[..], &entry.name); + assert_eq!(b"keep", entry.name); match entry.node { nar::reader::Node::Directory(mut subdir_reader) => { @@ -117,7 +117,7 @@ fn file_read_abandoned() { .expect("next must succeed") .expect("must be some"); - assert_eq!(&b".keep"[..], &entry.name); + assert_eq!(b".keep", entry.name); // don't bother to finish reading it. }; @@ -162,7 +162,7 @@ fn dir_read_abandoned() { .expect("next must be some") .expect("must be some"); - assert_eq!(&b"keep"[..], &entry.name); + assert_eq!(b"keep", entry.name); match entry.node { nar::reader::Node::Directory(_) => { @@ -213,7 +213,7 @@ fn dir_read_after_none() { .expect("next must be some") .expect("must be some"); - assert_eq!(&b"keep"[..], &entry.name); + assert_eq!(b"keep", entry.name); match entry.node { nar::reader::Node::Directory(mut subdir_reader) => { @@ -248,7 +248,7 @@ fn dir_read_after_none() { } fn must_read_file(name: &'static str, entry: nar::reader::Entry<'_, '_>) { - assert_eq!(name.as_bytes(), &entry.name); + assert_eq!(name.as_bytes(), entry.name); match entry.node { nar::reader::Node::File { @@ -267,7 +267,7 @@ fn must_be_symlink( exp_target: &'static str, entry: nar::reader::Entry<'_, '_>, ) { - assert_eq!(name.as_bytes(), &entry.name); + assert_eq!(name.as_bytes(), entry.name); match entry.node { nar::reader::Node::Symlink { target } => { diff --git a/tvix/nix-compat/src/nar/tests/nixos-release.ls b/tvix/nix-compat/src/nar/tests/nixos-release.ls new file mode 100644 index 000000000000..9dd350b7cf86 --- /dev/null +++ b/tvix/nix-compat/src/nar/tests/nixos-release.ls @@ -0,0 +1 @@ +{"root":{"entries":{"iso":{"entries":{"nixos-minimal-new-kernel-no-zfs-24.11pre660688.bee6b69aad74-x86_64-linux.iso":{"narOffset":440,"size":1051721728,"type":"regular"}},"type":"directory"},"nix-support":{"entries":{"hydra-build-products":{"narOffset":1051722544,"size":211,"type":"regular"},"system":{"narOffset":1051722944,"size":13,"type":"regular"}},"type":"directory"}},"type":"directory"},"version":1} \ No newline at end of file diff --git a/tvix/nix-compat/src/nar/wire/mod.rs b/tvix/nix-compat/src/nar/wire/mod.rs index b9e021249543..67654129ee1d 100644 --- a/tvix/nix-compat/src/nar/wire/mod.rs +++ b/tvix/nix-compat/src/nar/wire/mod.rs @@ -39,7 +39,7 @@ //! TOK_NAR ::= "nix-archive-1" "(" "type" //! TOK_SYM ::= "symlink" "target" //! TOK_REG ::= "regular" "contents" -//! TOK_EXE ::= "regular" "executable" "" +//! TOK_EXE ::= "regular" "executable" "" "contents" //! TOK_DIR ::= "directory" //! TOK_ENT ::= "entry" "(" "name" //! TOK_NOD ::= "node" "(" "type" @@ -90,6 +90,25 @@ pub const TOK_DIR: [u8; 24] = *b"\x09\0\0\0\0\0\0\0directory\0\0\0\0\0\0\0"; pub const TOK_ENT: [u8; 48] = *b"\x05\0\0\0\0\0\0\0entry\0\0\0\x01\0\0\0\0\0\0\0(\0\0\0\0\0\0\0\x04\0\0\0\0\0\0\0name\0\0\0\0"; pub const TOK_NOD: [u8; 48] = *b"\x04\0\0\0\0\0\0\0node\0\0\0\0\x01\0\0\0\0\0\0\0(\0\0\0\0\0\0\0\x04\0\0\0\0\0\0\0type\0\0\0\0"; pub const TOK_PAR: [u8; 16] = *b"\x01\0\0\0\0\0\0\0)\0\0\0\0\0\0\0"; +#[cfg(feature = "async")] +#[allow(dead_code)] +const TOK_PAD_PAR: [u8; 24] = *b"\0\0\0\0\0\0\0\0\x01\0\0\0\0\0\0\0)\0\0\0\0\0\0\0"; + +#[cfg(feature = "async")] +#[allow(dead_code)] +#[derive(Debug)] +pub(crate) enum PadPar {} + +#[cfg(all(feature = "async", feature = "wire"))] +impl crate::wire::reader::Tag for PadPar { + const PATTERN: &'static [u8] = &TOK_PAD_PAR; + + type Buf = [u8; 24]; + + fn make_buf() -> Self::Buf { + [0; 24] + } +} #[test] fn tokens() { @@ -102,6 +121,8 @@ fn tokens() { (&TOK_ENT, &["entry", "(", "name"]), (&TOK_NOD, &["node", "(", "type"]), (&TOK_PAR, &[")"]), + #[cfg(feature = "async")] + (&TOK_PAD_PAR, &["", ")"]), ]; for &(tok, xs) in cases { diff --git a/tvix/nix-compat/src/nar/writer/sync.rs b/tvix/nix-compat/src/nar/writer/sync.rs index 6270129028fa..b441479ac60b 100644 --- a/tvix/nix-compat/src/nar/writer/sync.rs +++ b/tvix/nix-compat/src/nar/writer/sync.rs @@ -35,11 +35,8 @@ use std::io::{ Write, }; -/// Convenience type alias for types implementing [`Write`]. -pub type Writer<'a> = dyn Write + Send + 'a; - /// Create a new NAR, writing the output to the specified writer. -pub fn open<'a, 'w: 'a>(writer: &'a mut Writer<'w>) -> io::Result<Node<'a, 'w>> { +pub fn open<W: Write>(writer: &mut W) -> io::Result<Node<W>> { let mut node = Node { writer }; node.write(&wire::TOK_NAR)?; Ok(node) @@ -49,11 +46,11 @@ pub fn open<'a, 'w: 'a>(writer: &'a mut Writer<'w>) -> io::Result<Node<'a, 'w>> /// /// A NAR can be thought of as a tree of nodes represented by this type. Each /// node can be a file, a symlink or a directory containing other nodes. -pub struct Node<'a, 'w: 'a> { - writer: &'a mut Writer<'w>, +pub struct Node<'a, W: Write> { + writer: &'a mut W, } -impl<'a, 'w> Node<'a, 'w> { +impl<'a, W: Write> Node<'a, W> { fn write(&mut self, data: &[u8]) -> io::Result<()> { self.writer.write_all(data) } @@ -123,12 +120,59 @@ impl<'a, 'w> Node<'a, 'w> { Ok(()) } + /// Make this node a single file but let the user handle the writing of the file contents. + /// The user gets access to a writer to write the file contents to, plus a struct they must + /// invoke a function on to finish writing the NAR file. + /// + /// It is the caller's responsibility to write the correct number of bytes to the writer and + /// invoke [`FileManualWrite::close`], or invalid archives will be produced silently. + /// + /// ```rust + /// # use std::io::BufReader; + /// # use std::io::Write; + /// # + /// # // Output location to write the NAR to. + /// # let mut sink: Vec<u8> = Vec::new(); + /// # + /// # // Instantiate writer for this output location. + /// # let mut nar = nix_compat::nar::writer::open(&mut sink)?; + /// # + /// let contents = "Hello world\n".as_bytes(); + /// let size = contents.len() as u64; + /// let executable = false; + /// + /// let (writer, skip) = nar + /// .file_manual_write(executable, size)?; + /// + /// // Write the contents + /// writer.write_all(&contents)?; + /// + /// // Close the file node + /// skip.close(writer)?; + /// # Ok::<(), std::io::Error>(()) + /// ``` + pub fn file_manual_write( + mut self, + executable: bool, + size: u64, + ) -> io::Result<(&'a mut W, FileManualWrite)> { + self.write(if executable { + &wire::TOK_EXE + } else { + &wire::TOK_REG + })?; + + self.write(&size.to_le_bytes())?; + + Ok((self.writer, FileManualWrite { size })) + } + /// Make this node a directory, the content of which is set using the /// resulting [`Directory`] value. /// /// It is the caller's responsibility to invoke [`Directory::close`], /// or invalid archives will be produced silently. - pub fn directory(mut self) -> io::Result<Directory<'a, 'w>> { + pub fn directory(mut self) -> io::Result<Directory<'a, W>> { self.write(&wire::TOK_DIR)?; Ok(Directory::new(self)) } @@ -145,13 +189,13 @@ fn into_name(_name: &[u8]) -> Name { } /// Content of a NAR node that represents a directory. -pub struct Directory<'a, 'w> { - node: Node<'a, 'w>, +pub struct Directory<'a, W: Write> { + node: Node<'a, W>, prev_name: Option<Name>, } -impl<'a, 'w> Directory<'a, 'w> { - fn new(node: Node<'a, 'w>) -> Self { +impl<'a, W: Write> Directory<'a, W> { + fn new(node: Node<'a, W>) -> Self { Self { node, prev_name: None, @@ -166,7 +210,7 @@ impl<'a, 'w> Directory<'a, 'w> { /// It is the caller's responsibility to ensure that directory entries are /// written in order of ascending name. If this is not ensured, this method /// may panic or silently produce invalid archives. - pub fn entry(&mut self, name: &[u8]) -> io::Result<Node<'_, 'w>> { + pub fn entry(&mut self, name: &[u8]) -> io::Result<Node<'_, W>> { debug_assert!( name.len() <= wire::MAX_NAME_LEN, "name.len() > {}", @@ -222,3 +266,24 @@ impl<'a, 'w> Directory<'a, 'w> { Ok(()) } } + +/// Content of a NAR node that represents a file whose contents are being written out manually. +/// Returned by the `file_manual_write` function. +#[must_use] +pub struct FileManualWrite { + size: u64, +} + +impl FileManualWrite { + /// Finish writing the file structure to the NAR after having manually written the file contents. + /// + /// **Important:** This *must* be called with the writer returned by file_manual_write after + /// the file contents have been manually and fully written. Otherwise the resulting NAR file + /// will be invalid. + pub fn close<W: Write>(self, writer: &mut W) -> io::Result<()> { + let mut node = Node { writer }; + node.pad(self.size)?; + node.write(&wire::TOK_PAR)?; + Ok(()) + } +} diff --git a/tvix/nix-compat/src/narinfo/mod.rs b/tvix/nix-compat/src/narinfo/mod.rs index b1c10bceb200..35146a927b39 100644 --- a/tvix/nix-compat/src/narinfo/mod.rs +++ b/tvix/nix-compat/src/narinfo/mod.rs @@ -27,13 +27,15 @@ use std::{ use crate::{nixbase32, nixhash::CAHash, store_path::StorePathRef}; mod fingerprint; -mod public_keys; mod signature; +mod signing_keys; +mod verifying_keys; pub use fingerprint::fingerprint; - -pub use public_keys::{Error as PubKeyError, PubKey}; -pub use signature::{Error as SignatureError, Signature}; +pub use signature::{Error as SignatureError, Signature, SignatureRef}; +pub use signing_keys::parse_keypair; +pub use signing_keys::{Error as SigningKeyError, SigningKey}; +pub use verifying_keys::{Error as VerifyingKeyError, VerifyingKey}; #[derive(Debug)] pub struct NarInfo<'a> { @@ -49,7 +51,7 @@ pub struct NarInfo<'a> { pub references: Vec<StorePathRef<'a>>, // authenticity /// Ed25519 signature over the path fingerprint - pub signatures: Vec<Signature<'a>>, + pub signatures: Vec<SignatureRef<'a>>, /// Content address (for content-defined paths) pub ca: Option<CAHash>, // derivation metadata @@ -244,7 +246,7 @@ impl<'a> NarInfo<'a> { }; } "Sig" => { - let val = Signature::parse(val) + let val = SignatureRef::parse(val) .map_err(|e| Error::UnableToParseSignature(signatures.len(), e))?; signatures.push(val); @@ -297,6 +299,21 @@ impl<'a> NarInfo<'a> { self.references.iter(), ) } + + /// Adds a signature, using the passed signer to sign. + /// This is generic over algo implementations / providers, + /// so users can bring their own signers. + pub fn add_signature<S>(&mut self, signer: &'a SigningKey<S>) + where + S: ed25519::signature::Signer<ed25519::Signature>, + { + // calculate the fingerprint to sign + let fp = self.fingerprint(); + + let sig = signer.sign(fp.as_bytes()); + + self.signatures.push(sig); + } } impl Display for NarInfo<'_> { @@ -392,10 +409,16 @@ pub enum Error { } #[cfg(test)] +const DUMMY_KEYPAIR: &str = "cache.example.com-1:cCta2MEsRNuYCgWYyeRXLyfoFpKhQJKn8gLMeXWAb7vIpRKKo/3JoxJ24OYa3DxT2JVV38KjK/1ywHWuMe2JEw=="; +#[cfg(test)] +const DUMMY_VERIFYING_KEY: &str = + "cache.example.com-1:yKUSiqP9yaMSduDmGtw8U9iVVd/Coyv9csB1rjHtiRM="; + +#[cfg(test)] mod test { use hex_literal::hex; - use lazy_static::lazy_static; use pretty_assertions::assert_eq; + use std::sync::LazyLock; use std::{io, str}; use crate::{ @@ -405,20 +428,18 @@ mod test { use super::{Flags, NarInfo}; - lazy_static! { - static ref CASES: &'static [&'static str] = { - let data = zstd::decode_all(io::Cursor::new(include_bytes!( - "../../testdata/narinfo.zst" - ))) - .unwrap(); - let data = str::from_utf8(Vec::leak(data)).unwrap(); - Vec::leak( - data.split_inclusive("\n\n") - .map(|s| s.strip_suffix('\n').unwrap()) - .collect::<Vec<_>>(), - ) - }; - } + static CASES: LazyLock<&'static [&'static str]> = LazyLock::new(|| { + let data = zstd::decode_all(io::Cursor::new(include_bytes!( + "../../testdata/narinfo.zst" + ))) + .unwrap(); + let data = str::from_utf8(Vec::leak(data)).unwrap(); + Vec::leak( + data.split_inclusive("\n\n") + .map(|s| s.strip_suffix('\n').unwrap()) + .collect::<Vec<_>>(), + ) + }); #[test] fn roundtrip() { @@ -524,4 +545,46 @@ Sig: cache.nixos.org-1:HhaiY36Uk3XV1JGe9d9xHnzAapqJXprU1YZZzSzxE97jCuO5RR7vlG2kF parsed.nar_hash, ); } + + /// Adds a signature to a NARInfo, using key material parsed from DUMMY_KEYPAIR. + /// It then ensures signature verification with the parsed + /// DUMMY_VERIFYING_KEY succeeds. + #[test] + fn sign() { + let mut narinfo = NarInfo::parse( + r#"StorePath: /nix/store/0vpqfxbkx0ffrnhbws6g9qwhmliksz7f-perl-HTTP-Cookies-6.01 +URL: nar/0i5biw0g01514llhfswxy6xfav8lxxdq1xg6ik7hgsqbpw0f06yi.nar.xz +Compression: xz +FileHash: sha256:0i5biw0g01514llhfswxy6xfav8lxxdq1xg6ik7hgsqbpw0f06yi +FileSize: 7120 +NarHash: sha256:0h1bm4sj1cnfkxgyhvgi8df1qavnnv94sd0v09wcrm971602shfg +NarSize: 22552 +References: +CA: fixed:r:sha1:1ak1ymbmsfx7z8kh09jzkr3a4dvkrfjw +"#, + ) + .expect("should parse"); + + let fp = narinfo.fingerprint(); + + // load our keypair from the fixtures + let (signing_key, _verifying_key) = + super::parse_keypair(super::DUMMY_KEYPAIR).expect("must succeed"); + + // add signature + narinfo.add_signature(&signing_key); + + // ensure the signature is added + let new_sig = narinfo.signatures.last().unwrap(); + assert_eq!(signing_key.name(), *new_sig.name()); + + // verify the new signature against the verifying key + let verifying_key = super::VerifyingKey::parse(super::DUMMY_VERIFYING_KEY) + .expect("parsing dummy verifying key"); + + assert!( + verifying_key.verify(&fp, new_sig), + "expect signature to be valid" + ); + } } diff --git a/tvix/nix-compat/src/narinfo/signature.rs b/tvix/nix-compat/src/narinfo/signature.rs index fd197e771d98..2005a5cb60df 100644 --- a/tvix/nix-compat/src/narinfo/signature.rs +++ b/tvix/nix-compat/src/narinfo/signature.rs @@ -1,21 +1,44 @@ -use std::fmt::{self, Display}; +use std::{ + fmt::{self, Display}, + ops::Deref, +}; use data_encoding::BASE64; -use ed25519_dalek::SIGNATURE_LENGTH; use serde::{Deserialize, Serialize}; +const SIGNATURE_LENGTH: usize = std::mem::size_of::<ed25519::SignatureBytes>(); + #[derive(Clone, Debug, Eq, PartialEq)] -pub struct Signature<'a> { - name: &'a str, - bytes: [u8; SIGNATURE_LENGTH], +pub struct Signature<S> { + name: S, + bytes: ed25519::SignatureBytes, } -impl<'a> Signature<'a> { - pub fn new(name: &'a str, bytes: [u8; SIGNATURE_LENGTH]) -> Self { +/// Type alias of a [Signature] using a `&str` as `name` field. +pub type SignatureRef<'a> = Signature<&'a str>; + +/// Represents the signatures that Nix emits. +/// It consists of a name (an identifier for a public key), and an ed25519 +/// signature (64 bytes). +/// It is generic over the string type that's used for the name, and there's +/// [SignatureRef] as a type alias for one containing &str. +impl<S> Signature<S> +where + S: Deref<Target = str>, +{ + /// Constructs a new [Signature] from a name and public key. + pub fn new(name: S, bytes: ed25519::SignatureBytes) -> Self { Self { name, bytes } } - pub fn parse(input: &'a str) -> Result<Self, Error> { + /// Parses a [Signature] from a string containing the name, a colon, and 64 + /// base64-encoded bytes (plus padding). + /// These strings are commonly seen in the `Signature:` field of a NARInfo + /// file. + pub fn parse<'a>(input: &'a str) -> Result<Self, Error> + where + S: From<&'a str>, + { let (name, bytes64) = input.split_once(':').ok_or(Error::MissingSeparator)?; if name.is_empty() @@ -39,14 +62,19 @@ impl<'a> Signature<'a> { Err(_) => return Err(Error::DecodeError(input.to_string())), } - Ok(Signature { name, bytes }) + Ok(Self { + name: name.into(), + bytes, + }) } - pub fn name(&self) -> &'a str { - self.name + /// Returns the name field of the signature. + pub fn name(&self) -> &S { + &self.name } - pub fn bytes(&self) -> &[u8; SIGNATURE_LENGTH] { + /// Returns the 64 bytes of signatures. + pub fn bytes(&self) -> &ed25519::SignatureBytes { &self.bytes } @@ -56,9 +84,27 @@ impl<'a> Signature<'a> { verifying_key.verify_strict(fingerprint, &signature).is_ok() } + + /// Constructs a [SignatureRef] from this signature. + pub fn as_ref(&self) -> SignatureRef<'_> { + SignatureRef { + name: self.name.deref(), + bytes: self.bytes, + } + } + pub fn to_owned(&self) -> Signature<String> { + Signature { + name: self.name.to_string(), + bytes: self.bytes, + } + } } -impl<'de: 'a, 'a> Deserialize<'de> for Signature<'a> { +impl<'a, 'de, S> Deserialize<'de> for Signature<S> +where + S: Deref<Target = str> + From<&'a str>, + 'de: 'a, +{ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: serde::Deserializer<'de>, @@ -70,10 +116,13 @@ impl<'de: 'a, 'a> Deserialize<'de> for Signature<'a> { } } -impl<'a> Serialize for Signature<'a> { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> +impl<S: Display> Serialize for Signature<S> +where + S: Deref<Target = str>, +{ + fn serialize<SR>(&self, serializer: SR) -> Result<SR::Ok, SR::Error> where - S: serde::Serializer, + SR: serde::Serializer, { let string: String = self.to_string(); @@ -81,7 +130,26 @@ impl<'a> Serialize for Signature<'a> { } } -#[derive(Debug, thiserror::Error)] +impl<S> Display for Signature<S> +where + S: Display, +{ + fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { + write!(w, "{}:{}", self.name, BASE64.encode(&self.bytes)) + } +} + +impl<S> std::hash::Hash for Signature<S> +where + S: AsRef<str>, +{ + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + state.write(self.name.as_ref().as_bytes()); + state.write(&self.bytes); + } +} + +#[derive(Debug, thiserror::Error, PartialEq, Eq)] pub enum Error { #[error("Invalid name: {0}")] InvalidName(String), @@ -93,43 +161,29 @@ pub enum Error { DecodeError(String), } -impl Display for Signature<'_> { - fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { - write!(w, "{}:{}", self.name, BASE64.encode(&self.bytes)) - } -} - #[cfg(test)] mod test { use data_encoding::BASE64; use ed25519_dalek::VerifyingKey; use hex_literal::hex; - use lazy_static::lazy_static; + use std::sync::LazyLock; use super::Signature; use rstest::rstest; const FINGERPRINT: &str = "1;/nix/store/syd87l2rxw8cbsxmxl853h0r6pdwhwjr-curl-7.82.0-bin;sha256:1b4sb93wp679q4zx9k1ignby1yna3z7c4c2ri3wphylbc2dwsys0;196040;/nix/store/0jqd0rlxzra1rs38rdxl43yh6rxchgc6-curl-7.82.0,/nix/store/6w8g7njm4mck5dmjxws0z1xnrxvl81xa-glibc-2.34-115,/nix/store/j5jxw3iy7bbz4a57fh9g2xm2gxmyal8h-zlib-1.2.12,/nix/store/yxvjs9drzsphm9pcf42a4byzj1kb9m7k-openssl-1.1.1n"; - // The signing key labelled as `cache.nixos.org-1`, - lazy_static! { - static ref PUB_CACHE_NIXOS_ORG_1: VerifyingKey = ed25519_dalek::VerifyingKey::from_bytes( + /// The signing key labelled as `cache.nixos.org-1`, + static PUB_CACHE_NIXOS_ORG_1: LazyLock<VerifyingKey> = LazyLock::new(|| { + ed25519_dalek::VerifyingKey::from_bytes( BASE64 .decode(b"6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY=") .unwrap()[..] .try_into() - .unwrap() - ) - .unwrap(); - static ref PUB_TEST_1: VerifyingKey = ed25519_dalek::VerifyingKey::from_bytes( - BASE64 - .decode(b"tLAEn+EeaBUJYqEpTd2yeerr7Ic6+0vWe+aXL/vYUpE=") - .unwrap()[..] - .try_into() - .unwrap() + .unwrap(), ) - .unwrap(); - } + .expect("embedded public key is valid") + }); #[rstest] #[case::valid_cache_nixos_org_1(&PUB_CACHE_NIXOS_ORG_1, &"cache.nixos.org-1:TsTTb3WGTZKphvYdBHXwo6weVILmTytUjLB+vcX89fOjjRicCHmKA4RCPMVLkj6TMJ4GMX3HPVWRdD1hkeKZBQ==", FINGERPRINT, true)] @@ -143,7 +197,7 @@ mod test { #[case] fp: &str, #[case] expect_valid: bool, ) { - let sig = Signature::parse(sig_str).expect("must parse"); + let sig = Signature::<&str>::parse(sig_str).expect("must parse"); assert_eq!(expect_valid, sig.verify(fp.as_bytes(), verifying_key)); } @@ -158,7 +212,7 @@ mod test { "u01BybwQhyI5H1bW1EIWXssMDhDDIvXOG5uh8Qzgdyjz6U1qg6DHhMAvXZOUStIj6X5t4/ufFgR8i3fjf0bMAw==" )] fn parse_fail(#[case] input: &'static str) { - Signature::parse(input).expect_err("must fail"); + Signature::<&str>::parse(input).expect_err("must fail"); } #[test] @@ -177,8 +231,29 @@ mod test { let serialized = serde_json::to_string(&signature_actual).expect("must serialize"); assert_eq!(signature_str_json, &serialized); - let deserialized: Signature<'_> = + let deserialized: Signature<&str> = serde_json::from_str(signature_str_json).expect("must deserialize"); assert_eq!(&signature_actual, &deserialized); } + + /// Construct a [Signature], using different String types for the name field. + #[test] + fn signature_owned() { + let signature1 = Signature::<String>::parse("cache.nixos.org-1:TsTTb3WGTZKphvYdBHXwo6weVILmTytUjLB+vcX89fOjjRicCHmKA4RCPMVLkj6TMJ4GMX3HPVWRdD1hkeKZBQ==").expect("must parse"); + let signature2 = Signature::<smol_str::SmolStr>::parse("cache.nixos.org-1:TsTTb3WGTZKphvYdBHXwo6weVILmTytUjLB+vcX89fOjjRicCHmKA4RCPMVLkj6TMJ4GMX3HPVWRdD1hkeKZBQ==").expect("must parse"); + let signature3 = Signature::<&str>::parse("cache.nixos.org-1:TsTTb3WGTZKphvYdBHXwo6weVILmTytUjLB+vcX89fOjjRicCHmKA4RCPMVLkj6TMJ4GMX3HPVWRdD1hkeKZBQ==").expect("must parse"); + + assert!( + signature1.verify(FINGERPRINT.as_bytes(), &PUB_CACHE_NIXOS_ORG_1), + "must verify" + ); + assert!( + signature2.verify(FINGERPRINT.as_bytes(), &PUB_CACHE_NIXOS_ORG_1), + "must verify" + ); + assert!( + signature3.verify(FINGERPRINT.as_bytes(), &PUB_CACHE_NIXOS_ORG_1), + "must verify" + ); + } } diff --git a/tvix/nix-compat/src/narinfo/signing_keys.rs b/tvix/nix-compat/src/narinfo/signing_keys.rs new file mode 100644 index 000000000000..cf513b7ba475 --- /dev/null +++ b/tvix/nix-compat/src/narinfo/signing_keys.rs @@ -0,0 +1,119 @@ +//! This module provides tooling to parse private key (pairs) produced by Nix +//! and its +//! `nix-store --generate-binary-cache-key name path.secret path.pub` command. +//! It produces `ed25519_dalek` keys, but the `NarInfo::add_signature` function +//! is generic, allowing other signers. + +use data_encoding::BASE64; +use ed25519_dalek::{PUBLIC_KEY_LENGTH, SECRET_KEY_LENGTH}; + +use super::{SignatureRef, VerifyingKey}; + +pub struct SigningKey<S> { + name: String, + signing_key: S, +} + +impl<S> SigningKey<S> +where + S: ed25519::signature::Signer<ed25519::Signature>, +{ + /// Constructs a singing key, using a name and a signing key. + pub fn new(name: String, signing_key: S) -> Self { + Self { name, signing_key } + } + + /// Signs a fingerprint using the internal signing key, returns the [SignatureRef] + pub(crate) fn sign<'a>(&'a self, fp: &[u8]) -> SignatureRef<'a> { + SignatureRef::new(&self.name, self.signing_key.sign(fp).to_bytes()) + } + + pub fn name(&self) -> &str { + &self.name + } +} + +/// Parses a SigningKey / VerifyingKey from a byte slice in the format that Nix uses. +pub fn parse_keypair( + input: &str, +) -> Result<(SigningKey<ed25519_dalek::SigningKey>, VerifyingKey), Error> { + let (name, bytes64) = input.split_once(':').ok_or(Error::MissingSeparator)?; + + if name.is_empty() + || !name + .chars() + .all(|c| char::is_alphanumeric(c) || c == '-' || c == '.') + { + return Err(Error::InvalidName(name.to_string())); + } + + const DECODED_BYTES_LEN: usize = SECRET_KEY_LENGTH + PUBLIC_KEY_LENGTH; + if bytes64.len() != BASE64.encode_len(DECODED_BYTES_LEN) { + return Err(Error::InvalidSigningKeyLen(bytes64.len())); + } + + let mut buf = [0; DECODED_BYTES_LEN + 2]; // 64 bytes + 2 bytes padding + let mut bytes = [0; DECODED_BYTES_LEN]; + match BASE64.decode_mut(bytes64.as_bytes(), &mut buf) { + Ok(len) if len == DECODED_BYTES_LEN => { + bytes.copy_from_slice(&buf[..DECODED_BYTES_LEN]); + } + Ok(_) => unreachable!(), + // keeping DecodePartial gets annoying lifetime-wise + Err(_) => return Err(Error::DecodeError(input.to_string())), + } + + let bytes_signing_key: [u8; SECRET_KEY_LENGTH] = { + let mut b = [0u8; SECRET_KEY_LENGTH]; + b.copy_from_slice(&bytes[0..SECRET_KEY_LENGTH]); + b + }; + let bytes_verifying_key: [u8; PUBLIC_KEY_LENGTH] = { + let mut b = [0u8; PUBLIC_KEY_LENGTH]; + b.copy_from_slice(&bytes[SECRET_KEY_LENGTH..]); + b + }; + + let signing_key = SigningKey::new( + name.to_string(), + ed25519_dalek::SigningKey::from_bytes(&bytes_signing_key), + ); + + let verifying_key = VerifyingKey::new( + name.to_string(), + ed25519_dalek::VerifyingKey::from_bytes(&bytes_verifying_key) + .map_err(Error::InvalidVerifyingKey)?, + ); + + Ok((signing_key, verifying_key)) +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Invalid name: {0}")] + InvalidName(String), + #[error("Missing separator")] + MissingSeparator, + #[error("Invalid signing key len: {0}")] + InvalidSigningKeyLen(usize), + #[error("Unable to base64-decode signing key: {0}")] + DecodeError(String), + #[error("VerifyingKey error: {0}")] + InvalidVerifyingKey(ed25519_dalek::SignatureError), +} + +#[cfg(test)] +mod test { + use crate::narinfo::DUMMY_KEYPAIR; + #[test] + fn parse() { + let (_signing_key, _verifying_key) = + super::parse_keypair(DUMMY_KEYPAIR).expect("must succeed"); + } + + #[test] + fn parse_fail() { + assert!(super::parse_keypair("cache.example.com-1:cCta2MEsRNuYCgWYyeRXLyfoFpKhQJKn8gLMeXWAb7vIpRKKo/3JoxJ24OYa3DxT2JVV38KjK/1ywHWuMe2JE").is_err()); + assert!(super::parse_keypair("cache.example.com-1cCta2MEsRNuYCgWYyeRXLyfoFpKhQJKn8gLMeXWAb7vIpRKKo/3JoxJ24OYa3DxT2JVV38KjK/1ywHWuMe2JE").is_err()); + } +} diff --git a/tvix/nix-compat/src/narinfo/public_keys.rs b/tvix/nix-compat/src/narinfo/verifying_keys.rs index 27dd90e096db..67ef2e3a459c 100644 --- a/tvix/nix-compat/src/narinfo/public_keys.rs +++ b/tvix/nix-compat/src/narinfo/verifying_keys.rs @@ -4,21 +4,21 @@ use std::fmt::Display; use data_encoding::BASE64; -use ed25519_dalek::{VerifyingKey, PUBLIC_KEY_LENGTH}; +use ed25519_dalek::PUBLIC_KEY_LENGTH; -use super::Signature; +use super::SignatureRef; /// This represents a ed25519 public key and "name". /// These are normally passed in the `trusted-public-keys` Nix config option, /// and consist of a name and base64-encoded ed25519 pubkey, separated by a `:`. -#[derive(Debug)] -pub struct PubKey { +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct VerifyingKey { name: String, - verifying_key: VerifyingKey, + verifying_key: ed25519_dalek::VerifyingKey, } -impl PubKey { - pub fn new(name: String, verifying_key: VerifyingKey) -> Self { +impl VerifyingKey { + pub fn new(name: String, verifying_key: ed25519_dalek::VerifyingKey) -> Self { Self { name, verifying_key, @@ -37,7 +37,7 @@ impl PubKey { } if bytes64.len() != BASE64.encode_len(PUBLIC_KEY_LENGTH) { - return Err(Error::InvalidPubKeyLen(bytes64.len())); + return Err(Error::InvalidVerifyingKeyLen(bytes64.len())); } let mut buf = [0; PUBLIC_KEY_LENGTH + 1]; @@ -51,7 +51,8 @@ impl PubKey { Err(_) => return Err(Error::DecodeError(input.to_string())), } - let verifying_key = VerifyingKey::from_bytes(&bytes).map_err(Error::InvalidVerifyingKey)?; + let verifying_key = + ed25519_dalek::VerifyingKey::from_bytes(&bytes).map_err(Error::InvalidVerifyingKey)?; Ok(Self { name: name.to_string(), @@ -68,8 +69,8 @@ impl PubKey { /// which means the name in the signature has to match, /// and the signature bytes themselves need to be a valid signature made by /// the signing key identified by [Self::verifying key]. - pub fn verify(&self, fingerprint: &str, signature: &Signature) -> bool { - if self.name() != signature.name() { + pub fn verify(&self, fingerprint: &str, signature: &SignatureRef<'_>) -> bool { + if self.name() != *signature.name() { return false; } @@ -84,14 +85,14 @@ pub enum Error { #[error("Missing separator")] MissingSeparator, #[error("Invalid pubkey len: {0}")] - InvalidPubKeyLen(usize), + InvalidVerifyingKeyLen(usize), #[error("VerifyingKey error: {0}")] InvalidVerifyingKey(ed25519_dalek::SignatureError), #[error("Unable to base64-decode pubkey: {0}")] DecodeError(String), } -impl Display for PubKey { +impl Display for VerifyingKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, @@ -108,9 +109,9 @@ mod test { use ed25519_dalek::PUBLIC_KEY_LENGTH; use rstest::rstest; - use crate::narinfo::Signature; + use crate::narinfo::SignatureRef; - use super::PubKey; + use super::VerifyingKey; const FINGERPRINT: &str = "1;/nix/store/syd87l2rxw8cbsxmxl853h0r6pdwhwjr-curl-7.82.0-bin;sha256:1b4sb93wp679q4zx9k1ignby1yna3z7c4c2ri3wphylbc2dwsys0;196040;/nix/store/0jqd0rlxzra1rs38rdxl43yh6rxchgc6-curl-7.82.0,/nix/store/6w8g7njm4mck5dmjxws0z1xnrxvl81xa-glibc-2.34-115,/nix/store/j5jxw3iy7bbz4a57fh9g2xm2gxmyal8h-zlib-1.2.12,/nix/store/yxvjs9drzsphm9pcf42a4byzj1kb9m7k-openssl-1.1.1n"; #[rstest] @@ -122,7 +123,7 @@ mod test { #[case] exp_name: &'static str, #[case] exp_verifying_key_bytes: &[u8; PUBLIC_KEY_LENGTH], ) { - let pubkey = PubKey::parse(input).expect("must parse"); + let pubkey = VerifyingKey::parse(input).expect("must parse"); assert_eq!(exp_name, pubkey.name()); assert_eq!(exp_verifying_key_bytes, pubkey.verifying_key.as_bytes()); } @@ -132,7 +133,7 @@ mod test { #[case::missing_padding("cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY")] #[case::wrong_length("cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDS")] fn parse_fail(#[case] input: &'static str) { - PubKey::parse(input).expect_err("must fail"); + VerifyingKey::parse(input).expect_err("must fail"); } #[rstest] @@ -144,8 +145,8 @@ mod test { #[case] signature_str: &'static str, #[case] expected: bool, ) { - let pubkey = PubKey::parse(pubkey_str).expect("must parse"); - let signature = Signature::parse(signature_str).expect("must parse"); + let pubkey = VerifyingKey::parse(pubkey_str).expect("must parse"); + let signature = SignatureRef::parse(signature_str).expect("must parse"); assert_eq!(expected, pubkey.verify(fingerprint, &signature)); } diff --git a/tvix/nix-compat/src/nix_daemon/handler.rs b/tvix/nix-compat/src/nix_daemon/handler.rs new file mode 100644 index 000000000000..65c5c2d60d08 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/handler.rs @@ -0,0 +1,294 @@ +use std::{future::Future, sync::Arc}; + +use bytes::Bytes; +use tokio::{ + io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, + sync::Mutex, +}; +use tracing::{debug, warn}; + +use super::{ + types::QueryValidPaths, + worker_protocol::{server_handshake_client, ClientSettings, Operation, Trust, STDERR_LAST}, + NixDaemonIO, +}; + +use crate::{ + store_path::StorePath, + wire::{ + de::{NixRead, NixReader}, + ser::{NixSerialize, NixWrite, NixWriter, NixWriterBuilder}, + ProtocolVersion, + }, +}; + +use crate::{nix_daemon::types::NixError, worker_protocol::STDERR_ERROR}; + +/// Handles a single connection with a nix client. +/// +/// As part of its [`initialization`] it performs the handshake with the client +/// and determines the [ProtocolVersion] and [ClientSettings] to use for the remainder of the session. +/// +/// Once initialized, [`handle_client`] needs to be called to handle the rest of the session, +/// it delegates all operation handling to an instance of [NixDaemonIO]. +/// +/// [`initialization`]: NixDaemon::initialize +#[allow(dead_code)] +pub struct NixDaemon<IO, R, W> { + io: Arc<IO>, + protocol_version: ProtocolVersion, + client_settings: ClientSettings, + reader: NixReader<R>, + writer: Arc<Mutex<NixWriter<W>>>, +} + +impl<IO, R, W> NixDaemon<IO, R, W> +where + IO: NixDaemonIO + Sync + Send, +{ + pub fn new( + io: Arc<IO>, + protocol_version: ProtocolVersion, + client_settings: ClientSettings, + reader: NixReader<R>, + writer: NixWriter<W>, + ) -> Self { + Self { + io, + protocol_version, + client_settings, + reader, + writer: Arc::new(Mutex::new(writer)), + } + } +} + +impl<IO, RW> NixDaemon<IO, ReadHalf<RW>, WriteHalf<RW>> +where + RW: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static, + IO: NixDaemonIO + Sync + Send, +{ + /// Async constructor for NixDaemon. + /// + /// Performs the initial handshake with the client and retrieves the client's preferred + /// settings. + /// + /// The resulting daemon can handle the client session by calling [NixDaemon::handle_client]. + pub async fn initialize(io: Arc<IO>, mut connection: RW) -> Result<Self, std::io::Error> + where + RW: AsyncReadExt + AsyncWriteExt + Send + Unpin, + { + let protocol_version = + server_handshake_client(&mut connection, "2.18.2", Trust::Trusted).await?; + + connection.write_u64_le(STDERR_LAST).await?; + let (reader, writer) = split(connection); + let mut reader = NixReader::builder() + .set_version(protocol_version) + .build(reader); + let mut writer = NixWriterBuilder::default() + .set_version(protocol_version) + .build(writer); + + // The first op is always SetOptions + let operation: Operation = reader.read_value().await?; + if operation != Operation::SetOptions { + return Err(std::io::Error::other( + "Expected SetOptions operation, but got {operation}", + )); + } + let client_settings: ClientSettings = reader.read_value().await?; + writer.write_number(STDERR_LAST).await?; + writer.flush().await?; + + Ok(Self::new( + io, + protocol_version, + client_settings, + reader, + writer, + )) + } + + /// Main client connection loop, reads client's requests and responds to them accordingly. + pub async fn handle_client(&mut self) -> Result<(), std::io::Error> { + let io = self.io.clone(); + loop { + let op_code = self.reader.read_number().await?; + match TryInto::<Operation>::try_into(op_code) { + // Note: please keep operations sorted in ascending order of their numerical op number. + Ok(operation) => match operation { + Operation::IsValidPath => { + let path: StorePath<String> = self.reader.read_value().await?; + self.handle(io.is_valid_path(&path)).await? + } + // Note this operation does not currently delegate to NixDaemonIO, + // The general idea is that we will pass relevant ClientSettings + // into individual NixDaemonIO method calls if the need arises. + // For now we just store the settings in the NixDaemon for future use. + Operation::SetOptions => { + self.client_settings = self.reader.read_value().await?; + self.handle(async { Ok(()) }).await? + } + Operation::QueryPathInfo => { + let path: StorePath<String> = self.reader.read_value().await?; + self.handle(io.query_path_info(&path)).await? + } + Operation::QueryPathFromHashPart => { + let hash: Bytes = self.reader.read_value().await?; + self.handle(io.query_path_from_hash_part(&hash)).await? + } + Operation::QueryValidPaths => { + let query: QueryValidPaths = self.reader.read_value().await?; + self.handle(io.query_valid_paths(&query)).await? + } + Operation::QueryValidDerivers => { + let path: StorePath<String> = self.reader.read_value().await?; + self.handle(io.query_valid_derivers(&path)).await? + } + // FUTUREWORK: These are just stubs that return an empty list. + // It's important not to return an error for the local-overlay:// store + // to work properly. While it will not see certain referrers and realizations + // it will not fail on various operations like gc and optimize store. At the + // same time, returning an empty list here shouldn't break any of local-overlay store's + // invariants. + Operation::QueryReferrers | Operation::QueryRealisation => { + let _: String = self.reader.read_value().await?; + self.handle(async move { + warn!( + ?operation, + "This operation is not implemented. Returning empty result..." + ); + Ok(Vec::<StorePath<String>>::new()) + }) + .await? + } + _ => { + return Err(std::io::Error::other(format!( + "Operation {operation:?} is not implemented" + ))); + } + }, + _ => { + return Err(std::io::Error::other(format!( + "Unknown operation code received: {op_code}" + ))); + } + } + } + } + + /// Handles the operation and sends the response or error to the client. + /// + /// As per nix daemon protocol, after sending the request, the client expects zero or more + /// log lines/activities followed by either + /// * STDERR_LAST and the response bytes + /// * STDERR_ERROR and the error + /// + /// This is a helper method, awaiting on the passed in future and then + /// handling log lines/activities as described above. + async fn handle<T>( + &mut self, + future: impl Future<Output = std::io::Result<T>>, + ) -> Result<(), std::io::Error> + where + T: NixSerialize + Send, + { + let result = future.await; + let mut writer = self.writer.lock().await; + + match result { + Ok(r) => { + // the protocol requires that we first indicate that we are done sending logs + // by sending STDERR_LAST and then the response. + writer.write_number(STDERR_LAST).await?; + writer.write_value(&r).await?; + writer.flush().await + } + Err(e) => { + debug!(err = ?e, "IO error"); + writer.write_number(STDERR_ERROR).await?; + writer.write_value(&NixError::new(format!("{e:?}"))).await?; + writer.flush().await + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{io::Result, sync::Arc}; + + use tokio::io::AsyncWriteExt; + + use crate::{ + nix_daemon::types::UnkeyedValidPathInfo, + wire::ProtocolVersion, + worker_protocol::{ClientSettings, WORKER_MAGIC_1, WORKER_MAGIC_2}, + }; + + struct MockDaemonIO {} + + impl NixDaemonIO for MockDaemonIO { + async fn query_path_info( + &self, + _path: &crate::store_path::StorePath<String>, + ) -> Result<Option<UnkeyedValidPathInfo>> { + Ok(None) + } + + async fn query_path_from_hash_part( + &self, + _hash: &[u8], + ) -> Result<Option<UnkeyedValidPathInfo>> { + Ok(None) + } + } + + #[tokio::test] + async fn test_daemon_initialization() { + let mut builder = tokio_test::io::Builder::new(); + let test_conn = builder + .read(&WORKER_MAGIC_1.to_le_bytes()) + .write(&WORKER_MAGIC_2.to_le_bytes()) + // Our version is 1.37 + .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // The client's versin is 1.35 + .read(&[35, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // cpu affinity + .read(&[0; 8]) + // reservespace + .read(&[0; 8]) + // version (size) + .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // version (data == 2.18.2 + padding) + .write(&[50, 46, 49, 56, 46, 50, 0, 0]) + // Trusted (1 == client trusted) + .write(&[1, 0, 0, 0, 0, 0, 0, 0]) + // STDERR_LAST + .write(&[115, 116, 108, 97, 0, 0, 0, 0]); + + let mut bytes = Vec::new(); + let mut writer = NixWriter::new(&mut bytes); + writer + .write_value(&ClientSettings::default()) + .await + .unwrap(); + writer.flush().await.unwrap(); + + let test_conn = test_conn + // SetOptions op + .read(&[19, 0, 0, 0, 0, 0, 0, 0]) + .read(&bytes) + // STDERR_LAST + .write(&[115, 116, 108, 97, 0, 0, 0, 0]) + .build(); + + let daemon = NixDaemon::initialize(Arc::new(MockDaemonIO {}), test_conn) + .await + .unwrap(); + assert_eq!(daemon.client_settings, ClientSettings::default()); + assert_eq!(daemon.protocol_version, ProtocolVersion::from_parts(1, 35)); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/mod.rs b/tvix/nix-compat/src/nix_daemon/mod.rs index fe652377d1b4..e475263d2302 100644 --- a/tvix/nix-compat/src/nix_daemon/mod.rs +++ b/tvix/nix-compat/src/nix_daemon/mod.rs @@ -1,4 +1,211 @@ pub mod worker_protocol; -mod protocol_version; -pub use protocol_version::ProtocolVersion; +use std::io::Result; + +use futures::future::try_join_all; +use tracing::warn; +use types::{QueryValidPaths, UnkeyedValidPathInfo}; + +use crate::store_path::StorePath; + +pub mod handler; +pub mod types; + +/// Represents all possible operations over the nix-daemon protocol. +pub trait NixDaemonIO: Sync { + fn is_valid_path( + &self, + path: &StorePath<String>, + ) -> impl std::future::Future<Output = Result<bool>> + Send { + async move { Ok(self.query_path_info(path).await?.is_some()) } + } + + fn query_path_info( + &self, + path: &StorePath<String>, + ) -> impl std::future::Future<Output = Result<Option<UnkeyedValidPathInfo>>> + Send; + + fn query_path_from_hash_part( + &self, + hash: &[u8], + ) -> impl std::future::Future<Output = Result<Option<UnkeyedValidPathInfo>>> + Send; + + fn query_valid_paths( + &self, + request: &QueryValidPaths, + ) -> impl std::future::Future<Output = Result<Vec<UnkeyedValidPathInfo>>> + Send { + async move { + if request.substitute { + warn!("tvix does not yet support substitution, ignoring the 'substitute' flag..."); + } + // Using try_join_all here to avoid returning partial results to the client. + // The only reason query_path_info can fail is due to transient IO errors, + // so we return such errors to the client as opposed to only returning paths + // that succeeded. + let result = + try_join_all(request.paths.iter().map(|path| self.query_path_info(path))).await?; + let result: Vec<UnkeyedValidPathInfo> = result.into_iter().flatten().collect(); + Ok(result) + } + } + + fn query_valid_derivers( + &self, + path: &StorePath<String>, + ) -> impl std::future::Future<Output = Result<Vec<StorePath<String>>>> + Send { + async move { + let result = self.query_path_info(path).await?; + let result: Vec<_> = result.into_iter().filter_map(|info| info.deriver).collect(); + Ok(result) + } + } +} + +#[cfg(test)] +mod tests { + + use crate::{nix_daemon::types::QueryValidPaths, store_path::StorePath}; + + use super::{types::UnkeyedValidPathInfo, NixDaemonIO}; + + // Very simple mock + // Unable to use mockall as it does not support unboxed async traits. + pub struct MockNixDaemonIO { + query_path_info_result: Option<UnkeyedValidPathInfo>, + } + + impl NixDaemonIO for MockNixDaemonIO { + async fn query_path_info( + &self, + _path: &StorePath<String>, + ) -> std::io::Result<Option<UnkeyedValidPathInfo>> { + Ok(self.query_path_info_result.clone()) + } + + async fn query_path_from_hash_part( + &self, + _hash: &[u8], + ) -> std::io::Result<Option<UnkeyedValidPathInfo>> { + Ok(None) + } + } + + #[tokio::test] + async fn test_is_valid_path_returns_true() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: Some(UnkeyedValidPathInfo::default()), + }; + + let result = io + .is_valid_path(&path) + .await + .expect("expected to get a non-empty response"); + assert!(result, "expected to get true"); + } + + #[tokio::test] + async fn test_is_valid_path_returns_false() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: None, + }; + + let result = io + .is_valid_path(&path) + .await + .expect("expected to get a non-empty response"); + assert!(!result, "expected to get false"); + } + + #[tokio::test] + async fn test_query_valid_paths_returns_empty_response() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: None, + }; + + let result = io + .query_valid_paths(&QueryValidPaths { + paths: vec![path], + substitute: false, + }) + .await + .expect("expected to get a non-empty response"); + assert_eq!(result, vec![], "expected to get empty response"); + } + + #[tokio::test] + async fn test_query_valid_paths_returns_non_empty_response() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: Some(UnkeyedValidPathInfo::default()), + }; + + let result = io + .query_valid_paths(&QueryValidPaths { + paths: vec![path], + substitute: false, + }) + .await + .expect("expected to get a non-empty response"); + assert_eq!( + result, + vec![UnkeyedValidPathInfo::default()], + "expected to get non empty response" + ); + } + + #[tokio::test] + async fn test_query_valid_derivers_returns_empty_response() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: None, + }; + + let result = io + .query_valid_derivers(&path) + .await + .expect("expected to get a non-empty response"); + assert_eq!(result, vec![], "expected to get empty response"); + } + + #[tokio::test] + async fn test_query_valid_derivers_returns_non_empty_response() { + let path = + StorePath::<String>::from_bytes("z6r3bn5l51679pwkvh9nalp6c317z34m-hello".as_bytes()) + .unwrap(); + let deriver = StorePath::<String>::from_bytes( + "z6r3bn5l51679pwkvh9nalp6c317z34m-hello.drv".as_bytes(), + ) + .unwrap(); + let io = MockNixDaemonIO { + query_path_info_result: Some(UnkeyedValidPathInfo { + deriver: Some(deriver.clone()), + nar_hash: "".to_owned(), + references: vec![], + registration_time: 0, + nar_size: 1, + ultimate: true, + signatures: vec![], + ca: None, + }), + }; + + let result = io + .query_valid_derivers(&path) + .await + .expect("expected to get a non-empty response"); + assert_eq!(result, vec![deriver], "expected to get non empty response"); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/types.rs b/tvix/nix-compat/src/nix_daemon/types.rs new file mode 100644 index 000000000000..bf7b1e6f6e58 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/types.rs @@ -0,0 +1,176 @@ +use crate::{ + narinfo::Signature, + nixhash::CAHash, + store_path::StorePath, + wire::{ + de::{NixDeserialize, NixRead}, + ser::{NixSerialize, NixWrite}, + }, +}; +use nix_compat_derive::{NixDeserialize, NixSerialize}; +use std::future::Future; + +/// Marker type that consumes/sends and ignores a u64. +#[derive(Clone, Debug, NixDeserialize, NixSerialize)] +#[nix(from = "u64", into = "u64")] +pub struct IgnoredZero; +impl From<u64> for IgnoredZero { + fn from(_: u64) -> Self { + IgnoredZero + } +} + +impl From<IgnoredZero> for u64 { + fn from(_: IgnoredZero) -> Self { + 0 + } +} + +#[derive(Debug, NixSerialize)] +pub struct TraceLine { + have_pos: IgnoredZero, + hint: String, +} + +/// Represents an error returned by the nix-daemon to its client. +/// +/// Adheres to the format described in serialization.md +#[derive(NixSerialize)] +pub struct NixError { + #[nix(version = "26..")] + type_: &'static str, + + #[nix(version = "26..")] + level: u64, + + #[nix(version = "26..")] + name: &'static str, + + msg: String, + #[nix(version = "26..")] + have_pos: IgnoredZero, + + #[nix(version = "26..")] + traces: Vec<TraceLine>, + + #[nix(version = "..=25")] + exit_status: u64, +} + +impl NixError { + pub fn new(msg: String) -> Self { + Self { + type_: "Error", + level: 0, // error + name: "Error", + msg, + have_pos: IgnoredZero {}, + traces: vec![], + exit_status: 1, + } + } +} + +nix_compat_derive::nix_serialize_remote!(#[nix(display)] Signature<String>); + +impl NixSerialize for CAHash { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_value(&self.to_nix_nixbase32_string()).await + } +} + +impl NixSerialize for Option<CAHash> { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + match self { + Some(value) => writer.write_value(value).await, + None => writer.write_value("").await, + } + } +} + +impl NixSerialize for Option<UnkeyedValidPathInfo> { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + match self { + Some(value) => { + writer.write_value(&true).await?; + writer.write_value(value).await + } + None => writer.write_value(&false).await, + } + } +} + +// Custom implementation since FromStr does not use from_absolute_path +impl NixDeserialize for StorePath<String> { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + use crate::wire::de::Error; + if let Some(buf) = reader.try_read_bytes().await? { + let result = StorePath::<String>::from_absolute_path(&buf); + result.map(Some).map_err(R::Error::invalid_data) + } else { + Ok(None) + } + } +} + +// Custom implementation since Display does not use absolute paths. +impl<S> NixSerialize for StorePath<S> +where + S: AsRef<str>, +{ + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite, + { + let sp = self.to_absolute_path(); + async move { writer.write_value(&sp).await } + } +} + +// Writes StorePath or an empty string. +impl NixSerialize for Option<StorePath<String>> { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + match self { + Some(value) => writer.write_value(value).await, + None => writer.write_value("").await, + } + } +} + +#[derive(NixSerialize, Debug, Clone, Default, PartialEq)] +pub struct UnkeyedValidPathInfo { + pub deriver: Option<StorePath<String>>, + pub nar_hash: String, + pub references: Vec<StorePath<String>>, + pub registration_time: u64, + pub nar_size: u64, + pub ultimate: bool, + pub signatures: Vec<Signature<String>>, + pub ca: Option<CAHash>, +} + +/// Request tupe for [super::worker_protocol::Operation::QueryValidPaths] +#[derive(NixDeserialize)] +pub struct QueryValidPaths { + // Paths to query + pub paths: Vec<StorePath<String>>, + + // Whether to try and substitute the paths. + #[nix(version = "27..")] + pub substitute: bool, +} diff --git a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs index 7e3adc0db2ff..1ef9b9ab02d7 100644 --- a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs +++ b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs @@ -1,19 +1,21 @@ use std::{ - collections::HashMap, + cmp::min, + collections::BTreeMap, io::{Error, ErrorKind}, }; -use enum_primitive_derive::Primitive; -use num_traits::{FromPrimitive, ToPrimitive}; +use nix_compat_derive::{NixDeserialize, NixSerialize}; +use num_enum::{FromPrimitive, IntoPrimitive, TryFromPrimitive}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::wire; -use super::ProtocolVersion; +use crate::wire::ProtocolVersion; -static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc" -static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio" +pub(crate) static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc" +pub(crate) static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio" pub static STDERR_LAST: u64 = 0x616c7473; // "alts" +pub(crate) static STDERR_ERROR: u64 = 0x63787470; // "cxtp" /// | Nix version | Protocol | /// |-----------------|----------| @@ -54,7 +56,11 @@ pub static MAX_SETTING_SIZE: usize = 1024; /// Note: for now, we're using the Nix 2.20 operation description. The /// operations marked as obsolete are obsolete for Nix 2.20, not /// necessarily for Nix 2.3. We'll revisit this later on. -#[derive(Debug, PartialEq, Primitive)] +#[derive( + Clone, Debug, PartialEq, TryFromPrimitive, IntoPrimitive, NixDeserialize, NixSerialize, +)] +#[nix(try_from = "u64", into = "u64")] +#[repr(u64)] pub enum Operation { IsValidPath = 1, HasSubstitutes = 3, @@ -105,8 +111,13 @@ pub enum Operation { /// Log verbosity. In the Nix wire protocol, the client requests a /// verbosity level to the daemon, which in turns does not produce any /// log below this verbosity. -#[derive(Debug, PartialEq, Primitive)] +#[derive( + Debug, PartialEq, FromPrimitive, IntoPrimitive, NixDeserialize, NixSerialize, Default, Clone, +)] +#[nix(from = "u64", into = "u64")] +#[repr(u64)] pub enum Verbosity { + #[default] LvlError = 0, LvlWarn = 1, LvlNotice = 2, @@ -119,7 +130,7 @@ pub enum Verbosity { /// Settings requested by the client. These settings are applied to a /// connection to between the daemon and a client. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, NixDeserialize, NixSerialize, Default)] pub struct ClientSettings { pub keep_failed: bool, pub keep_going: bool, @@ -127,70 +138,21 @@ pub struct ClientSettings { pub verbosity: Verbosity, pub max_build_jobs: u64, pub max_silent_time: u64, - pub verbose_build: bool, + pub use_build_hook: bool, + pub verbose_build: u64, + pub log_type: u64, + pub print_build_trace: u64, pub build_cores: u64, pub use_substitutes: bool, + /// Key/Value dictionary in charge of overriding the settings set /// by the Nix config file. /// /// Some settings can be safely overidden, /// some other require the user running the Nix client to be part /// of the trusted users group. - pub overrides: HashMap<String, String>, -} - -/// Reads the client settings from the wire. -/// -/// Note: this function **only** reads the settings. It does not -/// manage the log state with the daemon. You'll have to do that on -/// your own. A minimal log implementation will consist in sending -/// back [STDERR_LAST] to the client after reading the client -/// settings. -/// -/// FUTUREWORK: write serialization. -pub async fn read_client_settings<R: AsyncReadExt + Unpin>( - r: &mut R, - client_version: ProtocolVersion, -) -> std::io::Result<ClientSettings> { - let keep_failed = r.read_u64_le().await? != 0; - let keep_going = r.read_u64_le().await? != 0; - let try_fallback = r.read_u64_le().await? != 0; - let verbosity_uint = r.read_u64_le().await?; - let verbosity = Verbosity::from_u64(verbosity_uint).ok_or_else(|| { - Error::new( - ErrorKind::InvalidData, - format!("Can't convert integer {} to verbosity", verbosity_uint), - ) - })?; - let max_build_jobs = r.read_u64_le().await?; - let max_silent_time = r.read_u64_le().await?; - _ = r.read_u64_le().await?; // obsolete useBuildHook - let verbose_build = r.read_u64_le().await? != 0; - _ = r.read_u64_le().await?; // obsolete logType - _ = r.read_u64_le().await?; // obsolete printBuildTrace - let build_cores = r.read_u64_le().await?; - let use_substitutes = r.read_u64_le().await? != 0; - let mut overrides = HashMap::new(); - if client_version.minor() >= 12 { - let num_overrides = r.read_u64_le().await?; - for _ in 0..num_overrides { - let name = wire::read_string(r, 0..=MAX_SETTING_SIZE).await?; - let value = wire::read_string(r, 0..=MAX_SETTING_SIZE).await?; - overrides.insert(name, value); - } - } - Ok(ClientSettings { - keep_failed, - keep_going, - try_fallback, - verbosity, - max_build_jobs, - max_silent_time, - verbose_build, - build_cores, - use_substitutes, - overrides, - }) + #[nix(version = "12..")] + pub overrides: BTreeMap<String, String>, } /// Performs the initial handshake the server is sending to a connecting client. @@ -209,7 +171,7 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>( /// /// # Return /// -/// The protocol version of the client. +/// The protocol version to use for further comms, min(client_version, our_version). pub async fn server_handshake_client<'a, RW: 'a>( mut conn: &'a mut RW, nix_version: &str, @@ -239,46 +201,46 @@ where format!("The nix client version {} is too old", client_version), )); } - if client_version.minor() >= 14 { + let picked_version = min(PROTOCOL_VERSION, client_version); + if picked_version.minor() >= 14 { // Obsolete CPU affinity. let read_affinity = conn.read_u64_le().await?; if read_affinity != 0 { let _cpu_affinity = conn.read_u64_le().await?; }; } - if client_version.minor() >= 11 { + if picked_version.minor() >= 11 { // Obsolete reserveSpace let _reserve_space = conn.read_u64_le().await?; } - if client_version.minor() >= 33 { + if picked_version.minor() >= 33 { // Nix version. We're plain lying, we're not Nix, but eh… // Setting it to the 2.3 lineage. Not 100% sure this is a // good idea. wire::write_bytes(&mut conn, nix_version).await?; conn.flush().await?; } - if client_version.minor() >= 35 { + if picked_version.minor() >= 35 { write_worker_trust_level(&mut conn, trusted).await?; } - Ok(client_version) + Ok(picked_version) } } /// Read a worker [Operation] from the wire. pub async fn read_op<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<Operation> { let op_number = r.read_u64_le().await?; - Operation::from_u64(op_number).ok_or(Error::new( - ErrorKind::InvalidData, - format!("Invalid OP number {}", op_number), - )) + Operation::try_from(op_number).map_err(|_| { + Error::new( + ErrorKind::InvalidData, + format!("Invalid OP number {}", op_number), + ) + }) } /// Write a worker [Operation] to the wire. -pub async fn write_op<W: AsyncWriteExt + Unpin>(w: &mut W, op: &Operation) -> std::io::Result<()> { - let op = Operation::to_u64(op).ok_or(Error::new( - ErrorKind::Other, - format!("Can't convert the OP {:?} to u64", op), - ))?; +pub async fn write_op<W: AsyncWriteExt + Unpin>(w: &mut W, op: Operation) -> std::io::Result<()> { + let op: u64 = op.into(); w.write_u64(op).await } @@ -307,8 +269,6 @@ where #[cfg(test)] mod tests { use super::*; - use hex_literal::hex; - use tokio_test::io::Builder; #[tokio::test] async fn test_init_hanshake() { @@ -330,105 +290,63 @@ mod tests { // Trusted (1 == client trusted .write(&[1, 0, 0, 0, 0, 0, 0, 0]) .build(); - let client_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) + let picked_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) .await .unwrap(); - assert_eq!(client_version, PROTOCOL_VERSION) + assert_eq!(picked_version, PROTOCOL_VERSION) } #[tokio::test] - async fn test_read_client_settings_without_overrides() { - // Client settings bits captured from a Nix 2.3.17 run w/ sockdump (protocol version 21). - let wire_bits = hex!( - "00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 02 00 00 00 00 00 00 00 \ - 10 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 01 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 01 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00" - ); - let mut mock = Builder::new().read(&wire_bits).build(); - let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21)) + async fn test_init_hanshake_with_newer_client_should_use_older_version() { + let mut test_conn = tokio_test::io::Builder::new() + .read(&WORKER_MAGIC_1.to_le_bytes()) + .write(&WORKER_MAGIC_2.to_le_bytes()) + .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // Client is newer than us. + .read(&[38, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // cpu affinity + .read(&[0; 8]) + // reservespace + .read(&[0; 8]) + // version (size) + .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // version (data == 2.18.2 + padding) + .write(&[50, 46, 49, 56, 46, 50, 0, 0]) + // Trusted (1 == client trusted + .write(&[1, 0, 0, 0, 0, 0, 0, 0]) + .build(); + let picked_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) .await - .expect("should parse"); - let expected = ClientSettings { - keep_failed: false, - keep_going: false, - try_fallback: false, - verbosity: Verbosity::LvlNotice, - max_build_jobs: 16, - max_silent_time: 0, - verbose_build: false, - build_cores: 0, - use_substitutes: true, - overrides: HashMap::new(), - }; - assert_eq!(settings, expected); + .unwrap(); + + assert_eq!(picked_version, PROTOCOL_VERSION) } #[tokio::test] - async fn test_read_client_settings_with_overrides() { - // Client settings bits captured from a Nix 2.3.17 run w/ sockdump (protocol version 21). - let wire_bits = hex!( - "00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 02 00 00 00 00 00 00 00 \ - 10 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 01 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 00 00 00 00 00 00 00 00 \ - 01 00 00 00 00 00 00 00 \ - 02 00 00 00 00 00 00 00 \ - 0c 00 00 00 00 00 00 00 \ - 61 6c 6c 6f 77 65 64 2d \ - 75 72 69 73 00 00 00 00 \ - 1e 00 00 00 00 00 00 00 \ - 68 74 74 70 73 3a 2f 2f \ - 62 6f 72 64 65 61 75 78 \ - 2e 67 75 69 78 2e 67 6e \ - 75 2e 6f 72 67 2f 00 00 \ - 0d 00 00 00 00 00 00 00 \ - 61 6c 6c 6f 77 65 64 2d \ - 75 73 65 72 73 00 00 00 \ - 0b 00 00 00 00 00 00 00 \ - 6a 65 61 6e 20 70 69 65 \ - 72 72 65 00 00 00 00 00" - ); - let mut mock = Builder::new().read(&wire_bits).build(); - let settings = read_client_settings(&mut mock, ProtocolVersion::from_parts(1, 21)) + async fn test_init_hanshake_with_older_client_should_use_older_version() { + let mut test_conn = tokio_test::io::Builder::new() + .read(&WORKER_MAGIC_1.to_le_bytes()) + .write(&WORKER_MAGIC_2.to_le_bytes()) + .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // Client is newer than us. + .read(&[24, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // cpu affinity + .read(&[0; 8]) + // reservespace + .read(&[0; 8]) + // NOTE: we are not writing version and trust since the client is too old. + // version (size) + //.write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + // version (data == 2.18.2 + padding) + //.write(&[50, 46, 49, 56, 46, 50, 0, 0]) + // Trusted (1 == client trusted + //.write(&[1, 0, 0, 0, 0, 0, 0, 0]) + .build(); + let picked_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted) .await - .expect("should parse"); - let overrides = HashMap::from([ - ( - String::from("allowed-uris"), - String::from("https://bordeaux.guix.gnu.org/"), - ), - (String::from("allowed-users"), String::from("jean pierre")), - ]); - let expected = ClientSettings { - keep_failed: false, - keep_going: false, - try_fallback: false, - verbosity: Verbosity::LvlNotice, - max_build_jobs: 16, - max_silent_time: 0, - verbose_build: false, - build_cores: 0, - use_substitutes: true, - overrides, - }; - assert_eq!(settings, expected); + .unwrap(); + + assert_eq!(picked_version, ProtocolVersion::from_parts(1, 24)) } } diff --git a/tvix/nix-compat/src/nix_http/mod.rs b/tvix/nix-compat/src/nix_http/mod.rs new file mode 100644 index 000000000000..89ba147b8071 --- /dev/null +++ b/tvix/nix-compat/src/nix_http/mod.rs @@ -0,0 +1,115 @@ +use tracing::trace; + +use crate::nixbase32; + +/// The mime type used for NAR files, both compressed and uncompressed +pub const MIME_TYPE_NAR: &str = "application/x-nix-nar"; +/// The mime type used for NARInfo files +pub const MIME_TYPE_NARINFO: &str = "text/x-nix-narinfo"; +/// The mime type used for the `nix-cache-info` file +pub const MIME_TYPE_CACHE_INFO: &str = "text/x-nix-cache-info"; + +/// Parses a `14cx20k6z4hq508kqi2lm79qfld5f9mf7kiafpqsjs3zlmycza0k.nar` +/// string and returns the nixbase32-decoded digest, as well as the compression +/// suffix (which might be empty). +pub fn parse_nar_str(s: &str) -> Option<([u8; 32], &str)> { + if !s.is_char_boundary(52) { + trace!("invalid string, no char boundary at 52"); + return None; + } + + let (hash_str, suffix) = s.split_at(52); + + // we know hash_str is 52 bytes, so it's ok to unwrap here. + let hash_str_fixed: [u8; 52] = hash_str.as_bytes().try_into().unwrap(); + + match suffix.strip_prefix(".nar") { + Some(compression_suffix) => match nixbase32::decode_fixed(hash_str_fixed) { + Err(e) => { + trace!(err=%e, "invalid nixbase32 encoding"); + None + } + Ok(digest) => Some((digest, compression_suffix)), + }, + None => { + trace!("no .nar suffix"); + None + } + } +} + +/// Parses a `3mzh8lvgbynm9daj7c82k2sfsfhrsfsy.narinfo` string and returns the +/// nixbase32-decoded digest. +pub fn parse_narinfo_str(s: &str) -> Option<[u8; 20]> { + if !s.is_char_boundary(32) { + trace!("invalid string, no char boundary at 32"); + return None; + } + + match s.split_at(32) { + (hash_str, ".narinfo") => { + // we know this is 32 bytes, so it's ok to unwrap here. + let hash_str_fixed: [u8; 32] = hash_str.as_bytes().try_into().unwrap(); + + match nixbase32::decode_fixed(hash_str_fixed) { + Err(e) => { + trace!(err=%e, "invalid nixbase32 encoding"); + None + } + Ok(digest) => Some(digest), + } + } + _ => { + trace!("invalid string, no .narinfo suffix"); + None + } + } +} + +#[cfg(test)] +mod test { + use super::{parse_nar_str, parse_narinfo_str}; + use hex_literal::hex; + + #[test] + fn parse_nar_str_success() { + assert_eq!( + ( + hex!("13a8cf7ca57f68a9f1752acee36a72a55187d3a954443c112818926f26109d91"), + "" + ), + parse_nar_str("14cx20k6z4hq508kqi2lm79qfld5f9mf7kiafpqsjs3zlmycza0k.nar").unwrap() + ); + + assert_eq!( + ( + hex!("13a8cf7ca57f68a9f1752acee36a72a55187d3a954443c112818926f26109d91"), + ".xz" + ), + parse_nar_str("14cx20k6z4hq508kqi2lm79qfld5f9mf7kiafpqsjs3zlmycza0k.nar.xz").unwrap() + ) + } + + #[test] + fn parse_nar_str_failure() { + assert!(parse_nar_str("14cx20k6z4hq508kqi2lm79qfld5f9mf7kiafpqsjs3zlmycza0").is_none()); + assert!( + parse_nar_str("14cx20k6z4hq508kqi2lm79qfld5f9mf7kiafpqsjs3zlmycza0🦊.nar").is_none() + ) + } + #[test] + fn parse_narinfo_str_success() { + assert_eq!( + hex!("8a12321522fd91efbd60ebb2481af88580f61600"), + parse_narinfo_str("00bgd045z0d4icpbc2yyz4gx48ak44la.narinfo").unwrap() + ); + } + + #[test] + fn parse_narinfo_str_failure() { + assert!(parse_narinfo_str("00bgd045z0d4icpbc2yyz4gx48ak44la").is_none()); + assert!(parse_narinfo_str("/00bgd045z0d4icpbc2yyz4gx48ak44la").is_none()); + assert!(parse_narinfo_str("000000").is_none()); + assert!(parse_narinfo_str("00bgd045z0d4icpbc2yyz4gx48ak44l🦊.narinfo").is_none()); + } +} diff --git a/tvix/nix-compat/src/nixbase32.rs b/tvix/nix-compat/src/nixbase32.rs index b7ffc1dc2bcd..8d34e4cedce6 100644 --- a/tvix/nix-compat/src/nixbase32.rs +++ b/tvix/nix-compat/src/nixbase32.rs @@ -62,6 +62,12 @@ pub fn decode(input: impl AsRef<[u8]>) -> Result<Vec<u8>, DecodeError> { let input = input.as_ref(); let output_len = decode_len(input.len()); + if input.len() != encode_len(output_len) { + return Err(DecodeError { + position: input.len().min(encode_len(output_len)), + kind: DecodeKind::Length, + }); + } let mut output: Vec<u8> = vec![0x00; output_len]; decode_inner(input, &mut output)?; @@ -163,6 +169,10 @@ mod tests { #[case::invalid_encoding_1("zz", None)] // this is an even more specific example - it'd decode as 00000000 11 #[case::invalid_encoding_2("c0", None)] + // This has an invalid length + #[case::invalid_encoding_3("0", None)] + // This has an invalid length + #[case::invalid_encoding_4("0zz", None)] #[test] fn decode(#[case] enc: &str, #[case] dec: Option<&[u8]>) { match dec { @@ -201,6 +211,11 @@ mod tests { #[test] fn decode_len() { assert_eq!(super::decode_len(0), 0); + assert_eq!(super::decode_len(1), 0); + assert_eq!(super::decode_len(2), 1); + assert_eq!(super::decode_len(3), 1); + assert_eq!(super::decode_len(4), 2); + assert_eq!(super::decode_len(5), 3); assert_eq!(super::decode_len(32), 20); } } diff --git a/tvix/nix-compat/src/nixcpp/conf.rs b/tvix/nix-compat/src/nixcpp/conf.rs new file mode 100644 index 000000000000..68308115f988 --- /dev/null +++ b/tvix/nix-compat/src/nixcpp/conf.rs @@ -0,0 +1,202 @@ +use std::{fmt::Display, str::FromStr}; + +/// Represents configuration as stored in /etc/nix/nix.conf. +/// This list is not exhaustive, feel free to add more. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct NixConfig<'a> { + pub allowed_users: Option<Vec<&'a str>>, + pub auto_optimise_store: Option<bool>, + pub cores: Option<u64>, + pub max_jobs: Option<u64>, + pub require_sigs: Option<bool>, + pub sandbox: Option<SandboxSetting>, + pub sandbox_fallback: Option<bool>, + pub substituters: Option<Vec<&'a str>>, + pub system_features: Option<Vec<&'a str>>, + pub trusted_public_keys: Option<Vec<crate::narinfo::VerifyingKey>>, + pub trusted_substituters: Option<Vec<&'a str>>, + pub trusted_users: Option<Vec<&'a str>>, + pub extra_platforms: Option<Vec<&'a str>>, + pub extra_sandbox_paths: Option<Vec<&'a str>>, + pub experimental_features: Option<Vec<&'a str>>, + pub builders_use_substitutes: Option<bool>, +} + +impl<'a> NixConfig<'a> { + /// Parses configuration from a file like `/etc/nix/nix.conf`, returning + /// a [NixConfig] with all values contained in there. + /// It does not support parsing multiple config files, merging semantics, + /// and also does not understand `include` and `!include` statements. + pub fn parse(input: &'a str) -> Result<Self, Error> { + let mut out = Self::default(); + + for line in input.lines() { + // strip comments at the end of the line + let line = if let Some((line, _comment)) = line.split_once('#') { + line + } else { + line + }; + + // skip comments and empty lines + if line.trim().is_empty() { + continue; + } + + let (tag, val) = line + .split_once('=') + .ok_or_else(|| Error::InvalidLine(line.to_string()))?; + + // trim whitespace + let tag = tag.trim(); + let val = val.trim(); + + #[inline] + fn parse_val<'a>(this: &mut NixConfig<'a>, tag: &str, val: &'a str) -> Option<()> { + match tag { + "allowed-users" => { + this.allowed_users = Some(val.split_whitespace().collect()); + } + "auto-optimise-store" => { + this.auto_optimise_store = Some(val.parse::<bool>().ok()?); + } + "cores" => { + this.cores = Some(val.parse().ok()?); + } + "max-jobs" => { + this.max_jobs = Some(val.parse().ok()?); + } + "require-sigs" => { + this.require_sigs = Some(val.parse().ok()?); + } + "sandbox" => this.sandbox = Some(val.parse().ok()?), + "sandbox-fallback" => this.sandbox_fallback = Some(val.parse().ok()?), + "substituters" => this.substituters = Some(val.split_whitespace().collect()), + "system-features" => { + this.system_features = Some(val.split_whitespace().collect()) + } + "trusted-public-keys" => { + this.trusted_public_keys = Some( + val.split_whitespace() + .map(crate::narinfo::VerifyingKey::parse) + .collect::<Result<Vec<crate::narinfo::VerifyingKey>, _>>() + .ok()?, + ) + } + "trusted-substituters" => { + this.trusted_substituters = Some(val.split_whitespace().collect()) + } + "trusted-users" => this.trusted_users = Some(val.split_whitespace().collect()), + "extra-platforms" => { + this.extra_platforms = Some(val.split_whitespace().collect()) + } + "extra-sandbox-paths" => { + this.extra_sandbox_paths = Some(val.split_whitespace().collect()) + } + "experimental-features" => { + this.experimental_features = Some(val.split_whitespace().collect()) + } + "builders-use-substitutes" => { + this.builders_use_substitutes = Some(val.parse().ok()?) + } + _ => return None, + } + Some(()) + } + + parse_val(&mut out, tag, val) + .ok_or_else(|| Error::InvalidValue(tag.to_string(), val.to_string()))? + } + + Ok(out) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Invalid line: {0}")] + InvalidLine(String), + #[error("Unrecognized key: {0}")] + UnrecognizedKey(String), + #[error("Invalid value '{1}' for key '{0}'")] + InvalidValue(String, String), +} + +/// Valid values for the Nix 'sandbox' setting +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum SandboxSetting { + True, + False, + Relaxed, +} + +impl Display for SandboxSetting { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SandboxSetting::True => write!(f, "true"), + SandboxSetting::False => write!(f, "false"), + SandboxSetting::Relaxed => write!(f, "relaxed"), + } + } +} + +impl FromStr for SandboxSetting { + type Err = &'static str; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + Ok(match s { + "true" => Self::True, + "false" => Self::False, + "relaxed" => Self::Relaxed, + _ => return Err("invalid value"), + }) + } +} + +#[cfg(test)] +mod tests { + use crate::{narinfo::VerifyingKey, nixcpp::conf::SandboxSetting}; + + use super::NixConfig; + + #[test] + pub fn test_parse() { + let config = NixConfig::parse(include_str!("../../testdata/nix.conf")).expect("must parse"); + + assert_eq!( + NixConfig { + allowed_users: Some(vec!["*"]), + auto_optimise_store: Some(false), + cores: Some(0), + max_jobs: Some(8), + require_sigs: Some(true), + sandbox: Some(SandboxSetting::True), + sandbox_fallback: Some(false), + substituters: Some(vec!["https://nix-community.cachix.org", "https://cache.nixos.org/"]), + system_features: Some(vec!["nixos-test", "benchmark", "big-parallel", "kvm"]), + trusted_public_keys: Some(vec![ + VerifyingKey::parse("cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY=") + .expect("failed to parse pubkey"), + VerifyingKey::parse("nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=") + .expect("failed to parse pubkey") + ]), + trusted_substituters: Some(vec![]), + trusted_users: Some(vec!["flokli"]), + extra_platforms: Some(vec!["aarch64-linux", "i686-linux"]), + extra_sandbox_paths: Some(vec![ + "/run/binfmt", "/nix/store/swwyxyqpazzvbwx8bv40z7ih144q841f-qemu-aarch64-binfmt-P-x86_64-unknown-linux-musl" + ]), + experimental_features: Some(vec!["nix-command"]), + builders_use_substitutes: Some(true) + }, + config + ); + + // parse a config file using some non-space whitespaces, as well as comments right after the lines. + // ensure it contains the same data as initially parsed. + let other_config = NixConfig::parse(include_str!("../../testdata/other_nix.conf")) + .expect("other config must parse"); + + assert_eq!(config, other_config); + } +} diff --git a/tvix/nix-compat/src/nixcpp/mod.rs b/tvix/nix-compat/src/nixcpp/mod.rs new file mode 100644 index 000000000000..57518de8cc52 --- /dev/null +++ b/tvix/nix-compat/src/nixcpp/mod.rs @@ -0,0 +1,9 @@ +//! Contains code parsing some of the Nixcpp config files etc. +//! left by Nix *on the local disk*. +//! +//! This is only for Nix' own state/config. +//! +//! More "standardized" protocols, like parts of the Nix HTTP Binary Cache +//! protocol live elsewhere. + +pub mod conf; diff --git a/tvix/nix-compat/src/nixhash/ca_hash.rs b/tvix/nix-compat/src/nixhash/ca_hash.rs index 2bf5f966cefe..e6cbaf5b710a 100644 --- a/tvix/nix-compat/src/nixhash/ca_hash.rs +++ b/tvix/nix-compat/src/nixhash/ca_hash.rs @@ -47,12 +47,33 @@ impl CAHash { } } + /// Returns a colon-separated string consisting of mode, recursiveness and + /// hash algo. Used as a prefix in various string representations. + pub fn algo_str(&self) -> &'static str { + match self.mode() { + HashMode::Flat => match self.hash().as_ref() { + NixHash::Md5(_) => "fixed:md5", + NixHash::Sha1(_) => "fixed:sha1", + NixHash::Sha256(_) => "fixed:sha256", + NixHash::Sha512(_) => "fixed:sha512", + }, + HashMode::Nar => match self.hash().as_ref() { + NixHash::Md5(_) => "fixed:r:md5", + NixHash::Sha1(_) => "fixed:r:sha1", + NixHash::Sha256(_) => "fixed:r:sha256", + NixHash::Sha512(_) => "fixed:r:sha512", + }, + HashMode::Text => "text:sha256", + } + } + /// Constructs a [CAHash] from the textual representation, /// which is one of the three: /// - `text:sha256:$nixbase32sha256digest` /// - `fixed:r:$algo:$nixbase32digest` /// - `fixed:$algo:$nixbase32digest` - /// which is the format that's used in the NARInfo for example. + /// + /// These formats are used in NARInfo, for example. pub fn from_nix_hex_str(s: &str) -> Option<Self> { let (tag, s) = s.split_once(':')?; @@ -76,13 +97,11 @@ impl CAHash { /// Formats a [CAHash] in the Nix default hash format, which is the format /// that's used in NARInfos for example. pub fn to_nix_nixbase32_string(&self) -> String { - match self { - CAHash::Flat(nh) => format!("fixed:{}", nh.to_nix_nixbase32_string()), - CAHash::Nar(nh) => format!("fixed:r:{}", nh.to_nix_nixbase32_string()), - CAHash::Text(digest) => { - format!("text:sha256:{}", nixbase32::encode(digest)) - } - } + format!( + "{}:{}", + self.algo_str(), + nixbase32::encode(self.hash().digest_as_bytes()) + ) } /// This takes a serde_json::Map and turns it into this structure. This is necessary to do such @@ -90,11 +109,13 @@ impl CAHash { /// know whether we have a invalid or a missing NixHashWithMode structure in another structure, /// e.g. Output. /// This means we have this combinatorial situation: + /// /// - no hash, no hashAlgo: no [CAHash] so we return Ok(None). /// - present hash, missing hashAlgo: invalid, we will return missing_field /// - missing hash, present hashAlgo: same /// - present hash, present hashAlgo: either we return ourselves or a type/value validation - /// error. + /// error. + /// /// This function is for internal consumption regarding those needs until we have a better /// solution. Now this is said, let's explain how this works. /// diff --git a/tvix/nix-compat/src/path_info.rs b/tvix/nix-compat/src/path_info.rs index f289ebde338c..63512805fe09 100644 --- a/tvix/nix-compat/src/path_info.rs +++ b/tvix/nix-compat/src/path_info.rs @@ -1,4 +1,4 @@ -use crate::{nixbase32, nixhash::NixHash, store_path::StorePathRef}; +use crate::{narinfo::SignatureRef, nixbase32, nixhash::NixHash, store_path::StorePathRef}; use serde::{Deserialize, Serialize}; use std::collections::BTreeSet; @@ -15,7 +15,7 @@ pub struct ExportedPathInfo<'a> { #[serde( rename = "narHash", serialize_with = "to_nix_nixbase32_string", - deserialize_with = "from_nix_nixbase32_string" + deserialize_with = "from_nix_hash_string" )] pub nar_sha256: [u8; 32], @@ -25,11 +25,17 @@ pub struct ExportedPathInfo<'a> { #[serde(borrow)] pub path: StorePathRef<'a>, + #[serde(borrow)] + #[serde(skip_serializing_if = "Option::is_none")] + pub deriver: Option<StorePathRef<'a>>, + /// The list of other Store Paths this Store Path refers to. /// StorePathRef does Ord by the nixbase32-encoded string repr, so this is correct. pub references: BTreeSet<StorePathRef<'a>>, // more recent versions of Nix also have a `valid: true` field here, Nix 2.3 doesn't, // and nothing seems to use it. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub signatures: Vec<SignatureRef<'a>>, } /// ExportedPathInfo are ordered by their `path` field. @@ -56,18 +62,49 @@ where /// The length of a sha256 digest, nixbase32-encoded. const NIXBASE32_SHA256_ENCODE_LEN: usize = nixbase32::encode_len(32); -fn from_nix_nixbase32_string<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error> +fn from_nix_hash_string<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error> where D: serde::Deserializer<'de>, { let str: &'de str = Deserialize::deserialize(deserializer)?; + if let Some(digest_str) = str.strip_prefix("sha256:") { + return from_nix_nixbase32_string::<D>(digest_str); + } + if let Some(digest_str) = str.strip_prefix("sha256-") { + return from_sri_string::<D>(digest_str); + } + Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(str), + &"extected a valid nixbase32 or sri narHash", + )) +} - let digest_str = str.strip_prefix("sha256:").ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Str(str), &"sha256:…") - })?; +fn from_sri_string<'de, D>(str: &str) -> Result<[u8; 32], D::Error> +where + D: serde::Deserializer<'de>, +{ + let digest: [u8; 32] = data_encoding::BASE64 + .decode(str.as_bytes()) + .map_err(|_| { + serde::de::Error::invalid_value( + serde::de::Unexpected::Str(str), + &"valid base64 encoded string", + ) + })? + .try_into() + .map_err(|_| { + serde::de::Error::invalid_value(serde::de::Unexpected::Str(str), &"valid digest len") + })?; + Ok(digest) +} + +fn from_nix_nixbase32_string<'de, D>(str: &str) -> Result<[u8; 32], D::Error> +where + D: serde::Deserializer<'de>, +{ let digest_str: [u8; NIXBASE32_SHA256_ENCODE_LEN] = - digest_str.as_bytes().try_into().map_err(|_| { + str.as_bytes().try_into().map_err(|_| { serde::de::Error::invalid_value(serde::de::Unexpected::Str(str), &"valid digest len") })?; @@ -110,10 +147,49 @@ mod tests { b"7n0mbqydcipkpbxm24fab066lxk68aqk-libunistring-1.1" ) .expect("must parse"), + deriver: None, references: BTreeSet::from_iter([StorePathRef::from_bytes( b"7n0mbqydcipkpbxm24fab066lxk68aqk-libunistring-1.1" ) .unwrap()]), + signatures: vec![], + }, + deserialized.first().unwrap() + ); + } + + /// Ensure we can parse output from `nix path-info --json`` + #[test] + fn serialize_deserialize_from_path_info() { + // JSON extracted from + // nix path-info /nix/store/z6r3bn5l51679pwkvh9nalp6c317z34m-libcxx-16.0.6-dev --json --closure-size + let pathinfos_str_json = r#"[{"closureSize":10756176,"deriver":"/nix/store/vs9976cyyxpykvdnlv7x85fpp3shn6ij-libcxx-16.0.6.drv","narHash":"sha256-E73Nt0NAKGxCnsyBFDUaCAbA+wiF5qjq1O9J7WrnT0E=","narSize":7020664,"path":"/nix/store/z6r3bn5l51679pwkvh9nalp6c317z34m-libcxx-16.0.6-dev","references":["/nix/store/lzzd5jgybnpfj86xkcpnd54xgwc4m457-libcxx-16.0.6"],"registrationTime":1730048276,"signatures":["cache.nixos.org-1:cTdhK6hnpPwtMXFX43CYb7v+CbpAusVI/MORZ3v5aHvpBYNg1MfBHVVeoexMBpNtHA8uFAn0aEsJaLXYIDhJDg=="],"valid":true}]"#; + + let deserialized: BTreeSet<ExportedPathInfo> = + serde_json::from_str(pathinfos_str_json).expect("must serialize"); + + assert_eq!( + &ExportedPathInfo { + closure_size: 10756176, + nar_sha256: hex!( + "13bdcdb74340286c429ecc8114351a0806c0fb0885e6a8ead4ef49ed6ae74f41" + ), + nar_size: 7020664, + path: StorePathRef::from_bytes( + b"z6r3bn5l51679pwkvh9nalp6c317z34m-libcxx-16.0.6-dev" + ) + .expect("must parse"), + deriver: Some( + StorePathRef::from_bytes( + b"vs9976cyyxpykvdnlv7x85fpp3shn6ij-libcxx-16.0.6.drv" + ) + .expect("must parse") + ), + references: BTreeSet::from_iter([StorePathRef::from_bytes( + b"lzzd5jgybnpfj86xkcpnd54xgwc4m457-libcxx-16.0.6" + ) + .unwrap()]), + signatures: vec![SignatureRef::parse("cache.nixos.org-1:cTdhK6hnpPwtMXFX43CYb7v+CbpAusVI/MORZ3v5aHvpBYNg1MfBHVVeoexMBpNtHA8uFAn0aEsJaLXYIDhJDg==").expect("must parse")], }, deserialized.first().unwrap() ); diff --git a/tvix/nix-compat/src/store_path/mod.rs b/tvix/nix-compat/src/store_path/mod.rs index ff7ede77e1da..13265048641b 100644 --- a/tvix/nix-compat/src/store_path/mod.rs +++ b/tvix/nix-compat/src/store_path/mod.rs @@ -3,14 +3,11 @@ use data_encoding::{DecodeError, BASE64}; use serde::{Deserialize, Serialize}; use std::{ fmt, - path::PathBuf, + path::Path, str::{self, FromStr}, }; use thiserror; -#[cfg(target_family = "unix")] -use std::os::unix::ffi::OsStringExt; - mod utils; pub use utils::*; @@ -53,181 +50,111 @@ pub enum Error { /// /// A [StorePath] does not encode any additional subpath "inside" the store /// path. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct StorePath { +#[derive(Clone, Debug)] +pub struct StorePath<S> { digest: [u8; DIGEST_SIZE], - name: String, + name: S, } -impl StorePath { - pub fn digest(&self) -> &[u8; DIGEST_SIZE] { - &self.digest - } - - pub fn name(&self) -> &str { - self.name.as_ref() - } - - pub fn as_ref(&self) -> StorePathRef<'_> { - StorePathRef { - digest: self.digest, - name: &self.name, - } +impl<S> PartialEq for StorePath<S> +where + S: AsRef<str>, +{ + fn eq(&self, other: &Self) -> bool { + self.digest() == other.digest() && self.name().as_ref() == other.name().as_ref() } } -impl PartialOrd for StorePath { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - Some(self.cmp(other)) - } -} +impl<S> Eq for StorePath<S> where S: AsRef<str> {} -/// `StorePath`s are sorted by their reverse digest to match the sorting order -/// of the nixbase32-encoded string. -impl Ord for StorePath { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.as_ref().cmp(&other.as_ref()) +impl<S> std::hash::Hash for StorePath<S> +where + S: AsRef<str>, +{ + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + state.write(&self.digest); + state.write(self.name.as_ref().as_bytes()); } } -impl FromStr for StorePath { - type Err = Error; +/// Like [StorePath], but without a heap allocation for the name. +/// Used by [StorePath] for parsing. +pub type StorePathRef<'a> = StorePath<&'a str>; - /// Construct a [StorePath] by passing the `$digest-$name` string - /// that comes after [STORE_DIR_WITH_SLASH]. - fn from_str(s: &str) -> Result<Self, Self::Err> { - Self::from_bytes(s.as_bytes()) +impl<S> StorePath<S> +where + S: AsRef<str>, +{ + pub fn digest(&self) -> &[u8; DIGEST_SIZE] { + &self.digest } -} -impl StorePath { - /// Construct a [StorePath] by passing the `$digest-$name` string - /// that comes after [STORE_DIR_WITH_SLASH]. - pub fn from_bytes(s: &[u8]) -> Result<StorePath, Error> { - Ok(StorePathRef::from_bytes(s)?.to_owned()) + pub fn name(&self) -> &S { + &self.name } - /// Decompose a string into a [StorePath] and a [PathBuf] containing the - /// rest of the path, or an error. - #[cfg(target_family = "unix")] - pub fn from_absolute_path_full(s: &str) -> Result<(StorePath, PathBuf), Error> { - // strip [STORE_DIR_WITH_SLASH] from s - match s.strip_prefix(STORE_DIR_WITH_SLASH) { - None => Err(Error::MissingStoreDir), - Some(rest) => { - // put rest in a PathBuf - let mut p = PathBuf::new(); - p.push(rest); - - let mut it = p.components(); - - // The first component of the rest must be parse-able as a [StorePath] - if let Some(first_component) = it.next() { - // convert first component to StorePath - let first_component_bytes = first_component.as_os_str().to_owned().into_vec(); - let store_path = StorePath::from_bytes(&first_component_bytes)?; - // collect rest - let rest_buf: PathBuf = it.collect(); - Ok((store_path, rest_buf)) - } else { - Err(Error::InvalidLength) // Well, or missing "/"? - } - } + pub fn as_ref(&self) -> StorePathRef<'_> { + StorePathRef { + digest: self.digest, + name: self.name.as_ref(), } } - /// Returns an absolute store path string. - /// That is just the string representation, prefixed with the store prefix - /// ([STORE_DIR_WITH_SLASH]), - pub fn to_absolute_path(&self) -> String { - let sp_ref: StorePathRef = self.into(); - sp_ref.to_absolute_path() - } -} - -impl<'de> Deserialize<'de> for StorePath { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: serde::Deserializer<'de>, - { - let r = <StorePathRef<'de> as Deserialize<'de>>::deserialize(deserializer)?; - Ok(r.to_owned()) + pub fn to_owned(&self) -> StorePath<String> { + StorePath { + digest: self.digest, + name: self.name.as_ref().to_string(), + } } -} -impl Serialize for StorePath { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + /// Construct a [StorePath] by passing the `$digest-$name` string + /// that comes after [STORE_DIR_WITH_SLASH]. + pub fn from_bytes<'a>(s: &'a [u8]) -> Result<Self, Error> where - S: serde::Serializer, + S: From<&'a str>, { - let r: StorePathRef = self.into(); - r.serialize(serializer) - } -} - -/// Like [StorePath], but without a heap allocation for the name. -/// Used by [StorePath] for parsing. -/// -#[derive(Debug, Eq, PartialEq, Clone, Copy, Hash)] -pub struct StorePathRef<'a> { - digest: [u8; DIGEST_SIZE], - name: &'a str, -} - -impl<'a> From<&'a StorePath> for StorePathRef<'a> { - fn from(&StorePath { digest, ref name }: &'a StorePath) -> Self { - StorePathRef { - digest, - name: name.as_ref(), + // the whole string needs to be at least: + // + // - 32 characters (encoded hash) + // - 1 dash + // - 1 character for the name + if s.len() < ENCODED_DIGEST_SIZE + 2 { + Err(Error::InvalidLength)? } - } -} -impl<'a> PartialOrd for StorePathRef<'a> { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - Some(self.cmp(other)) - } -} - -/// `StorePathRef`s are sorted by their reverse digest to match the sorting order -/// of the nixbase32-encoded string. -impl<'a> Ord for StorePathRef<'a> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.digest.iter().rev().cmp(other.digest.iter().rev()) - } -} - -impl<'a> StorePathRef<'a> { - pub fn digest(&self) -> &[u8; DIGEST_SIZE] { - &self.digest - } - - pub fn name(&self) -> &'a str { - self.name - } + let digest = nixbase32::decode_fixed(&s[..ENCODED_DIGEST_SIZE])?; - pub fn to_owned(&self) -> StorePath { - StorePath { - digest: self.digest, - name: self.name.to_owned(), + if s[ENCODED_DIGEST_SIZE] != b'-' { + return Err(Error::MissingDash); } + + Ok(StorePath { + digest, + name: validate_name(&s[ENCODED_DIGEST_SIZE + 1..])?.into(), + }) } /// Construct a [StorePathRef] from a name and digest. /// The name is validated, and the digest checked for size. - pub fn from_name_and_digest(name: &'a str, digest: &[u8]) -> Result<Self, Error> { + pub fn from_name_and_digest<'a>(name: &'a str, digest: &[u8]) -> Result<Self, Error> + where + S: From<&'a str>, + { let digest_fixed = digest.try_into().map_err(|_| Error::InvalidLength)?; Self::from_name_and_digest_fixed(name, digest_fixed) } /// Construct a [StorePathRef] from a name and digest of correct length. /// The name is validated. - pub fn from_name_and_digest_fixed( + pub fn from_name_and_digest_fixed<'a>( name: &'a str, digest: [u8; DIGEST_SIZE], - ) -> Result<Self, Error> { + ) -> Result<Self, Error> + where + S: From<&'a str>, + { Ok(Self { - name: validate_name(name.as_bytes())?, + name: validate_name(name)?.into(), digest, }) } @@ -235,35 +162,40 @@ impl<'a> StorePathRef<'a> { /// Construct a [StorePathRef] from an absolute store path string. /// This is equivalent to calling [StorePathRef::from_bytes], but stripping /// the [STORE_DIR_WITH_SLASH] prefix before. - pub fn from_absolute_path(s: &'a [u8]) -> Result<Self, Error> { + pub fn from_absolute_path<'a>(s: &'a [u8]) -> Result<Self, Error> + where + S: From<&'a str>, + { match s.strip_prefix(STORE_DIR_WITH_SLASH.as_bytes()) { Some(s_stripped) => Self::from_bytes(s_stripped), None => Err(Error::MissingStoreDir), } } - /// Construct a [StorePathRef] by passing the `$digest-$name` string - /// that comes after [STORE_DIR_WITH_SLASH]. - pub fn from_bytes(s: &'a [u8]) -> Result<Self, Error> { - // the whole string needs to be at least: - // - // - 32 characters (encoded hash) - // - 1 dash - // - 1 character for the name - if s.len() < ENCODED_DIGEST_SIZE + 2 { - Err(Error::InvalidLength)? - } + /// Decompose a string into a [StorePath] and a [PathBuf] containing the + /// rest of the path, or an error. + #[cfg(target_family = "unix")] + pub fn from_absolute_path_full<'a, P>(path: &'a P) -> Result<(Self, &'a Path), Error> + where + S: From<&'a str>, + P: AsRef<std::path::Path> + ?Sized, + { + // strip [STORE_DIR_WITH_SLASH] from s + let p = path + .as_ref() + .strip_prefix(STORE_DIR_WITH_SLASH) + .map_err(|_e| Error::MissingStoreDir)?; - let digest = nixbase32::decode_fixed(&s[..ENCODED_DIGEST_SIZE])?; + let mut it = Path::new(p).components(); - if s[ENCODED_DIGEST_SIZE] != b'-' { - return Err(Error::MissingDash); - } + // The first component of the rest must be parse-able as a [StorePath] + let first_component = it.next().ok_or(Error::InvalidLength)?; + let store_path = StorePath::from_bytes(first_component.as_os_str().as_encoded_bytes())?; - Ok(StorePathRef { - digest, - name: validate_name(&s[ENCODED_DIGEST_SIZE + 1..])?, - }) + // collect rest + let rest_buf = it.as_path(); + + Ok((store_path, rest_buf)) } /// Returns an absolute store path string. @@ -274,7 +206,40 @@ impl<'a> StorePathRef<'a> { } } -impl<'de: 'a, 'a> Deserialize<'de> for StorePathRef<'a> { +impl<S> PartialOrd for StorePath<S> +where + S: AsRef<str>, +{ + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +/// `StorePath`s are sorted by their reverse digest to match the sorting order +/// of the nixbase32-encoded string. +impl<S> Ord for StorePath<S> +where + S: AsRef<str>, +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.digest.iter().rev().cmp(other.digest.iter().rev()) + } +} + +impl<'a, 'b: 'a> FromStr for StorePath<String> { + type Err = Error; + + /// Construct a [StorePath] by passing the `$digest-$name` string + /// that comes after [STORE_DIR_WITH_SLASH]. + fn from_str(s: &str) -> Result<Self, Self::Err> { + StorePath::<String>::from_bytes(s.as_bytes()) + } +} + +impl<'a, 'de: 'a, S> Deserialize<'de> for StorePath<S> +where + S: AsRef<str> + From<&'a str>, +{ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: serde::Deserializer<'de>, @@ -287,16 +252,19 @@ impl<'de: 'a, 'a> Deserialize<'de> for StorePathRef<'a> { &"store path prefix", ) })?; - StorePathRef::from_bytes(stripped.as_bytes()).map_err(|_| { + StorePath::from_bytes(stripped.as_bytes()).map_err(|_| { serde::de::Error::invalid_value(serde::de::Unexpected::Str(string), &"StorePath") }) } } -impl Serialize for StorePathRef<'_> { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> +impl<S> Serialize for StorePath<S> +where + S: AsRef<str>, +{ + fn serialize<SR>(&self, serializer: SR) -> Result<SR::Ok, SR::Error> where - S: serde::Serializer, + SR: serde::Serializer, { let string: String = self.to_absolute_path(); string.serialize(serializer) @@ -350,18 +318,20 @@ pub(crate) fn validate_name(s: &(impl AsRef<[u8]> + ?Sized)) -> Result<&str, Err Ok(unsafe { str::from_utf8_unchecked(s) }) } -impl fmt::Display for StorePath { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - StorePathRef::from(self).fmt(f) - } -} - -impl fmt::Display for StorePathRef<'_> { +impl<S> fmt::Display for StorePath<S> +where + S: AsRef<str>, +{ /// The string representation of a store path starts with a digest (20 /// bytes), [crate::nixbase32]-encoded, followed by a `-`, /// and ends with the name. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}-{}", nixbase32::encode(&self.digest), self.name) + write!( + f, + "{}-{}", + nixbase32::encode(&self.digest), + self.name.as_ref() + ) } } @@ -389,12 +359,12 @@ mod tests { fn happy_path() { let example_nix_path_str = "00bgd045z0d4icpbc2yyz4gx48ak44la-net-tools-1.60_p20170221182432"; - let nixpath = StorePath::from_bytes(example_nix_path_str.as_bytes()) + let nixpath = StorePathRef::from_bytes(example_nix_path_str.as_bytes()) .expect("Error parsing example string"); let expected_digest: [u8; DIGEST_SIZE] = hex!("8a12321522fd91efbd60ebb2481af88580f61600"); - assert_eq!("net-tools-1.60_p20170221182432", nixpath.name); + assert_eq!("net-tools-1.60_p20170221182432", *nixpath.name()); assert_eq!(nixpath.digest, expected_digest); assert_eq!(example_nix_path_str, nixpath.to_string()) @@ -429,8 +399,8 @@ mod tests { if w.len() < 2 { continue; } - let (pa, _) = StorePath::from_absolute_path_full(w[0]).expect("parseable"); - let (pb, _) = StorePath::from_absolute_path_full(w[1]).expect("parseable"); + let (pa, _) = StorePathRef::from_absolute_path_full(w[0]).expect("parseable"); + let (pb, _) = StorePathRef::from_absolute_path_full(w[1]).expect("parseable"); assert_eq!( Ordering::Less, pa.cmp(&pb), @@ -451,36 +421,38 @@ mod tests { /// https://github.com/NixOS/nix/pull/9867 (revert-of-revert) #[test] fn starts_with_dot() { - StorePath::from_bytes(b"fli4bwscgna7lpm7v5xgnjxrxh0yc7ra-.gitignore") + StorePathRef::from_bytes(b"fli4bwscgna7lpm7v5xgnjxrxh0yc7ra-.gitignore") .expect("must succeed"); } #[test] fn empty_name() { - StorePath::from_bytes(b"00bgd045z0d4icpbc2yy-").expect_err("must fail"); + StorePathRef::from_bytes(b"00bgd045z0d4icpbc2yy-").expect_err("must fail"); } #[test] fn excessive_length() { - StorePath::from_bytes(b"00bgd045z0d4icpbc2yy-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + StorePathRef::from_bytes(b"00bgd045z0d4icpbc2yy-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") .expect_err("must fail"); } #[test] fn invalid_hash_length() { - StorePath::from_bytes(b"00bgd045z0d4icpbc2yy-net-tools-1.60_p20170221182432") + StorePathRef::from_bytes(b"00bgd045z0d4icpbc2yy-net-tools-1.60_p20170221182432") .expect_err("must fail"); } #[test] fn invalid_encoding_hash() { - StorePath::from_bytes(b"00bgd045z0d4icpbc2yyz4gx48aku4la-net-tools-1.60_p20170221182432") - .expect_err("must fail"); + StorePathRef::from_bytes( + b"00bgd045z0d4icpbc2yyz4gx48aku4la-net-tools-1.60_p20170221182432", + ) + .expect_err("must fail"); } #[test] fn more_than_just_the_bare_nix_store_path() { - StorePath::from_bytes( + StorePathRef::from_bytes( b"00bgd045z0d4icpbc2yyz4gx48aku4la-net-tools-1.60_p20170221182432/bin/arp", ) .expect_err("must fail"); @@ -488,7 +460,7 @@ mod tests { #[test] fn no_dash_between_hash_and_name() { - StorePath::from_bytes(b"00bgd045z0d4icpbc2yyz4gx48ak44lanet-tools-1.60_p20170221182432") + StorePathRef::from_bytes(b"00bgd045z0d4icpbc2yyz4gx48ak44lanet-tools-1.60_p20170221182432") .expect_err("must fail"); } @@ -537,7 +509,7 @@ mod tests { #[test] fn serialize_owned() { - let nixpath_actual = StorePath::from_bytes( + let nixpath_actual = StorePathRef::from_bytes( b"00bgd045z0d4icpbc2yyz4gx48ak44la-net-tools-1.60_p20170221182432", ) .expect("can parse"); @@ -581,7 +553,8 @@ mod tests { let store_path_str_json = "\"/nix/store/00bgd045z0d4icpbc2yyz4gx48ak44la-net-tools-1.60_p20170221182432\""; - let store_path: StorePath = serde_json::from_str(store_path_str_json).expect("valid json"); + let store_path: StorePath<String> = + serde_json::from_str(store_path_str_json).expect("valid json"); assert_eq!( "/nix/store/00bgd045z0d4icpbc2yyz4gx48ak44la-net-tools-1.60_p20170221182432", @@ -604,7 +577,7 @@ mod tests { StorePath::from_bytes(b"00bgd045z0d4icpbc2yyz4gx48ak44la-net-tools-1.60_p20170221182432").unwrap(), PathBuf::from("bin/arp/"))] fn from_absolute_path_full( #[case] s: &str, - #[case] exp_store_path: StorePath, + #[case] exp_store_path: StorePath<&str>, #[case] exp_path: PathBuf, ) { let (actual_store_path, actual_path) = @@ -618,15 +591,15 @@ mod tests { fn from_absolute_path_errors() { assert_eq!( Error::InvalidLength, - StorePath::from_absolute_path_full("/nix/store/").expect_err("must fail") + StorePathRef::from_absolute_path_full("/nix/store/").expect_err("must fail") ); assert_eq!( Error::InvalidLength, - StorePath::from_absolute_path_full("/nix/store/foo").expect_err("must fail") + StorePathRef::from_absolute_path_full("/nix/store/foo").expect_err("must fail") ); assert_eq!( Error::MissingStoreDir, - StorePath::from_absolute_path_full( + StorePathRef::from_absolute_path_full( "00bgd045z0d4icpbc2yyz4gx48ak44la-net-tools-1.60_p20170221182432" ) .expect_err("must fail") diff --git a/tvix/nix-compat/src/store_path/utils.rs b/tvix/nix-compat/src/store_path/utils.rs index d6f390db85c2..63b6969464d8 100644 --- a/tvix/nix-compat/src/store_path/utils.rs +++ b/tvix/nix-compat/src/store_path/utils.rs @@ -1,6 +1,6 @@ use crate::nixbase32; use crate::nixhash::{CAHash, NixHash}; -use crate::store_path::{Error, StorePathRef, STORE_DIR}; +use crate::store_path::{Error, StorePath, STORE_DIR}; use data_encoding::HEXLOWER; use sha2::{Digest, Sha256}; use thiserror; @@ -43,11 +43,17 @@ pub fn compress_hash<const OUTPUT_SIZE: usize>(input: &[u8]) -> [u8; OUTPUT_SIZE /// derivation or a literal text file that may contain references. /// If you don't want to have to pass the entire contents, you might want to use /// [build_ca_path] instead. -pub fn build_text_path<S: AsRef<str>, I: IntoIterator<Item = S>, C: AsRef<[u8]>>( - name: &str, +pub fn build_text_path<'a, S, SP, I, C>( + name: &'a str, content: C, references: I, -) -> Result<StorePathRef<'_>, BuildStorePathError> { +) -> Result<StorePath<SP>, BuildStorePathError> +where + S: AsRef<str>, + SP: AsRef<str> + std::convert::From<&'a str>, + I: IntoIterator<Item = S>, + C: AsRef<[u8]>, +{ // produce the sha256 digest of the contents let content_digest = Sha256::new_with_prefix(content).finalize().into(); @@ -55,12 +61,17 @@ pub fn build_text_path<S: AsRef<str>, I: IntoIterator<Item = S>, C: AsRef<[u8]>> } /// This builds a store path from a [CAHash] and a list of references. -pub fn build_ca_path<'a, S: AsRef<str>, I: IntoIterator<Item = S>>( +pub fn build_ca_path<'a, S, SP, I>( name: &'a str, ca_hash: &CAHash, references: I, self_reference: bool, -) -> Result<StorePathRef<'a>, BuildStorePathError> { +) -> Result<StorePath<SP>, BuildStorePathError> +where + S: AsRef<str>, + SP: AsRef<str> + std::convert::From<&'a str>, + I: IntoIterator<Item = S>, +{ // self references are only allowed for CAHash::Nar(NixHash::Sha256(_)). if self_reference && matches!(ca_hash, CAHash::Nar(NixHash::Sha256(_))) { return Err(BuildStorePathError::InvalidReference()); @@ -108,26 +119,18 @@ pub fn build_ca_path<'a, S: AsRef<str>, I: IntoIterator<Item = S>>( .map_err(BuildStorePathError::InvalidStorePath) } -/// For given NAR sha256 digest and name, return the new [StorePathRef] this -/// would have, or an error, in case the name is invalid. -pub fn build_nar_based_store_path<'a>( - nar_sha256_digest: &[u8; 32], - name: &'a str, -) -> Result<StorePathRef<'a>, BuildStorePathError> { - let nar_hash_with_mode = CAHash::Nar(NixHash::Sha256(nar_sha256_digest.to_owned())); - - build_ca_path(name, &nar_hash_with_mode, Vec::<String>::new(), false) -} - /// This builds an input-addressed store path. /// /// Input-addresed store paths are always derivation outputs, the "input" in question is the /// derivation and its closure. -pub fn build_output_path<'a>( +pub fn build_output_path<'a, SP>( drv_sha256: &[u8; 32], output_name: &str, output_path_name: &'a str, -) -> Result<StorePathRef<'a>, Error> { +) -> Result<StorePath<SP>, Error> +where + SP: AsRef<str> + std::convert::From<&'a str>, +{ build_store_path_from_fingerprint_parts( &(String::from("output:") + output_name), drv_sha256, @@ -145,17 +148,20 @@ pub fn build_output_path<'a>( /// bytes. /// Inside a StorePath, that digest is printed nixbase32-encoded /// (32 characters). -fn build_store_path_from_fingerprint_parts<'a>( +fn build_store_path_from_fingerprint_parts<'a, SP>( ty: &str, inner_digest: &[u8; 32], name: &'a str, -) -> Result<StorePathRef<'a>, Error> { +) -> Result<StorePath<SP>, Error> +where + SP: AsRef<str> + std::convert::From<&'a str>, +{ let fingerprint = format!( "{ty}:sha256:{}:{STORE_DIR}:{name}", HEXLOWER.encode(inner_digest) ); // name validation happens in here. - StorePathRef::from_name_and_digest_fixed( + StorePath::from_name_and_digest_fixed( name, compress_hash(&Sha256::new_with_prefix(fingerprint).finalize()), ) @@ -207,7 +213,10 @@ mod test { use hex_literal::hex; use super::*; - use crate::nixhash::{CAHash, NixHash}; + use crate::{ + nixhash::{CAHash, NixHash}, + store_path::StorePathRef, + }; #[test] fn build_text_path_with_zero_references() { @@ -216,7 +225,7 @@ mod test { // nix-repl> builtins.toFile "foo" "bar" // "/nix/store/vxjiwkjkn7x4079qvh1jkl5pn05j2aw0-foo" - let store_path = build_text_path("foo", "bar", Vec::<String>::new()) + let store_path: StorePathRef = build_text_path("foo", "bar", Vec::<String>::new()) .expect("build_store_path() should succeed"); assert_eq!( @@ -232,11 +241,11 @@ mod test { // nix-repl> builtins.toFile "baz" "${builtins.toFile "foo" "bar"}" // "/nix/store/5xd714cbfnkz02h2vbsj4fm03x3f15nf-baz" - let inner = build_text_path("foo", "bar", Vec::<String>::new()) + let inner: StorePathRef = build_text_path("foo", "bar", Vec::<String>::new()) .expect("path_with_references() should succeed"); let inner_path = inner.to_absolute_path(); - let outer = build_text_path("baz", &inner_path, vec![inner_path.as_str()]) + let outer: StorePathRef = build_text_path("baz", &inner_path, vec![inner_path.as_str()]) .expect("path_with_references() should succeed"); assert_eq!( @@ -247,7 +256,7 @@ mod test { #[test] fn build_sha1_path() { - let outer = build_ca_path( + let outer: StorePathRef = build_ca_path( "bar", &CAHash::Nar(NixHash::Sha1(hex!( "0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33" @@ -272,7 +281,7 @@ mod test { // // $ nix store make-content-addressed /nix/store/5xd714cbfnkz02h2vbsj4fm03x3f15nf-baz // rewrote '/nix/store/5xd714cbfnkz02h2vbsj4fm03x3f15nf-baz' to '/nix/store/s89y431zzhmdn3k8r96rvakryddkpv2v-baz' - let outer = build_ca_path( + let outer: StorePathRef = build_ca_path( "baz", &CAHash::Nar(NixHash::Sha256( nixbase32::decode(b"1xqkzcb3909fp07qngljr4wcdnrh1gdam1m2n29i6hhrxlmkgkv1") diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs index ef0b59def8b9..9b981fbbd2c0 100644 --- a/tvix/nix-compat/src/wire/bytes/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/mod.rs @@ -1,8 +1,12 @@ +#[cfg(feature = "async")] +use std::mem::MaybeUninit; use std::{ io::{Error, ErrorKind}, ops::RangeInclusive, }; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +#[cfg(feature = "async")] +use tokio::io::ReadBuf; +use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; pub(crate) mod reader; pub use reader::BytesReader; @@ -10,12 +14,11 @@ mod writer; pub use writer::BytesWriter; /// 8 null bytes, used to write out padding. -const EMPTY_BYTES: &[u8; 8] = &[0u8; 8]; +pub(crate) const EMPTY_BYTES: &[u8; 8] = &[0u8; 8]; /// The length of the size field, in bytes is always 8. const LEN_SIZE: usize = 8; -#[allow(dead_code)] /// Read a "bytes wire packet" from the AsyncRead. /// Rejects reading more than `allowed_size` bytes of payload. /// @@ -33,12 +36,9 @@ const LEN_SIZE: usize = 8; /// /// This buffers the entire payload into memory, /// a streaming version is available at [crate::wire::bytes::BytesReader]. -pub async fn read_bytes<R>( - r: &mut R, - allowed_size: RangeInclusive<usize>, -) -> std::io::Result<Vec<u8>> +pub async fn read_bytes<R>(r: &mut R, allowed_size: RangeInclusive<usize>) -> io::Result<Vec<u8>> where - R: AsyncReadExt + Unpin, + R: AsyncReadExt + Unpin + ?Sized, { // read the length field let len = r.read_u64_le().await?; @@ -47,8 +47,8 @@ where .ok() .filter(|len| allowed_size.contains(len)) .ok_or_else(|| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, + io::Error::new( + io::ErrorKind::InvalidData, "signalled package size not in allowed range", ) })?; @@ -64,15 +64,15 @@ where // make sure we got exactly the number of bytes, and not less. if s as u64 != padded_len { - return Err(std::io::ErrorKind::UnexpectedEof.into()); + return Err(io::ErrorKind::UnexpectedEof.into()); } let (_content, padding) = buf.split_at(len); // ensure the padding is all zeroes. - if !padding.iter().all(|e| *e == b'\0') { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, + if padding.iter().any(|&b| b != 0) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, "padding is not all zeroes", )); } @@ -82,13 +82,69 @@ where Ok(buf) } +#[cfg(feature = "async")] +pub(crate) async fn read_bytes_buf<'a, const N: usize, R>( + reader: &mut R, + buf: &'a mut [MaybeUninit<u8>; N], + allowed_size: RangeInclusive<usize>, +) -> io::Result<&'a [u8]> +where + R: AsyncReadExt + Unpin + ?Sized, +{ + assert_eq!(N % 8, 0); + assert!(*allowed_size.end() <= N); + + let len = reader.read_u64_le().await?; + let len: usize = len + .try_into() + .ok() + .filter(|len| allowed_size.contains(len)) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "signalled package size not in allowed range", + ) + })?; + + let buf_len = (len + 7) & !7; + let buf = { + let mut read_buf = ReadBuf::uninit(&mut buf[..buf_len]); + + while read_buf.filled().len() < buf_len { + reader.read_buf(&mut read_buf).await?; + } + + // ReadBuf::filled does not pass the underlying buffer's lifetime through, + // so we must make a trip to hell. + // + // SAFETY: `read_buf` is filled up to `buf_len`, and we verify that it is + // still pointing at the same underlying buffer. + unsafe { + assert_eq!(read_buf.filled().as_ptr(), buf.as_ptr() as *const u8); + assume_init_bytes(&buf[..buf_len]) + } + }; + + if buf[len..buf_len].iter().any(|&b| b != 0) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "padding is not all zeroes", + )); + } + + Ok(&buf[..len]) +} + +/// SAFETY: The bytes have to actually be initialized. +#[cfg(feature = "async")] +unsafe fn assume_init_bytes(slice: &[MaybeUninit<u8>]) -> &[u8] { + &*(slice as *const [MaybeUninit<u8>] as *const [u8]) +} + /// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string. /// Internally uses [read_bytes]. /// Rejects reading more than `allowed_size` bytes of payload. -pub async fn read_string<R>( - r: &mut R, - allowed_size: RangeInclusive<usize>, -) -> std::io::Result<String> +pub async fn read_string<R>(r: &mut R, allowed_size: RangeInclusive<usize>) -> io::Result<String> where R: AsyncReadExt + Unpin, { @@ -108,7 +164,7 @@ where pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>( w: &mut W, b: B, -) -> std::io::Result<()> { +) -> io::Result<()> { // write the size packet. w.write_u64_le(b.as_ref().len() as u64).await?; @@ -125,7 +181,7 @@ pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>( /// Computes the number of bytes we should add to len (a length in /// bytes) to be aligned on 64 bits (8 bytes). -fn padding_len(len: u64) -> u8 { +pub(crate) fn padding_len(len: u64) -> u8 { let aligned = len.wrapping_add(7) & !7; aligned.wrapping_sub(len) as u8 } diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs index cd45f78a0c84..a6209a6e6dad 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs @@ -1,11 +1,12 @@ use std::{ future::Future, io, + num::NonZeroU64, ops::RangeBounds, pin::Pin, task::{self, ready, Poll}, }; -use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf}; use trailer::{read_trailer, ReadTrailer, Trailer}; @@ -33,14 +34,26 @@ pub struct BytesReader<R, T: Tag = Pad> { state: State<R, T>, } +/// Split the `user_len` into `body_len` and `tail_len`, which are respectively +/// the non-terminal 8-byte blocks, and the ≤8 bytes of user data contained in +/// the trailer block. +#[inline(always)] +fn split_user_len(user_len: NonZeroU64) -> (u64, u8) { + let n = user_len.get() - 1; + let body_len = n & !7; + let tail_len = (n & 7) as u8 + 1; + (body_len, tail_len) +} + #[derive(Debug)] enum State<R, T: Tag> { /// Full 8-byte blocks are being read and released to the caller. + /// NOTE: The final 8-byte block is *always* part of the trailer. Body { reader: Option<R>, consumed: u64, /// The total length of all user data contained in both the body and trailer. - user_len: u64, + user_len: NonZeroU64, }, /// The trailer is in the process of being read. ReadTrailer(ReadTrailer<R, T>), @@ -76,10 +89,16 @@ where } Ok(Self { - state: State::Body { - reader: Some(reader), - consumed: 0, - user_len: size, + state: match NonZeroU64::new(size) { + Some(size) => State::Body { + reader: Some(reader), + consumed: 0, + user_len: size, + }, + None => State::ReleaseTrailer { + consumed: 0, + data: read_trailer::<R, T>(reader, 0).await?, + }, }, }) } @@ -90,13 +109,11 @@ where } /// Remaining data length, ie not including data already read. - /// - /// If the size has not been read yet, this is [None]. pub fn len(&self) -> u64 { match self.state { State::Body { consumed, user_len, .. - } => user_len - consumed, + } => user_len.get() - consumed, State::ReadTrailer(ref fut) => fut.len() as u64, State::ReleaseTrailer { consumed, ref data } => data.len() as u64 - consumed as u64, } @@ -119,22 +136,21 @@ impl<R: AsyncRead + Unpin, T: Tag> AsyncRead for BytesReader<R, T> { consumed, user_len, } => { - let body_len = *user_len & !7; + let (body_len, tail_len) = split_user_len(*user_len); let remaining = body_len - *consumed; let reader = if remaining == 0 { let reader = reader.take().unwrap(); - let user_len = (*user_len & 7) as u8; - *this = State::ReadTrailer(read_trailer(reader, user_len)); + *this = State::ReadTrailer(read_trailer(reader, tail_len)); continue; } else { - reader.as_mut().unwrap() + Pin::new(reader.as_mut().unwrap()) }; let mut bytes_read = 0; ready!(with_limited(buf, remaining, |buf| { - let ret = Pin::new(reader).poll_read(cx, buf); - bytes_read = buf.initialized().len(); + let ret = reader.poll_read(cx, buf); + bytes_read = buf.filled().len(); ret }))?; @@ -167,6 +183,96 @@ impl<R: AsyncRead + Unpin, T: Tag> AsyncRead for BytesReader<R, T> { } } +#[allow(private_bounds)] +impl<R: AsyncBufRead + Unpin, T: Tag> AsyncBufRead for BytesReader<R, T> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<io::Result<&[u8]>> { + let this = &mut self.get_mut().state; + + loop { + match this { + // This state comes *after* the following case, + // but we can't keep it in logical order because + // that would lengthen the borrow lifetime. + State::Body { + reader, + consumed, + user_len, + } if { + let (body_len, _) = split_user_len(*user_len); + let remaining = body_len - *consumed; + + remaining == 0 + } => + { + let reader = reader.take().unwrap(); + let (_, tail_len) = split_user_len(*user_len); + + *this = State::ReadTrailer(read_trailer(reader, tail_len)); + } + State::Body { + reader, + consumed, + user_len, + } => { + let (body_len, _) = split_user_len(*user_len); + let remaining = body_len - *consumed; + + let reader = Pin::new(reader.as_mut().unwrap()); + + match ready!(reader.poll_fill_buf(cx))? { + &[] => { + return Err(io::ErrorKind::UnexpectedEof.into()).into(); + } + mut buf => { + if buf.len() as u64 > remaining { + buf = &buf[..remaining as usize]; + } + + return Ok(buf).into(); + } + } + } + State::ReadTrailer(fut) => { + *this = State::ReleaseTrailer { + consumed: 0, + data: ready!(Pin::new(fut).poll(cx))?, + }; + } + State::ReleaseTrailer { consumed, data } => { + return Ok(&data[*consumed as usize..]).into(); + } + } + } + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + match &mut self.state { + State::Body { + reader, + consumed, + user_len, + } => { + let reader = Pin::new(reader.as_mut().unwrap()); + let (body_len, _) = split_user_len(*user_len); + + *consumed = consumed + .checked_add(amt as u64) + .filter(|&consumed| consumed <= body_len) + .expect("consumed out of bounds"); + + reader.consume(amt); + } + State::ReadTrailer(_) => unreachable!(), + State::ReleaseTrailer { consumed, data } => { + *consumed = amt + .checked_add(*consumed as usize) + .filter(|&consumed| consumed <= data.len()) + .expect("consumed out of bounds") as u8; + } + } + } +} + /// Make a limited version of `buf`, consisting only of up to `n` bytes of the unfilled section, and call `f` with it. /// After `f` returns, we propagate the filled cursor advancement back to `buf`. fn with_limited<R>(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) -> R { @@ -193,13 +299,13 @@ fn with_limited<R>(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) #[cfg(test)] mod tests { + use std::sync::LazyLock; use std::time::Duration; use crate::wire::bytes::{padding_len, write_bytes}; use hex_literal::hex; - use lazy_static::lazy_static; use rstest::rstest; - use tokio::io::AsyncReadExt; + use tokio::io::{AsyncReadExt, BufReader}; use tokio_test::io::Builder; use super::*; @@ -208,9 +314,8 @@ mod tests { /// cases. const MAX_LEN: u64 = 1024; - lazy_static! { - pub static ref LARGE_PAYLOAD: Vec<u8> = (0..255).collect::<Vec<u8>>().repeat(4 * 1024); - } + pub static LARGE_PAYLOAD: LazyLock<Vec<u8>> = + LazyLock::new(|| (0..255).collect::<Vec<u8>>().repeat(4 * 1024)); /// Helper function, calling the (simpler) write_bytes with the payload. /// We use this to create data we want to read from the wire. @@ -243,6 +348,34 @@ mod tests { assert_eq!(payload, &buf[..]); } + /// Read bytes packets of various length, and ensure copy_buf reads the + /// expected payload. + #[rstest] + #[case::empty(&[])] // empty bytes packet + #[case::size_1b(&[0xff])] // 1 bytes payload + #[case::size_8b(&hex!("0001020304050607"))] // 8 bytes payload (no padding) + #[case::size_9b(&hex!("000102030405060708"))] // 9 bytes payload (7 bytes padding) + #[case::size_1m(LARGE_PAYLOAD.as_slice())] // larger bytes packet + #[tokio::test] + async fn read_payload_correct_readbuf(#[case] payload: &[u8]) { + let mut mock = BufReader::new( + Builder::new() + .read(&produce_packet_bytes(payload).await) + .build(), + ); + + let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64) + .await + .unwrap(); + + let mut buf = Vec::new(); + tokio::io::copy_buf(&mut r, &mut buf) + .await + .expect("copy_buf must succeed"); + + assert_eq!(payload, &buf[..]); + } + /// Fail if the bytes packet is larger than allowed #[tokio::test] async fn read_bigger_than_allowed_fail() { @@ -277,6 +410,47 @@ mod tests { ); } + /// Read the trailer immediately if there is no payload. + #[cfg(feature = "async")] + #[tokio::test] + async fn read_trailer_immediately() { + use crate::nar::wire::PadPar; + + let mut mock = Builder::new() + .read(&[0; 8]) + .read(&PadPar::PATTERN[8..]) + .build(); + + BytesReader::<_, PadPar>::new_internal(&mut mock, ..) + .await + .unwrap(); + + // The mock reader will panic if dropped without reading all data. + } + + /// Read the trailer even if we only read the exact payload size. + #[cfg(feature = "async")] + #[tokio::test] + async fn read_exact_trailer() { + use crate::nar::wire::PadPar; + + let mut mock = Builder::new() + .read(&16u64.to_le_bytes()) + .read(&[0x55; 16]) + .read(&PadPar::PATTERN[8..]) + .build(); + + let mut reader = BytesReader::<_, PadPar>::new_internal(&mut mock, ..) + .await + .unwrap(); + + let mut buf = [0; 16]; + reader.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, [0x55; 16]); + + // The mock reader will panic if dropped without reading all data. + } + /// Fail if the padding is not all zeroes #[tokio::test] async fn read_fail_if_nonzero_padding() { @@ -402,6 +576,48 @@ mod tests { ); } + /// Start a 9 bytes payload packet, but return an error after a certain position. + /// Ensure that error is propagated (AsyncReadBuf case) + #[rstest] + #[case::during_size(4)] + #[case::before_payload(8)] + #[case::during_payload(8 + 4)] + #[case::before_padding(8 + 4)] + #[case::during_padding(8 + 9 + 2)] + #[tokio::test] + async fn propagate_error_from_reader_buffered(#[case] offset: usize) { + let payload = &hex!("FF0102030405060708"); + let mock = Builder::new() + .read(&produce_packet_bytes(payload).await[..offset]) + .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo")) + .build(); + let mut mock = BufReader::new(mock); + + // Either length reading or data reading can fail, depending on which test case we're in. + let err: io::Error = async { + let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await?; + let mut buf = Vec::new(); + + tokio::io::copy_buf(&mut r, &mut buf).await?; + + Ok(()) + } + .await + .expect_err("must fail"); + + assert_eq!( + err.kind(), + std::io::ErrorKind::Other, + "error kind must match" + ); + + assert_eq!( + err.into_inner().unwrap().to_string(), + "foo", + "error payload must contain foo" + ); + } + /// If there's an error right after the padding, we don't propagate it, as /// we're done reading. We just return EOF. #[tokio::test] @@ -419,6 +635,26 @@ mod tests { assert_eq!(buf.as_slice(), payload); } + /// If there's an error right after the padding, we don't propagate it, as + /// we're done reading. We just return EOF. + #[tokio::test] + async fn no_error_after_eof_buffered() { + let payload = &hex!("FF0102030405060708"); + let mock = Builder::new() + .read(&produce_packet_bytes(payload).await) + .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo")) + .build(); + let mut mock = BufReader::new(mock); + + let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap(); + let mut buf = Vec::new(); + + tokio::io::copy_buf(&mut r, &mut buf) + .await + .expect("must succeed"); + assert_eq!(buf.as_slice(), payload); + } + /// Introduce various stalls in various places of the packet, to ensure we /// handle these cases properly, too. #[rstest] diff --git a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs index 0b0c7b13554d..3a5bb75e7103 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/trailer.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/trailer.rs @@ -9,11 +9,11 @@ use std::{ use tokio::io::{self, AsyncRead, ReadBuf}; -/// Trailer represents up to 7 bytes of data read as part of the trailer block(s) +/// Trailer represents up to 8 bytes of data read as part of the trailer block(s) #[derive(Debug)] pub(crate) struct Trailer { data_len: u8, - buf: [u8; 7], + buf: [u8; 8], } impl Deref for Trailer { @@ -28,7 +28,7 @@ impl Deref for Trailer { pub(crate) trait Tag { /// The expected suffix /// - /// The first 7 bytes may be ignored, and it must be an 8-byte aligned size. + /// The first 8 bytes may be ignored, and it must be an 8-byte aligned size. const PATTERN: &'static [u8]; /// Suitably sized buffer for reading [Self::PATTERN] @@ -67,7 +67,7 @@ pub(crate) fn read_trailer<R: AsyncRead + Unpin, T: Tag>( reader: R, data_len: u8, ) -> ReadTrailer<R, T> { - assert!(data_len < 8, "payload in trailer must be less than 8 bytes"); + assert!(data_len <= 8, "payload in trailer must be <= 8 bytes"); let buf = T::make_buf(); assert_eq!(buf.as_ref().len(), T::PATTERN.len()); @@ -108,8 +108,8 @@ impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> { } if this.filled as usize == T::PATTERN.len() { - let mut buf = [0; 7]; - buf.copy_from_slice(&this.buf.as_ref()[..7]); + let mut buf = [0; 8]; + buf.copy_from_slice(&this.buf.as_ref()[..8]); return Ok(Trailer { data_len: this.data_len, @@ -124,10 +124,9 @@ impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> { ready!(Pin::new(&mut this.reader).poll_read(cx, &mut buf))?; this.filled = { - let prev_filled = this.filled; let filled = buf.filled().len() as u8; - if filled == prev_filled { + if filled == this.filled { return Err(io::ErrorKind::UnexpectedEof.into()).into(); } diff --git a/tvix/nix-compat/src/wire/bytes/writer.rs b/tvix/nix-compat/src/wire/bytes/writer.rs index f5632771e961..8b9b59aa1b85 100644 --- a/tvix/nix-compat/src/wire/bytes/writer.rs +++ b/tvix/nix-compat/src/wire/bytes/writer.rs @@ -232,19 +232,18 @@ where #[cfg(test)] mod tests { + use std::sync::LazyLock; use std::time::Duration; use crate::wire::bytes::write_bytes; use hex_literal::hex; - use lazy_static::lazy_static; use tokio::io::AsyncWriteExt; use tokio_test::{assert_err, assert_ok, io::Builder}; use super::*; - lazy_static! { - pub static ref LARGE_PAYLOAD: Vec<u8> = (0..255).collect::<Vec<u8>>().repeat(4 * 1024); - } + pub static LARGE_PAYLOAD: LazyLock<Vec<u8>> = + LazyLock::new(|| (0..255).collect::<Vec<u8>>().repeat(4 * 1024)); /// Helper function, calling the (simpler) write_bytes with the payload. /// We use this to create data we want to see on the wire. diff --git a/tvix/nix-compat/src/wire/de/bytes.rs b/tvix/nix-compat/src/wire/de/bytes.rs new file mode 100644 index 000000000000..4c64247f7051 --- /dev/null +++ b/tvix/nix-compat/src/wire/de/bytes.rs @@ -0,0 +1,70 @@ +use bytes::Bytes; + +use super::{Error, NixDeserialize, NixRead}; + +impl NixDeserialize for Bytes { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + reader.try_read_bytes().await + } +} + +impl NixDeserialize for String { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + if let Some(buf) = reader.try_read_bytes().await? { + String::from_utf8(buf.to_vec()) + .map_err(R::Error::invalid_data) + .map(Some) + } else { + Ok(None) + } + } +} + +#[cfg(test)] +mod test { + use std::io; + + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use crate::wire::de::{NixRead, NixReader}; + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_read_string(#[case] expected: &str, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: String = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + #[tokio::test] + async fn test_read_string_invalid() { + let mock = Builder::new() + .read(&hex!("0300 0000 0000 0000 EDA0 8000 0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + assert_eq!( + io::ErrorKind::InvalidData, + reader.read_value::<String>().await.unwrap_err().kind() + ); + } +} diff --git a/tvix/nix-compat/src/wire/de/collections.rs b/tvix/nix-compat/src/wire/de/collections.rs new file mode 100644 index 000000000000..e1271635e4e6 --- /dev/null +++ b/tvix/nix-compat/src/wire/de/collections.rs @@ -0,0 +1,105 @@ +use std::{collections::BTreeMap, future::Future}; + +use super::{NixDeserialize, NixRead}; + +#[allow(clippy::manual_async_fn)] +impl<T> NixDeserialize for Vec<T> +where + T: NixDeserialize + Send, +{ + fn try_deserialize<R>( + reader: &mut R, + ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send, + { + async move { + if let Some(len) = reader.try_read_value::<usize>().await? { + let mut ret = Vec::with_capacity(len); + for _ in 0..len { + ret.push(reader.read_value().await?); + } + Ok(Some(ret)) + } else { + Ok(None) + } + } + } +} + +#[allow(clippy::manual_async_fn)] +impl<K, V> NixDeserialize for BTreeMap<K, V> +where + K: NixDeserialize + Ord + Send, + V: NixDeserialize + Send, +{ + fn try_deserialize<R>( + reader: &mut R, + ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send, + { + async move { + if let Some(len) = reader.try_read_value::<usize>().await? { + let mut ret = BTreeMap::new(); + for _ in 0..len { + let key = reader.read_value().await?; + let value = reader.read_value().await?; + ret.insert(key, value); + } + Ok(Some(ret)) + } else { + Ok(None) + } + } + } +} + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + use std::fmt; + + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use crate::wire::de::{NixDeserialize, NixRead, NixReader}; + + #[rstest] + #[case::empty(vec![], &hex!("0000 0000 0000 0000"))] + #[case::one(vec![0x29], &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(vec![0x7469, 10], &hex!("0200 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_read_small_vec(#[case] expected: Vec<usize>, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: Vec<usize> = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + fn empty_map() -> BTreeMap<usize, u64> { + BTreeMap::new() + } + macro_rules! map { + ($($key:expr => $value:expr),*) => {{ + let mut ret = BTreeMap::new(); + $(ret.insert($key, $value);)* + ret + }}; + } + + #[rstest] + #[case::empty(empty_map(), &hex!("0000 0000 0000 0000"))] + #[case::one(map![0x7469usize => 10u64], &hex!("0100 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_read_small_btree_map<E>(#[case] expected: E, #[case] data: &[u8]) + where + E: NixDeserialize + PartialEq + fmt::Debug, + { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: E = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } +} diff --git a/tvix/nix-compat/src/wire/de/int.rs b/tvix/nix-compat/src/wire/de/int.rs new file mode 100644 index 000000000000..d505de9b1b24 --- /dev/null +++ b/tvix/nix-compat/src/wire/de/int.rs @@ -0,0 +1,100 @@ +use super::{Error, NixDeserialize, NixRead}; + +impl NixDeserialize for u64 { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + reader.try_read_number().await + } +} + +impl NixDeserialize for usize { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + if let Some(value) = reader.try_read_number().await? { + value.try_into().map_err(R::Error::invalid_data).map(Some) + } else { + Ok(None) + } + } +} + +impl NixDeserialize for bool { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + Ok(reader.try_read_number().await?.map(|v| v != 0)) + } +} +impl NixDeserialize for i64 { + async fn try_deserialize<R>(reader: &mut R) -> Result<Option<Self>, R::Error> + where + R: ?Sized + NixRead + Send, + { + Ok(reader.try_read_number().await?.map(|v| v as i64)) + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use crate::wire::de::{NixRead, NixReader}; + + #[rstest] + #[case::simple_false(false, &hex!("0000 0000 0000 0000"))] + #[case::simple_true(true, &hex!("0100 0000 0000 0000"))] + #[case::other_true(true, &hex!("1234 5600 0000 0000"))] + #[case::max_true(true, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_read_bool(#[case] expected: bool, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: bool = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_read_u64(#[case] expected: u64, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: u64 = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(usize::MAX, &usize::MAX.to_le_bytes())] + #[tokio::test] + async fn test_read_usize(#[case] expected: usize, #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual: usize = reader.read_value().await.unwrap(); + assert_eq!(actual, expected); + } + + // FUTUREWORK: Test this on supported hardware + #[tokio::test] + #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] + async fn test_read_usize_overflow() { + let mock = Builder::new().read(&u64::MAX.to_le_bytes()).build(); + let mut reader = NixReader::new(mock); + assert_eq!( + std::io::ErrorKind::InvalidData, + reader.read_value::<usize>().await.unwrap_err().kind() + ); + } +} diff --git a/tvix/nix-compat/src/wire/de/mock.rs b/tvix/nix-compat/src/wire/de/mock.rs new file mode 100644 index 000000000000..8a1fb817743c --- /dev/null +++ b/tvix/nix-compat/src/wire/de/mock.rs @@ -0,0 +1,261 @@ +use std::collections::VecDeque; +use std::fmt; +use std::io; +use std::thread; + +use bytes::Bytes; +use thiserror::Error; + +use crate::wire::ProtocolVersion; + +use super::NixRead; + +#[derive(Debug, Error, PartialEq, Eq, Clone)] +pub enum Error { + #[error("custom error '{0}'")] + Custom(String), + #[error("invalid data '{0}'")] + InvalidData(String), + #[error("missing data '{0}'")] + MissingData(String), + #[error("IO error {0} '{1}'")] + IO(io::ErrorKind, String), + #[error("wrong read: expected {0} got {1}")] + WrongRead(OperationType, OperationType), +} + +impl Error { + pub fn expected_read_number() -> Error { + Error::WrongRead(OperationType::ReadNumber, OperationType::ReadBytes) + } + + pub fn expected_read_bytes() -> Error { + Error::WrongRead(OperationType::ReadBytes, OperationType::ReadNumber) + } +} + +impl super::Error for Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + Self::Custom(msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + Self::IO(err.kind(), err.to_string()) + } + + fn invalid_data<T: fmt::Display>(msg: T) -> Self { + Self::InvalidData(msg.to_string()) + } + + fn missing_data<T: fmt::Display>(msg: T) -> Self { + Self::MissingData(msg.to_string()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OperationType { + ReadNumber, + ReadBytes, +} + +impl fmt::Display for OperationType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::ReadNumber => write!(f, "read_number"), + Self::ReadBytes => write!(f, "read_bytess"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum Operation { + ReadNumber(Result<u64, Error>), + ReadBytes(Result<Bytes, Error>), +} + +impl From<Operation> for OperationType { + fn from(value: Operation) -> Self { + match value { + Operation::ReadNumber(_) => OperationType::ReadNumber, + Operation::ReadBytes(_) => OperationType::ReadBytes, + } + } +} + +pub struct Builder { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Builder { + pub fn new() -> Builder { + Builder { + version: Default::default(), + ops: VecDeque::new(), + } + } + + pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self { + self.version = version.into(); + self + } + + pub fn read_number(&mut self, value: u64) -> &mut Self { + self.ops.push_back(Operation::ReadNumber(Ok(value))); + self + } + + pub fn read_number_error(&mut self, err: Error) -> &mut Self { + self.ops.push_back(Operation::ReadNumber(Err(err))); + self + } + + pub fn read_bytes(&mut self, value: Bytes) -> &mut Self { + self.ops.push_back(Operation::ReadBytes(Ok(value))); + self + } + + pub fn read_slice(&mut self, data: &[u8]) -> &mut Self { + let value = Bytes::copy_from_slice(data); + self.ops.push_back(Operation::ReadBytes(Ok(value))); + self + } + + pub fn read_bytes_error(&mut self, err: Error) -> &mut Self { + self.ops.push_back(Operation::ReadBytes(Err(err))); + self + } + + pub fn build(&mut self) -> Mock { + Mock { + version: self.version, + ops: self.ops.clone(), + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +pub struct Mock { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl NixRead for Mock { + type Error = Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> { + match self.ops.pop_front() { + Some(Operation::ReadNumber(ret)) => ret.map(Some), + Some(Operation::ReadBytes(_)) => Err(Error::expected_read_bytes()), + None => Ok(None), + } + } + + async fn try_read_bytes_limited( + &mut self, + _limit: std::ops::RangeInclusive<usize>, + ) -> Result<Option<Bytes>, Self::Error> { + match self.ops.pop_front() { + Some(Operation::ReadBytes(ret)) => ret.map(Some), + Some(Operation::ReadNumber(_)) => Err(Error::expected_read_number()), + None => Ok(None), + } + } +} + +impl Drop for Mock { + fn drop(&mut self) { + // No need to panic again + if thread::panicking() { + return; + } + if let Some(op) = self.ops.front() { + panic!("reader dropped with {op:?} operation still unread") + } + } +} + +#[cfg(test)] +mod test { + use bytes::Bytes; + use hex_literal::hex; + + use crate::wire::de::NixRead; + + use super::{Builder, Error}; + + #[tokio::test] + async fn read_slice() { + let mut mock = Builder::new() + .read_number(10) + .read_slice(&[]) + .read_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]); + assert_eq!( + &hex!("0000 1234 5678 9ABC DEFF"), + &mock.read_bytes().await.unwrap()[..] + ); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn read_bytes() { + let mut mock = Builder::new() + .read_number(10) + .read_bytes(Bytes::from_static(&[])) + .read_bytes(Bytes::from_static(&hex!("0000 1234 5678 9ABC DEFF"))) + .build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]); + assert_eq!( + &hex!("0000 1234 5678 9ABC DEFF"), + &mock.read_bytes().await.unwrap()[..] + ); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn read_number() { + let mut mock = Builder::new().read_number(10).build(); + assert_eq!(10, mock.read_number().await.unwrap()); + assert_eq!(None, mock.try_read_number().await.unwrap()); + assert_eq!(None, mock.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn expect_number() { + let mut mock = Builder::new().read_number(10).build(); + assert_eq!( + Error::expected_read_number(), + mock.read_bytes().await.unwrap_err() + ); + } + + #[tokio::test] + async fn expect_bytes() { + let mut mock = Builder::new().read_slice(&[]).build(); + assert_eq!( + Error::expected_read_bytes(), + mock.read_number().await.unwrap_err() + ); + } + + #[test] + #[should_panic] + fn operations_left() { + let _ = Builder::new().read_number(10).build(); + } +} diff --git a/tvix/nix-compat/src/wire/de/mod.rs b/tvix/nix-compat/src/wire/de/mod.rs new file mode 100644 index 000000000000..f85ccd8fea0e --- /dev/null +++ b/tvix/nix-compat/src/wire/de/mod.rs @@ -0,0 +1,225 @@ +use std::error::Error as StdError; +use std::future::Future; +use std::ops::RangeInclusive; +use std::{fmt, io}; + +use ::bytes::Bytes; + +use super::ProtocolVersion; + +mod bytes; +mod collections; +mod int; +#[cfg(any(test, feature = "test"))] +pub mod mock; +mod reader; + +pub use reader::{NixReader, NixReaderBuilder}; + +/// Like serde the `Error` trait allows `NixRead` implementations to add +/// custom error handling for `NixDeserialize`. +pub trait Error: Sized + StdError { + /// A totally custom non-specific error. + fn custom<T: fmt::Display>(msg: T) -> Self; + + /// Some kind of std::io::Error occured. + fn io_error(err: std::io::Error) -> Self { + Self::custom(format_args!("There was an I/O error {}", err)) + } + + /// The data read from `NixRead` is invalid. + /// This could be that some bytes were supposed to be valid UFT-8 but weren't. + fn invalid_data<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } + + /// Required data is missing. This is mostly like an EOF + fn missing_data<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } +} + +impl Error for io::Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::Other, msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + err + } + + fn invalid_data<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::InvalidData, msg.to_string()) + } + + fn missing_data<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::UnexpectedEof, msg.to_string()) + } +} + +/// A reader of data from the Nix daemon protocol. +/// Basically there are two basic types in the Nix daemon protocol +/// u64 and a bytes buffer. Everything else is more or less built on +/// top of these two types. +pub trait NixRead: Send { + type Error: Error + Send; + + /// Some types are serialized differently depending on the version + /// of the protocol and so this can be used for implementing that. + fn version(&self) -> ProtocolVersion; + + /// Read a single u64 from the protocol. + /// This returns an Option to support graceful shutdown. + fn try_read_number( + &mut self, + ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + '_; + + /// Read bytes from the protocol. + /// A size limit on the returned bytes has to be specified. + /// This returns an Option to support graceful shutdown. + fn try_read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_; + + /// Read bytes from the protocol without a limit. + /// The default implementation just calls `try_read_bytes_limited` with a + /// limit of `0..=usize::MAX` but other implementations are free to have a + /// reader wide limit. + /// This returns an Option to support graceful shutdown. + fn try_read_bytes( + &mut self, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + self.try_read_bytes_limited(0..=usize::MAX) + } + + /// Read a single u64 from the protocol. + /// This will return an error if the number could not be read. + fn read_number(&mut self) -> impl Future<Output = Result<u64, Self::Error>> + Send + '_ { + async move { + match self.try_read_number().await? { + Some(v) => Ok(v), + None => Err(Self::Error::missing_data("unexpected end-of-file")), + } + } + } + + /// Read bytes from the protocol. + /// A size limit on the returned bytes has to be specified. + /// This will return an error if the number could not be read. + fn read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + async move { + match self.try_read_bytes_limited(limit).await? { + Some(v) => Ok(v), + None => Err(Self::Error::missing_data("unexpected end-of-file")), + } + } + } + + /// Read bytes from the protocol. + /// The default implementation just calls `read_bytes_limited` with a + /// limit of `0..=usize::MAX` but other implementations are free to have a + /// reader wide limit. + /// This will return an error if the bytes could not be read. + fn read_bytes(&mut self) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + self.read_bytes_limited(0..=usize::MAX) + } + + /// Read a value from the protocol. + /// Uses `NixDeserialize::deserialize` to read a value. + fn read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<V, Self::Error>> + Send + '_ { + V::deserialize(self) + } + + /// Read a value from the protocol. + /// Uses `NixDeserialize::try_deserialize` to read a value. + /// This returns an Option to support graceful shutdown. + fn try_read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<Option<V>, Self::Error>> + Send + '_ { + V::try_deserialize(self) + } +} + +impl<T: ?Sized + NixRead> NixRead for &mut T { + type Error = T::Error; + + fn version(&self) -> ProtocolVersion { + (**self).version() + } + + fn try_read_number( + &mut self, + ) -> impl Future<Output = Result<Option<u64>, Self::Error>> + Send + '_ { + (**self).try_read_number() + } + + fn try_read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + (**self).try_read_bytes_limited(limit) + } + + fn try_read_bytes( + &mut self, + ) -> impl Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + (**self).try_read_bytes() + } + + fn read_number(&mut self) -> impl Future<Output = Result<u64, Self::Error>> + Send + '_ { + (**self).read_number() + } + + fn read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + (**self).read_bytes_limited(limit) + } + + fn read_bytes(&mut self) -> impl Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + (**self).read_bytes() + } + + fn try_read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<Option<V>, Self::Error>> + Send + '_ { + (**self).try_read_value() + } + + fn read_value<V: NixDeserialize>( + &mut self, + ) -> impl Future<Output = Result<V, Self::Error>> + Send + '_ { + (**self).read_value() + } +} + +/// A data structure that can be deserialized from the Nix daemon +/// worker protocol. +pub trait NixDeserialize: Sized { + /// Read a value from the reader. + /// This returns an Option to support gracefull shutdown. + fn try_deserialize<R>( + reader: &mut R, + ) -> impl Future<Output = Result<Option<Self>, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send; + + fn deserialize<R>(reader: &mut R) -> impl Future<Output = Result<Self, R::Error>> + Send + '_ + where + R: ?Sized + NixRead + Send, + { + async move { + match Self::try_deserialize(reader).await? { + Some(v) => Ok(v), + None => Err(R::Error::missing_data("unexpected end-of-file")), + } + } + } +} diff --git a/tvix/nix-compat/src/wire/de/reader.rs b/tvix/nix-compat/src/wire/de/reader.rs new file mode 100644 index 000000000000..b7825f393c4e --- /dev/null +++ b/tvix/nix-compat/src/wire/de/reader.rs @@ -0,0 +1,526 @@ +use std::future::poll_fn; +use std::io::{self, Cursor}; +use std::ops::RangeInclusive; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf}; + +use crate::wire::{ProtocolVersion, EMPTY_BYTES}; + +use super::{Error, NixRead}; + +pub struct NixReaderBuilder { + buf: Option<BytesMut>, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, +} + +impl Default for NixReaderBuilder { + fn default() -> Self { + Self { + buf: Default::default(), + reserved_buf_size: 8192, + max_buf_size: 8192, + version: Default::default(), + } + } +} + +impl NixReaderBuilder { + pub fn set_buffer(mut self, buf: BytesMut) -> Self { + self.buf = Some(buf); + self + } + + pub fn set_reserved_buf_size(mut self, size: usize) -> Self { + self.reserved_buf_size = size; + self + } + + pub fn set_max_buf_size(mut self, size: usize) -> Self { + self.max_buf_size = size; + self + } + + pub fn set_version(mut self, version: ProtocolVersion) -> Self { + self.version = version; + self + } + + pub fn build<R>(self, reader: R) -> NixReader<R> { + let buf = self.buf.unwrap_or_else(|| BytesMut::with_capacity(0)); + NixReader { + buf, + inner: reader, + reserved_buf_size: self.reserved_buf_size, + max_buf_size: self.max_buf_size, + version: self.version, + } + } +} + +pin_project! { + pub struct NixReader<R> { + #[pin] + inner: R, + buf: BytesMut, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, + } +} + +impl NixReader<Cursor<Vec<u8>>> { + pub fn builder() -> NixReaderBuilder { + NixReaderBuilder::default() + } +} + +impl<R> NixReader<R> +where + R: AsyncReadExt, +{ + pub fn new(reader: R) -> NixReader<R> { + NixReader::builder().build(reader) + } + + pub fn buffer(&self) -> &[u8] { + &self.buf[..] + } + + #[cfg(test)] + pub(crate) fn buffer_mut(&mut self) -> &mut BytesMut { + &mut self.buf + } + + /// Remaining capacity in internal buffer + pub fn remaining_mut(&self) -> usize { + self.buf.capacity() - self.buf.len() + } + + fn poll_force_fill_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<usize>> { + // Ensure that buffer has space for at least reserved_buf_size bytes + if self.remaining_mut() < self.reserved_buf_size { + let me = self.as_mut().project(); + me.buf.reserve(*me.reserved_buf_size); + } + let me = self.project(); + let n = { + let dst = me.buf.spare_capacity_mut(); + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(me.inner.poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // SAFETY: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + me.buf.advance_mut(n); + } + Poll::Ready(Ok(n)) + } +} + +impl<R> NixReader<R> +where + R: AsyncReadExt + Unpin, +{ + async fn force_fill(&mut self) -> io::Result<usize> { + let mut p = Pin::new(self); + let read = poll_fn(|cx| p.as_mut().poll_force_fill_buf(cx)).await?; + Ok(read) + } +} + +impl<R> NixRead for NixReader<R> +where + R: AsyncReadExt + Send + Unpin, +{ + type Error = io::Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> { + let mut buf = [0u8; 8]; + let read = self.read_buf(&mut &mut buf[..]).await?; + if read == 0 { + return Ok(None); + } + if read < 8 { + self.read_exact(&mut buf[read..]).await?; + } + let num = Buf::get_u64_le(&mut &buf[..]); + Ok(Some(num)) + } + + async fn try_read_bytes_limited( + &mut self, + limit: RangeInclusive<usize>, + ) -> Result<Option<Bytes>, Self::Error> { + assert!( + *limit.end() <= self.max_buf_size, + "The limit must be smaller than {}", + self.max_buf_size + ); + match self.try_read_number().await? { + Some(raw_len) => { + // Check that length is in range and convert to usize + let len = raw_len + .try_into() + .ok() + .filter(|v| limit.contains(v)) + .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))?; + + // Calculate 64bit aligned length and convert to usize + let aligned: usize = raw_len + .checked_add(7) + .map(|v| v & !7) + .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))? + .try_into() + .map_err(Self::Error::invalid_data)?; + + // Ensure that there is enough space in buffer for contents + if self.buf.len() + self.remaining_mut() < aligned { + self.buf.reserve(aligned - self.buf.len()); + } + while self.buf.len() < aligned { + if self.force_fill().await? == 0 { + return Err(Self::Error::missing_data( + "unexpected end-of-file reading bytes", + )); + } + } + let mut contents = self.buf.split_to(aligned); + + let padding = aligned - len; + // Ensure padding is all zeros + if contents[len..] != EMPTY_BYTES[..padding] { + return Err(Self::Error::invalid_data("non-zero padding")); + } + + contents.truncate(len); + Ok(Some(contents.freeze())) + } + None => Ok(None), + } + } + + fn try_read_bytes( + &mut self, + ) -> impl std::future::Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ { + self.try_read_bytes_limited(0..=self.max_buf_size) + } + + fn read_bytes( + &mut self, + ) -> impl std::future::Future<Output = Result<Bytes, Self::Error>> + Send + '_ { + self.read_bytes_limited(0..=self.max_buf_size) + } +} + +impl<R: AsyncRead> AsyncRead for NixReader<R> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let rem = ready!(self.as_mut().poll_fill_buf(cx))?; + let amt = std::cmp::min(rem.len(), buf.remaining()); + buf.put_slice(&rem[0..amt]); + self.consume(amt); + Poll::Ready(Ok(())) + } +} + +impl<R: AsyncRead> AsyncBufRead for NixReader<R> { + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + if self.as_ref().project_ref().buf.is_empty() { + ready!(self.as_mut().poll_force_fill_buf(cx))?; + } + let me = self.project(); + Poll::Ready(Ok(&me.buf[..])) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let me = self.project(); + me.buf.advance(amt) + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use hex_literal::hex; + use rstest::rstest; + use tokio_test::io::Builder; + + use super::*; + use crate::wire::de::NixRead; + + #[tokio::test] + async fn test_read_u64() { + let mock = Builder::new().read(&hex!("0100 0000 0000 0000")).build(); + let mut reader = NixReader::new(mock); + + assert_eq!(1, reader.read_number().await.unwrap()); + assert_eq!(hex!(""), reader.buffer()); + + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(hex!(""), &buf[..]); + } + + #[tokio::test] + async fn test_read_u64_rest() { + let mock = Builder::new() + .read(&hex!("0100 0000 0000 0000 0123 4567 89AB CDEF")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!(1, reader.read_number().await.unwrap()); + assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer()); + + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(hex!("0123 4567 89AB CDEF"), &buf[..]); + } + + #[tokio::test] + async fn test_read_u64_partial() { + let mock = Builder::new() + .read(&hex!("0100 0000")) + .wait(Duration::ZERO) + .read(&hex!("0000 0000 0123 4567 89AB CDEF")) + .wait(Duration::ZERO) + .read(&hex!("0100 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!(1, reader.read_number().await.unwrap()); + assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer()); + + let mut buf = Vec::new(); + reader.read_to_end(&mut buf).await.unwrap(); + assert_eq!(hex!("0123 4567 89AB CDEF 0100 0000"), &buf[..]); + } + + #[tokio::test] + async fn test_read_u64_eof() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.read_number().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_u64_none() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!(None, reader.try_read_number().await.unwrap()); + } + + #[tokio::test] + async fn test_try_read_u64_eof() { + let mock = Builder::new().read(&hex!("0100 0000 0000")).build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_number().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_u64_eof2() { + let mock = Builder::new() + .read(&hex!("0100")) + .wait(Duration::ZERO) + .read(&hex!("0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_number().await.unwrap_err().kind() + ); + } + + #[rstest] + #[case::empty(b"", &hex!("0000 0000 0000 0000"))] + #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_read_bytes(#[case] expected: &[u8], #[case] data: &[u8]) { + let mock = Builder::new().read(data).build(); + let mut reader = NixReader::new(mock); + let actual = reader.read_bytes().await.unwrap(); + assert_eq!(&actual[..], expected); + } + + #[tokio::test] + async fn test_read_bytes_empty() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_bytes_none() { + let mock = Builder::new().build(); + let mut reader = NixReader::new(mock); + + assert_eq!(None, reader.try_read_bytes().await.unwrap()); + } + + #[tokio::test] + async fn test_try_read_bytes_missing_data() { + let mock = Builder::new() + .read(&hex!("0500")) + .wait(Duration::ZERO) + .read(&hex!("0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_try_read_bytes_missing_padding() { + let mock = Builder::new() + .read(&hex!("0200 0000 0000 0000")) + .wait(Duration::ZERO) + .read(&hex!("1234")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::UnexpectedEof, + reader.try_read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_read_bytes_bad_padding() { + let mock = Builder::new() + .read(&hex!("0200 0000 0000 0000")) + .wait(Duration::ZERO) + .read(&hex!("1234 0100 0000 0000")) + .build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::InvalidData, + reader.read_bytes().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_read_bytes_limited_out_of_range() { + let mock = Builder::new().read(&hex!("FFFF 0000 0000 0000")).build(); + let mut reader = NixReader::new(mock); + + assert_eq!( + io::ErrorKind::InvalidData, + reader.read_bytes_limited(0..=50).await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_read_bytes_length_overflow() { + let mock = Builder::new().read(&hex!("F9FF FFFF FFFF FFFF")).build(); + let mut reader = NixReader::builder() + .set_max_buf_size(usize::MAX) + .build(mock); + + assert_eq!( + io::ErrorKind::InvalidData, + reader + .read_bytes_limited(0..=usize::MAX) + .await + .unwrap_err() + .kind() + ); + } + + // FUTUREWORK: Test this on supported hardware + #[tokio::test] + #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] + async fn test_bytes_length_conversion_overflow() { + let len = (usize::MAX as u64) + 1; + let mock = Builder::new().read(&len.to_le_bytes()).build(); + let mut reader = NixReader::new(mock); + assert_eq!( + std::io::ErrorKind::InvalidData, + reader.read_value::<usize>().await.unwrap_err().kind() + ); + } + + // FUTUREWORK: Test this on supported hardware + #[tokio::test] + #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))] + async fn test_bytes_aligned_length_conversion_overflow() { + let len = (usize::MAX - 6) as u64; + let mock = Builder::new().read(&len.to_le_bytes()).build(); + let mut reader = NixReader::new(mock); + assert_eq!( + std::io::ErrorKind::InvalidData, + reader.read_value::<usize>().await.unwrap_err().kind() + ); + } + + #[tokio::test] + async fn test_buffer_resize() { + let mock = Builder::new() + .read(&hex!("0100")) + .read(&hex!("0000 0000 0000")) + .build(); + let mut reader = NixReader::builder().set_reserved_buf_size(8).build(mock); + // buffer has no capacity initially + assert_eq!(0, reader.buffer_mut().capacity()); + + assert_eq!(2, reader.force_fill().await.unwrap()); + + // After first read buffer should have capacity we chose + assert_eq!(8, reader.buffer_mut().capacity()); + + // Because there was only 6 bytes remaining in buffer, + // which is enough to read the last 6 bytes, but we require + // capacity for 8 bytes, it doubled the capacity + assert_eq!(6, reader.force_fill().await.unwrap()); + assert_eq!(16, reader.buffer_mut().capacity()); + + assert_eq!(1, reader.read_number().await.unwrap()); + } +} diff --git a/tvix/nix-compat/src/wire/mod.rs b/tvix/nix-compat/src/wire/mod.rs index a197e3a1f451..c3e88dda05ec 100644 --- a/tvix/nix-compat/src/wire/mod.rs +++ b/tvix/nix-compat/src/wire/mod.rs @@ -3,3 +3,9 @@ mod bytes; pub use bytes::*; + +mod protocol_version; +pub use protocol_version::ProtocolVersion; + +pub mod de; +pub mod ser; diff --git a/tvix/nix-compat/src/nix_daemon/protocol_version.rs b/tvix/nix-compat/src/wire/protocol_version.rs index 8fd2b085c962..19da28d484dd 100644 --- a/tvix/nix-compat/src/nix_daemon/protocol_version.rs +++ b/tvix/nix-compat/src/wire/protocol_version.rs @@ -1,3 +1,6 @@ +/// The latest version that is currently supported by nix-compat. +static DEFAULT_PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::from_parts(1, 37); + /// Protocol versions are represented as a u16. /// The upper 8 bits are the major version, the lower bits the minor. /// This is not aware of any endianness, use [crate::wire::read_u64] to get an @@ -20,6 +23,12 @@ impl ProtocolVersion { } } +impl Default for ProtocolVersion { + fn default() -> Self { + DEFAULT_PROTOCOL_VERSION + } +} + impl PartialOrd for ProtocolVersion { fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { Some(self.cmp(other)) @@ -45,6 +54,13 @@ impl From<u16> for ProtocolVersion { } } +#[cfg(any(test, feature = "test"))] +impl From<(u8, u8)> for ProtocolVersion { + fn from((major, minor): (u8, u8)) -> Self { + Self::from_parts(major, minor) + } +} + impl TryFrom<u64> for ProtocolVersion { type Error = &'static str; diff --git a/tvix/nix-compat/src/wire/ser/bytes.rs b/tvix/nix-compat/src/wire/ser/bytes.rs new file mode 100644 index 000000000000..737edb059b5b --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/bytes.rs @@ -0,0 +1,98 @@ +use bytes::Bytes; + +use super::{NixSerialize, NixWrite}; + +impl NixSerialize for Bytes { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self).await + } +} + +impl<'a> NixSerialize for &'a [u8] { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self).await + } +} + +impl NixSerialize for String { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +impl NixSerialize for str { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +impl NixSerialize for &str { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::{NixWrite, NixWriter}; + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_write_str(#[case] value: &str, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_write_string(#[case] value: &str, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value.to_string()).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/wire/ser/collections.rs b/tvix/nix-compat/src/wire/ser/collections.rs new file mode 100644 index 000000000000..478e1d04d809 --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/collections.rs @@ -0,0 +1,94 @@ +use std::collections::BTreeMap; +use std::future::Future; + +use super::{NixSerialize, NixWrite}; + +impl<T> NixSerialize for Vec<T> +where + T: NixSerialize + Send + Sync, +{ + #[allow(clippy::manual_async_fn)] + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite, + { + async move { + writer.write_value(&self.len()).await?; + for value in self.iter() { + writer.write_value(value).await?; + } + Ok(()) + } + } +} + +impl<K, V> NixSerialize for BTreeMap<K, V> +where + K: NixSerialize + Ord + Send + Sync, + V: NixSerialize + Send + Sync, +{ + #[allow(clippy::manual_async_fn)] + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite, + { + async move { + writer.write_value(&self.len()).await?; + for (key, value) in self.iter() { + writer.write_value(key).await?; + writer.write_value(value).await?; + } + Ok(()) + } + } +} + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + use std::fmt; + + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::{NixSerialize, NixWrite, NixWriter}; + + #[rstest] + #[case::empty(vec![], &hex!("0000 0000 0000 0000"))] + #[case::one(vec![0x29], &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(vec![0x7469, 10], &hex!("0200 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_write_small_vec(#[case] value: Vec<usize>, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + fn empty_map() -> BTreeMap<usize, u64> { + BTreeMap::new() + } + macro_rules! map { + ($($key:expr => $value:expr),*) => {{ + let mut ret = BTreeMap::new(); + $(ret.insert($key, $value);)* + ret + }}; + } + + #[rstest] + #[case::empty(empty_map(), &hex!("0000 0000 0000 0000"))] + #[case::one(map![0x7469usize => 10u64], &hex!("0100 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_write_small_btree_map<E>(#[case] value: E, #[case] data: &[u8]) + where + E: NixSerialize + Send + PartialEq + fmt::Debug, + { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/wire/ser/display.rs b/tvix/nix-compat/src/wire/ser/display.rs new file mode 100644 index 000000000000..a3438d50d8ff --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/display.rs @@ -0,0 +1,8 @@ +use nix_compat_derive::nix_serialize_remote; + +use crate::nixhash; + +nix_serialize_remote!( + #[nix(display)] + nixhash::HashAlgo +); diff --git a/tvix/nix-compat/src/wire/ser/int.rs b/tvix/nix-compat/src/wire/ser/int.rs new file mode 100644 index 000000000000..e68179c71dc7 --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/int.rs @@ -0,0 +1,108 @@ +#[cfg(feature = "nix-compat-derive")] +use nix_compat_derive::nix_serialize_remote; + +use super::{Error, NixSerialize, NixWrite}; + +impl NixSerialize for u64 { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_number(*self).await + } +} + +impl NixSerialize for usize { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + let v = (*self).try_into().map_err(W::Error::unsupported_data)?; + writer.write_number(v).await + } +} + +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u8 +); +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u16 +); +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u32 +); + +impl NixSerialize for bool { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + if *self { + writer.write_number(1).await + } else { + writer.write_number(0).await + } + } +} + +impl NixSerialize for i64 { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_number(*self as u64).await + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::{NixWrite, NixWriter}; + + #[rstest] + #[case::simple_false(false, &hex!("0000 0000 0000 0000"))] + #[case::simple_true(true, &hex!("0100 0000 0000 0000"))] + #[tokio::test] + async fn test_write_bool(#[case] value: bool, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_write_u64(#[case] value: u64, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(usize::MAX, &usize::MAX.to_le_bytes())] + #[tokio::test] + async fn test_write_usize(#[case] value: usize, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/wire/ser/mock.rs b/tvix/nix-compat/src/wire/ser/mock.rs new file mode 100644 index 000000000000..7104a94238ff --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/mock.rs @@ -0,0 +1,672 @@ +use std::collections::VecDeque; +use std::fmt; +use std::io; +use std::thread; + +#[cfg(test)] +use ::proptest::prelude::TestCaseError; +use thiserror::Error; + +use crate::wire::ProtocolVersion; + +use super::NixWrite; + +#[derive(Debug, Error, PartialEq, Eq, Clone)] +pub enum Error { + #[error("custom error '{0}'")] + Custom(String), + #[error("unsupported data error '{0}'")] + UnsupportedData(String), + #[error("Invalid enum: {0}")] + InvalidEnum(String), + #[error("IO error {0} '{1}'")] + IO(io::ErrorKind, String), + #[error("wrong write: expected {0} got {1}")] + WrongWrite(OperationType, OperationType), + #[error("unexpected write: got an extra {0}")] + ExtraWrite(OperationType), + #[error("got an unexpected number {0} in write_number")] + UnexpectedNumber(u64), + #[error("got an unexpected slice '{0:?}' in write_slice")] + UnexpectedSlice(Vec<u8>), + #[error("got an unexpected display '{0:?}' in write_slice")] + UnexpectedDisplay(String), +} + +impl Error { + pub fn unexpected_write_number(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteNumber) + } + + pub fn extra_write_number() -> Error { + Error::ExtraWrite(OperationType::WriteNumber) + } + + pub fn unexpected_write_slice(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteSlice) + } + + pub fn unexpected_write_display(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteDisplay) + } +} + +impl super::Error for Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + Self::Custom(msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + Self::IO(err.kind(), err.to_string()) + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + Self::UnsupportedData(msg.to_string()) + } + + fn invalid_enum<T: fmt::Display>(msg: T) -> Self { + Self::InvalidEnum(msg.to_string()) + } +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OperationType { + WriteNumber, + WriteSlice, + WriteDisplay, +} + +impl fmt::Display for OperationType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::WriteNumber => write!(f, "write_number"), + Self::WriteSlice => write!(f, "write_slice"), + Self::WriteDisplay => write!(f, "write_display"), + } + } +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, PartialEq, Eq)] +enum Operation { + WriteNumber(u64, Result<(), Error>), + WriteSlice(Vec<u8>, Result<(), Error>), + WriteDisplay(String, Result<(), Error>), +} + +impl From<Operation> for OperationType { + fn from(value: Operation) -> Self { + match value { + Operation::WriteNumber(_, _) => OperationType::WriteNumber, + Operation::WriteSlice(_, _) => OperationType::WriteSlice, + Operation::WriteDisplay(_, _) => OperationType::WriteDisplay, + } + } +} + +pub struct Builder { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Builder { + pub fn new() -> Builder { + Builder { + version: Default::default(), + ops: VecDeque::new(), + } + } + + pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self { + self.version = version.into(); + self + } + + pub fn write_number(&mut self, value: u64) -> &mut Self { + self.ops.push_back(Operation::WriteNumber(value, Ok(()))); + self + } + + pub fn write_number_error(&mut self, value: u64, err: Error) -> &mut Self { + self.ops.push_back(Operation::WriteNumber(value, Err(err))); + self + } + + pub fn write_slice(&mut self, value: &[u8]) -> &mut Self { + self.ops + .push_back(Operation::WriteSlice(value.to_vec(), Ok(()))); + self + } + + pub fn write_slice_error(&mut self, value: &[u8], err: Error) -> &mut Self { + self.ops + .push_back(Operation::WriteSlice(value.to_vec(), Err(err))); + self + } + + pub fn write_display<D>(&mut self, value: D) -> &mut Self + where + D: fmt::Display, + { + let msg = value.to_string(); + self.ops.push_back(Operation::WriteDisplay(msg, Ok(()))); + self + } + + pub fn write_display_error<D>(&mut self, value: D, err: Error) -> &mut Self + where + D: fmt::Display, + { + let msg = value.to_string(); + self.ops.push_back(Operation::WriteDisplay(msg, Err(err))); + self + } + + #[cfg(test)] + fn write_operation_type(&mut self, op: OperationType) -> &mut Self { + match op { + OperationType::WriteNumber => self.write_number(10), + OperationType::WriteSlice => self.write_slice(b"testing"), + OperationType::WriteDisplay => self.write_display("testing"), + } + } + + #[cfg(test)] + fn write_operation(&mut self, op: &Operation) -> &mut Self { + match op { + Operation::WriteNumber(value, Ok(_)) => self.write_number(*value), + Operation::WriteNumber(value, Err(Error::UnexpectedNumber(_))) => { + self.write_number(*value) + } + Operation::WriteNumber(_, Err(Error::ExtraWrite(OperationType::WriteNumber))) => self, + Operation::WriteNumber(_, Err(Error::WrongWrite(op, OperationType::WriteNumber))) => { + self.write_operation_type(*op) + } + Operation::WriteNumber(value, Err(Error::Custom(msg))) => { + self.write_number_error(*value, Error::Custom(msg.clone())) + } + Operation::WriteNumber(value, Err(Error::IO(kind, msg))) => { + self.write_number_error(*value, Error::IO(*kind, msg.clone())) + } + Operation::WriteSlice(value, Ok(_)) => self.write_slice(value), + Operation::WriteSlice(value, Err(Error::UnexpectedSlice(_))) => self.write_slice(value), + Operation::WriteSlice(_, Err(Error::ExtraWrite(OperationType::WriteSlice))) => self, + Operation::WriteSlice(_, Err(Error::WrongWrite(op, OperationType::WriteSlice))) => { + self.write_operation_type(*op) + } + Operation::WriteSlice(value, Err(Error::Custom(msg))) => { + self.write_slice_error(value, Error::Custom(msg.clone())) + } + Operation::WriteSlice(value, Err(Error::IO(kind, msg))) => { + self.write_slice_error(value, Error::IO(*kind, msg.clone())) + } + Operation::WriteDisplay(value, Ok(_)) => self.write_display(value), + Operation::WriteDisplay(value, Err(Error::Custom(msg))) => { + self.write_display_error(value, Error::Custom(msg.clone())) + } + Operation::WriteDisplay(value, Err(Error::IO(kind, msg))) => { + self.write_display_error(value, Error::IO(*kind, msg.clone())) + } + Operation::WriteDisplay(value, Err(Error::UnexpectedDisplay(_))) => { + self.write_display(value) + } + Operation::WriteDisplay(_, Err(Error::ExtraWrite(OperationType::WriteDisplay))) => self, + Operation::WriteDisplay(_, Err(Error::WrongWrite(op, OperationType::WriteDisplay))) => { + self.write_operation_type(*op) + } + s => panic!("Invalid operation {:?}", s), + } + } + + pub fn build(&mut self) -> Mock { + Mock { + version: self.version, + ops: self.ops.clone(), + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +pub struct Mock { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Mock { + #[cfg(test)] + #[allow(dead_code)] + async fn assert_operation(&mut self, op: Operation) { + match op { + Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { + assert_eq!( + self.write_number(value).await, + Err(Error::UnexpectedNumber(value)) + ); + } + Operation::WriteNumber(value, res) => { + assert_eq!(self.write_number(value).await, res); + } + Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { + assert_eq!(self.write_slice(value).await, res.clone()); + } + Operation::WriteSlice(value, res) => { + assert_eq!(self.write_slice(&value).await, res); + } + Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { + assert_eq!(self.write_display(value).await, res.clone()); + } + Operation::WriteDisplay(value, res) => { + assert_eq!(self.write_display(value).await, res); + } + } + } + + #[cfg(test)] + async fn prop_assert_operation(&mut self, op: Operation) -> Result<(), TestCaseError> { + use ::proptest::prop_assert_eq; + + match op { + Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { + prop_assert_eq!( + self.write_number(value).await, + Err(Error::UnexpectedNumber(value)) + ); + } + Operation::WriteNumber(value, res) => { + prop_assert_eq!(self.write_number(value).await, res); + } + Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { + prop_assert_eq!(self.write_slice(value).await, res.clone()); + } + Operation::WriteSlice(value, res) => { + prop_assert_eq!(self.write_slice(&value).await, res); + } + Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { + prop_assert_eq!(self.write_display(&value).await, res.clone()); + } + Operation::WriteDisplay(value, res) => { + prop_assert_eq!(self.write_display(&value).await, res); + } + } + Ok(()) + } +} + +impl NixWrite for Mock { + type Error = Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { + match self.ops.pop_front() { + Some(Operation::WriteNumber(expected, ret)) => { + if value != expected { + return Err(Error::UnexpectedNumber(value)); + } + ret + } + Some(op) => Err(Error::unexpected_write_number(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteNumber)), + } + } + + async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + match self.ops.pop_front() { + Some(Operation::WriteSlice(expected, ret)) => { + if buf != expected { + return Err(Error::UnexpectedSlice(buf.to_vec())); + } + ret + } + Some(op) => Err(Error::unexpected_write_slice(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteSlice)), + } + } + + async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error> + where + D: fmt::Display + Send, + Self: Sized, + { + let value = msg.to_string(); + match self.ops.pop_front() { + Some(Operation::WriteDisplay(expected, ret)) => { + if value != expected { + return Err(Error::UnexpectedDisplay(value)); + } + ret + } + Some(op) => Err(Error::unexpected_write_display(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteDisplay)), + } + } +} + +impl Drop for Mock { + fn drop(&mut self) { + // No need to panic again + if thread::panicking() { + return; + } + if let Some(op) = self.ops.front() { + panic!("reader dropped with {op:?} operation still unread") + } + } +} + +#[cfg(test)] +mod proptest { + use std::io; + + use proptest::{ + prelude::{any, Arbitrary, BoxedStrategy, Just, Strategy}, + prop_oneof, + }; + + use super::{Error, Operation, OperationType}; + + pub fn arb_write_number_operation() -> impl Strategy<Value = Operation> { + ( + any::<u64>(), + prop_oneof![ + Just(Ok(())), + any::<u64>().prop_map(|v| Err(Error::UnexpectedNumber(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteSlice, + OperationType::WriteNumber + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteDisplay, + OperationType::WriteNumber + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same number", |(v, res)| match res { + Err(Error::UnexpectedNumber(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteNumber(v, res)) + } + + pub fn arb_write_slice_operation() -> impl Strategy<Value = Operation> { + ( + any::<Vec<u8>>(), + prop_oneof![ + Just(Ok(())), + any::<Vec<u8>>().prop_map(|v| Err(Error::UnexpectedSlice(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteNumber, + OperationType::WriteSlice + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteDisplay, + OperationType::WriteSlice + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same slice", |(v, res)| match res { + Err(Error::UnexpectedSlice(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteSlice(v, res)) + } + + #[allow(dead_code)] + pub fn arb_extra_write() -> impl Strategy<Value = Operation> { + prop_oneof![ + any::<u64>().prop_map(|msg| { + Operation::WriteNumber(msg, Err(Error::ExtraWrite(OperationType::WriteNumber))) + }), + any::<Vec<u8>>().prop_map(|msg| { + Operation::WriteSlice(msg, Err(Error::ExtraWrite(OperationType::WriteSlice))) + }), + any::<String>().prop_map(|msg| { + Operation::WriteDisplay(msg, Err(Error::ExtraWrite(OperationType::WriteDisplay))) + }), + ] + } + + pub fn arb_write_display_operation() -> impl Strategy<Value = Operation> { + ( + any::<String>(), + prop_oneof![ + Just(Ok(())), + any::<String>().prop_map(|v| Err(Error::UnexpectedDisplay(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteNumber, + OperationType::WriteDisplay + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteSlice, + OperationType::WriteDisplay + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same string", |(v, res)| match res { + Err(Error::UnexpectedDisplay(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteDisplay(v, res)) + } + + pub fn arb_operation() -> impl Strategy<Value = Operation> { + prop_oneof![ + arb_write_number_operation(), + arb_write_slice_operation(), + arb_write_display_operation(), + ] + } + + impl Arbitrary for Operation { + type Parameters = (); + type Strategy = BoxedStrategy<Operation>; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + arb_operation().boxed() + } + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use proptest::prelude::any; + use proptest::prelude::TestCaseError; + use proptest::proptest; + + use crate::wire::ser::mock::proptest::arb_extra_write; + use crate::wire::ser::mock::Operation; + use crate::wire::ser::mock::OperationType; + use crate::wire::ser::Error as _; + use crate::wire::ser::NixWrite; + + use super::{Builder, Error}; + + #[tokio::test] + async fn write_number() { + let mut mock = Builder::new().write_number(10).build(); + mock.write_number(10).await.unwrap(); + } + + #[tokio::test] + async fn write_number_error() { + let mut mock = Builder::new() + .write_number_error(10, Error::custom("bad number")) + .build(); + assert_eq!( + Err(Error::custom("bad number")), + mock.write_number(10).await + ); + } + + #[tokio::test] + async fn write_number_unexpected() { + let mut mock = Builder::new().write_slice(b"").build(); + assert_eq!( + Err(Error::unexpected_write_number(OperationType::WriteSlice)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn write_number_unexpected_number() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::UnexpectedNumber(11)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn extra_write_number() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteNumber)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn write_slice() { + let mut mock = Builder::new() + .write_slice(&[]) + .write_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .build(); + mock.write_slice(&[]).await.expect("write_slice empty"); + mock.write_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .await + .expect("write_slice"); + } + + #[tokio::test] + async fn write_slice_error() { + let mut mock = Builder::new() + .write_slice_error(&[], Error::custom("bad slice")) + .build(); + assert_eq!(Err(Error::custom("bad slice")), mock.write_slice(&[]).await); + } + + #[tokio::test] + async fn write_slice_unexpected() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::unexpected_write_slice(OperationType::WriteNumber)), + mock.write_slice(b"").await + ); + } + + #[tokio::test] + async fn write_slice_unexpected_slice() { + let mut mock = Builder::new().write_slice(b"").build(); + assert_eq!( + Err(Error::UnexpectedSlice(b"bad slice".to_vec())), + mock.write_slice(b"bad slice").await + ); + } + + #[tokio::test] + async fn extra_write_slice() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteSlice)), + mock.write_slice(b"extra slice").await + ); + } + + #[tokio::test] + async fn write_display() { + let mut mock = Builder::new().write_display("testing").build(); + mock.write_display("testing").await.unwrap(); + } + + #[tokio::test] + async fn write_display_error() { + let mut mock = Builder::new() + .write_display_error("testing", Error::custom("bad number")) + .build(); + assert_eq!( + Err(Error::custom("bad number")), + mock.write_display("testing").await + ); + } + + #[tokio::test] + async fn write_display_unexpected() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::unexpected_write_display(OperationType::WriteNumber)), + mock.write_display("").await + ); + } + + #[tokio::test] + async fn write_display_unexpected_display() { + let mut mock = Builder::new().write_display("").build(); + assert_eq!( + Err(Error::UnexpectedDisplay("bad display".to_string())), + mock.write_display("bad display").await + ); + } + + #[tokio::test] + async fn extra_write_display() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteDisplay)), + mock.write_display("extra slice").await + ); + } + + #[test] + #[should_panic] + fn operations_left() { + let _ = Builder::new().write_number(10).build(); + } + + #[test] + fn proptest_mock() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + proptest!(|( + operations in any::<Vec<Operation>>(), + extra_operations in proptest::collection::vec(arb_extra_write(), 0..3) + )| { + rt.block_on(async { + let mut builder = Builder::new(); + for op in operations.iter() { + builder.write_operation(op); + } + for op in extra_operations.iter() { + builder.write_operation(op); + } + let mut mock = builder.build(); + for op in operations { + mock.prop_assert_operation(op).await?; + } + for op in extra_operations { + mock.prop_assert_operation(op).await?; + } + Ok(()) as Result<(), TestCaseError> + })?; + }); + } +} diff --git a/tvix/nix-compat/src/wire/ser/mod.rs b/tvix/nix-compat/src/wire/ser/mod.rs new file mode 100644 index 000000000000..ef3c6e2e372f --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/mod.rs @@ -0,0 +1,134 @@ +use std::error::Error as StdError; +use std::future::Future; +use std::{fmt, io}; + +use super::ProtocolVersion; + +mod bytes; +mod collections; +#[cfg(feature = "nix-compat-derive")] +mod display; +mod int; +#[cfg(any(test, feature = "test"))] +pub mod mock; +mod writer; + +pub use writer::{NixWriter, NixWriterBuilder}; + +pub trait Error: Sized + StdError { + fn custom<T: fmt::Display>(msg: T) -> Self; + + fn io_error(err: std::io::Error) -> Self { + Self::custom(format_args!("There was an I/O error {}", err)) + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } + + fn invalid_enum<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } +} + +impl Error for io::Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::Other, msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + err + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::InvalidData, msg.to_string()) + } +} + +pub trait NixWrite: Send { + type Error: Error; + + /// Some types are serialized differently depending on the version + /// of the protocol and so this can be used for implementing that. + fn version(&self) -> ProtocolVersion; + + /// Write a single u64 to the protocol. + fn write_number(&mut self, value: u64) -> impl Future<Output = Result<(), Self::Error>> + Send; + + /// Write a slice of bytes to the protocol. + fn write_slice(&mut self, buf: &[u8]) -> impl Future<Output = Result<(), Self::Error>> + Send; + + /// Write a value that implements `std::fmt::Display` to the protocol. + /// The protocol uses many small string formats and instead of allocating + /// a `String` each time we want to write one an implementation of `NixWrite` + /// can instead use `Display` to dump these formats to a reusable buffer. + fn write_display<D>(&mut self, msg: D) -> impl Future<Output = Result<(), Self::Error>> + Send + where + D: fmt::Display + Send, + Self: Sized, + { + async move { + let s = msg.to_string(); + self.write_slice(s.as_bytes()).await + } + } + + /// Write a value to the protocol. + /// Uses `NixSerialize::serialize` to write the value. + fn write_value<V>(&mut self, value: &V) -> impl Future<Output = Result<(), Self::Error>> + Send + where + V: NixSerialize + Send + ?Sized, + Self: Sized, + { + value.serialize(self) + } +} + +impl<T: NixWrite> NixWrite for &mut T { + type Error = T::Error; + + fn version(&self) -> ProtocolVersion { + (**self).version() + } + + fn write_number(&mut self, value: u64) -> impl Future<Output = Result<(), Self::Error>> + Send { + (**self).write_number(value) + } + + fn write_slice(&mut self, buf: &[u8]) -> impl Future<Output = Result<(), Self::Error>> + Send { + (**self).write_slice(buf) + } + + fn write_display<D>(&mut self, msg: D) -> impl Future<Output = Result<(), Self::Error>> + Send + where + D: fmt::Display + Send, + Self: Sized, + { + (**self).write_display(msg) + } + + fn write_value<V>(&mut self, value: &V) -> impl Future<Output = Result<(), Self::Error>> + Send + where + V: NixSerialize + Send + ?Sized, + Self: Sized, + { + (**self).write_value(value) + } +} + +pub trait NixSerialize { + /// Write a value to the writer. + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite; +} + +// Noop +impl NixSerialize for () { + async fn serialize<W>(&self, _writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + Ok(()) + } +} diff --git a/tvix/nix-compat/src/wire/ser/writer.rs b/tvix/nix-compat/src/wire/ser/writer.rs new file mode 100644 index 000000000000..da1c2b18c5e2 --- /dev/null +++ b/tvix/nix-compat/src/wire/ser/writer.rs @@ -0,0 +1,306 @@ +use std::fmt::{self, Write as _}; +use std::future::poll_fn; +use std::io::{self, Cursor}; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use bytes::{Buf, BufMut, BytesMut}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +use crate::wire::{padding_len, ProtocolVersion, EMPTY_BYTES}; + +use super::{Error, NixWrite}; + +pub struct NixWriterBuilder { + buf: Option<BytesMut>, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, +} + +impl Default for NixWriterBuilder { + fn default() -> Self { + Self { + buf: Default::default(), + reserved_buf_size: 8192, + max_buf_size: 8192, + version: Default::default(), + } + } +} + +impl NixWriterBuilder { + pub fn set_buffer(mut self, buf: BytesMut) -> Self { + self.buf = Some(buf); + self + } + + pub fn set_reserved_buf_size(mut self, size: usize) -> Self { + self.reserved_buf_size = size; + self + } + + pub fn set_max_buf_size(mut self, size: usize) -> Self { + self.max_buf_size = size; + self + } + + pub fn set_version(mut self, version: ProtocolVersion) -> Self { + self.version = version; + self + } + + pub fn build<W>(self, writer: W) -> NixWriter<W> { + let buf = self + .buf + .unwrap_or_else(|| BytesMut::with_capacity(self.max_buf_size)); + NixWriter { + buf, + inner: writer, + reserved_buf_size: self.reserved_buf_size, + max_buf_size: self.max_buf_size, + version: self.version, + } + } +} + +pin_project! { + pub struct NixWriter<W> { + #[pin] + inner: W, + buf: BytesMut, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, + } +} + +impl NixWriter<Cursor<Vec<u8>>> { + pub fn builder() -> NixWriterBuilder { + NixWriterBuilder::default() + } +} + +impl<W> NixWriter<W> +where + W: AsyncWriteExt, +{ + pub fn new(writer: W) -> NixWriter<W> { + NixWriter::builder().build(writer) + } + + pub fn buffer(&self) -> &[u8] { + &self.buf[..] + } + + pub fn set_version(&mut self, version: ProtocolVersion) { + self.version = version; + } + + /// Remaining capacity in internal buffer + pub fn remaining_mut(&self) -> usize { + self.buf.capacity() - self.buf.len() + } + + fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + let mut this = self.project(); + while !this.buf.is_empty() { + let n = ready!(this.inner.as_mut().poll_write(cx, &this.buf[..]))?; + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write the buffer", + ))); + } + this.buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl<W> NixWriter<W> +where + W: AsyncWriteExt + Unpin, +{ + async fn flush_buf(&mut self) -> Result<(), io::Error> { + let mut s = Pin::new(self); + poll_fn(move |cx| s.as_mut().poll_flush_buf(cx)).await + } +} + +impl<W> AsyncWrite for NixWriter<W> +where + W: AsyncWrite, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + // Flush + if self.remaining_mut() < buf.len() { + ready!(self.as_mut().poll_flush_buf(cx))?; + } + let this = self.project(); + if buf.len() > this.buf.capacity() { + this.inner.poll_write(cx, buf) + } else { + this.buf.put_slice(buf); + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + ready!(self.as_mut().poll_flush_buf(cx))?; + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + ready!(self.as_mut().poll_flush_buf(cx))?; + self.project().inner.poll_shutdown(cx) + } +} + +impl<W> NixWrite for NixWriter<W> +where + W: AsyncWrite + Send + Unpin, +{ + type Error = io::Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { + let mut buf = [0u8; 8]; + BufMut::put_u64_le(&mut &mut buf[..], value); + self.write_all(&buf).await + } + + async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + let padding = padding_len(buf.len() as u64) as usize; + self.write_value(&buf.len()).await?; + self.write_all(buf).await?; + if padding > 0 { + self.write_all(&EMPTY_BYTES[..padding]).await + } else { + Ok(()) + } + } + + async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error> + where + D: fmt::Display + Send, + Self: Sized, + { + // Ensure that buffer has space for at least reserved_buf_size bytes + if self.remaining_mut() < self.reserved_buf_size && !self.buf.is_empty() { + self.flush_buf().await?; + } + let offset = self.buf.len(); + self.buf.put_u64_le(0); + if let Err(err) = write!(self.buf, "{}", msg) { + self.buf.truncate(offset); + return Err(Self::Error::unsupported_data(err)); + } + let len = self.buf.len() - offset - 8; + BufMut::put_u64_le(&mut &mut self.buf[offset..(offset + 8)], len as u64); + let padding = padding_len(len as u64) as usize; + self.write_all(&EMPTY_BYTES[..padding]).await + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::wire::ser::NixWrite; + + use super::NixWriter; + + #[rstest] + #[case(1, &hex!("0100 0000 0000 0000"))] + #[case::evil(666, &hex!("9A02 0000 0000 0000"))] + #[case::max(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_write_number(#[case] number: u64, #[case] buf: &[u8]) { + let mock = Builder::new().write(buf).build(); + let mut writer = NixWriter::new(mock); + + writer.write_number(number).await.unwrap(); + assert_eq!(writer.buffer(), buf); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } + + #[rstest] + #[case::empty(b"", &hex!("0000 0000 0000 0000"))] + #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_write_slice( + #[case] value: &[u8], + #[case] buf: &[u8], + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] buf_size: usize, + ) { + let mut builder = Builder::new(); + for chunk in buf.chunks(chunks_size) { + builder.write(chunk); + builder.wait(Duration::ZERO); + } + let mock = builder.build(); + let mut writer = NixWriter::builder().set_max_buf_size(buf_size).build(mock); + + writer.write_slice(value).await.unwrap(); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_write_display( + #[case] value: &str, + #[case] buf: &[u8], + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, + ) { + let mut builder = Builder::new(); + for chunk in buf.chunks(chunks_size) { + builder.write(chunk); + builder.wait(Duration::ZERO); + } + let mock = builder.build(); + let mut writer = NixWriter::builder().build(mock); + + writer.write_display(value).await.unwrap(); + assert_eq!(writer.buffer(), buf); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } +} |