diff options
Diffstat (limited to 'tvix')
20 files changed, 2339 insertions, 15 deletions
diff --git a/tvix/Cargo.lock b/tvix/Cargo.lock index 2984479c18ef..73badc2af496 100644 --- a/tvix/Cargo.lock +++ b/tvix/Cargo.lock @@ -2302,6 +2302,7 @@ dependencies = [ "num-traits", "pin-project-lite", "pretty_assertions", + "proptest", "rstest", "serde", "serde_json", @@ -2322,6 +2323,7 @@ dependencies = [ "nix-compat", "pretty_assertions", "proc-macro2", + "proptest", "quote", "rstest", "syn 2.0.79", diff --git a/tvix/Cargo.nix b/tvix/Cargo.nix index 749f39fb93cc..a09cbc016a23 100644 --- a/tvix/Cargo.nix +++ b/tvix/Cargo.nix @@ -7299,6 +7299,12 @@ rec { packageId = "pretty_assertions"; } { + name = "proptest"; + packageId = "proptest"; + usesDefaultFeatures = false; + features = [ "std" "alloc" "tempfile" ]; + } + { name = "rstest"; packageId = "rstest"; } @@ -7370,6 +7376,12 @@ rec { packageId = "pretty_assertions"; } { + name = "proptest"; + packageId = "proptest"; + usesDefaultFeatures = false; + features = [ "std" "alloc" "tempfile" ]; + } + { name = "rstest"; packageId = "rstest"; } diff --git a/tvix/nix-compat-derive-tests/tests/write_derive.rs b/tvix/nix-compat-derive-tests/tests/write_derive.rs new file mode 100644 index 000000000000..1ed5dbf7e735 --- /dev/null +++ b/tvix/nix-compat-derive-tests/tests/write_derive.rs @@ -0,0 +1,370 @@ +use std::fmt; + +use nix_compat::nix_daemon::ser::{ + mock::{Builder, Error}, + NixWrite as _, +}; +use nix_compat_derive::NixSerialize; + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct UnitTest; + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct EmptyTupleTest(); + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct StructTest { + first: u64, + second: String, +} + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct TupleTest(u64, String); + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct StructVersionTest { + test: u64, + #[nix(version = "20..")] + hello: String, +} + +fn default_test() -> StructVersionTest { + StructVersionTest { + test: 89, + hello: String::from("klomp"), + } +} + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct TupleVersionTest(u64, #[nix(version = "25..")] String); + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct TupleVersionDefaultTest(u64, #[nix(version = "..25")] StructVersionTest); + +#[tokio::test] +async fn write_unit() { + let mut mock = Builder::new().build(); + mock.write_value(&UnitTest).await.unwrap(); +} + +#[tokio::test] +async fn write_empty_tuple() { + let mut mock = Builder::new().build(); + mock.write_value(&EmptyTupleTest()).await.unwrap(); +} + +#[tokio::test] +async fn write_struct() { + let mut mock = Builder::new() + .write_number(89) + .write_slice(b"klomp") + .build(); + mock.write_value(&StructTest { + first: 89, + second: String::from("klomp"), + }) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_tuple() { + let mut mock = Builder::new() + .write_number(89) + .write_slice(b"klomp") + .build(); + mock.write_value(&TupleTest(89, String::from("klomp"))) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_struct_version() { + let mut mock = Builder::new() + .version((1, 20)) + .write_number(89) + .write_slice(b"klomp") + .build(); + mock.write_value(&default_test()).await.unwrap(); +} + +#[tokio::test] +async fn write_struct_without_version() { + let mut mock = Builder::new().version((1, 19)).write_number(89).build(); + mock.write_value(&StructVersionTest { + test: 89, + hello: String::new(), + }) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_tuple_version() { + let mut mock = Builder::new() + .version((1, 26)) + .write_number(89) + .write_slice(b"klomp") + .build(); + mock.write_value(&TupleVersionTest(89, "klomp".into())) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_tuple_without_version() { + let mut mock = Builder::new().version((1, 19)).write_number(89).build(); + mock.write_value(&TupleVersionTest(89, String::new())) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_complex_1() { + let mut mock = Builder::new() + .version((1, 19)) + .write_number(999) + .write_number(666) + .build(); + mock.write_value(&TupleVersionDefaultTest( + 999, + StructVersionTest { + test: 666, + hello: String::new(), + }, + )) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_complex_2() { + let mut mock = Builder::new() + .version((1, 20)) + .write_number(999) + .write_number(666) + .write_slice(b"The quick brown \xF0\x9F\xA6\x8A jumps over 13 lazy \xF0\x9F\x90\xB6.") + .build(); + mock.write_value(&TupleVersionDefaultTest( + 999, + StructVersionTest { + test: 666, + hello: String::from("The quick brown 🦊 jumps over 13 lazy 🐶."), + }, + )) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_complex_3() { + let mut mock = Builder::new().version((1, 25)).write_number(999).build(); + mock.write_value(&TupleVersionDefaultTest( + 999, + StructVersionTest { + test: 89, + hello: String::from("klomp"), + }, + )) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_complex_4() { + let mut mock = Builder::new().version((1, 26)).write_number(999).build(); + mock.write_value(&TupleVersionDefaultTest( + 999, + StructVersionTest { + test: 89, + hello: String::from("klomp"), + }, + )) + .await + .unwrap(); +} + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +#[nix(display)] +struct TestFromStr; + +impl fmt::Display for TestFromStr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "test") + } +} + +#[tokio::test] +async fn write_display() { + let mut mock = Builder::new().write_display("test").build(); + mock.write_value(&TestFromStr).await.unwrap(); +} + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +#[nix(display = "TestFromStr2::display")] +struct TestFromStr2; +struct TestFromStrDisplay; + +impl fmt::Display for TestFromStrDisplay { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "test") + } +} +impl TestFromStr2 { + fn display(&self) -> TestFromStrDisplay { + TestFromStrDisplay + } +} + +#[tokio::test] +async fn write_display_path() { + let mut mock = Builder::new().write_display("test").build(); + mock.write_value(&TestFromStr2).await.unwrap(); +} + +#[derive(Clone, Debug, PartialEq, Eq, NixSerialize)] +#[nix(try_into = "u64")] +struct TestTryFromU64(u64); + +impl TryFrom<TestTryFromU64> for u64 { + type Error = u64; + + fn try_from(value: TestTryFromU64) -> Result<Self, Self::Error> { + if value.0 != 42 { + Ok(value.0) + } else { + Err(value.0) + } + } +} + +#[tokio::test] +async fn write_try_into_u64() { + let mut mock = Builder::new().write_number(666).build(); + mock.write_value(&TestTryFromU64(666)).await.unwrap(); +} + +#[tokio::test] +async fn write_try_into_u64_invalid_data() { + let mut mock = Builder::new().build(); + let err = mock.write_value(&TestTryFromU64(42)).await.unwrap_err(); + assert_eq!(Error::UnsupportedData("42".into()), err); +} + +#[derive(Clone, Debug, PartialEq, Eq, NixSerialize)] +#[nix(into = "u64")] +struct TestFromU64; + +impl From<TestFromU64> for u64 { + fn from(_value: TestFromU64) -> u64 { + 42 + } +} + +#[tokio::test] +async fn write_into_u64() { + let mut mock = Builder::new().write_number(42).build(); + mock.write_value(&TestFromU64).await.unwrap(); +} + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +enum TestEnum { + #[nix(version = "..=19")] + Pre20(TestFromU64, #[nix(version = "10..")] u64), + #[nix(version = "20..=29")] + Post20(StructVersionTest), + #[nix(version = "30..=39")] + Post30, + #[nix(version = "40..")] + Post40 { + msg: String, + #[nix(version = "45..")] + level: u64, + }, +} + +#[tokio::test] +async fn write_enum_9() { + let mut mock = Builder::new().version((1, 9)).write_number(42).build(); + mock.write_value(&TestEnum::Pre20(TestFromU64, 666)) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_enum_19() { + let mut mock = Builder::new() + .version((1, 19)) + .write_number(42) + .write_number(666) + .build(); + mock.write_value(&TestEnum::Pre20(TestFromU64, 666)) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_enum_20() { + let mut mock = Builder::new() + .version((1, 20)) + .write_number(666) + .write_slice(b"klomp") + .build(); + mock.write_value(&TestEnum::Post20(StructVersionTest { + test: 666, + hello: "klomp".into(), + })) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_enum_30() { + let mut mock = Builder::new().version((1, 30)).build(); + mock.write_value(&TestEnum::Post30).await.unwrap(); +} + +#[tokio::test] +async fn write_enum_40() { + let mut mock = Builder::new() + .version((1, 40)) + .write_slice(b"hello world") + .build(); + mock.write_value(&TestEnum::Post40 { + msg: "hello world".into(), + level: 9001, + }) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_enum_45() { + let mut mock = Builder::new() + .version((1, 45)) + .write_slice(b"hello world") + .write_number(9001) + .build(); + mock.write_value(&TestEnum::Post40 { + msg: "hello world".into(), + level: 9001, + }) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_wrong_enum() { + let mut mock = Builder::new().version((1, 30)).build(); + let err = mock + .write_value(&TestEnum::Post40 { + msg: "hello world".into(), + level: 9001, + }) + .await + .unwrap_err(); + assert_eq!( + err, + Error::InvalidEnum("Post40 is not valid for version 1.30".into()) + ) +} diff --git a/tvix/nix-compat-derive/Cargo.toml b/tvix/nix-compat-derive/Cargo.toml index da6d6744e650..bc656b42842c 100644 --- a/tvix/nix-compat-derive/Cargo.toml +++ b/tvix/nix-compat-derive/Cargo.toml @@ -14,6 +14,7 @@ syn = { version = "2.0.76", features = ["full", "extra-traits"] } [dev-dependencies] hex-literal = { workspace = true } pretty_assertions = { workspace = true } +proptest = { workspace = true, features = ["std", "alloc", "tempfile"] } rstest = { workspace = true } tokio-test = { workspace = true } tokio = { workspace = true, features = ["io-util", "macros"] } diff --git a/tvix/nix-compat-derive/src/de.rs b/tvix/nix-compat-derive/src/de.rs index 2214254e2b32..e678b50b0533 100644 --- a/tvix/nix-compat-derive/src/de.rs +++ b/tvix/nix-compat-derive/src/de.rs @@ -34,6 +34,11 @@ pub fn expand_nix_deserialize_remote( ) -> syn::Result<TokenStream> { let cx = Context::new(); let remote = Remote::from_ast(&cx, crate_path, input); + if let Some(attrs) = remote.as_ref().map(|r| &r.attrs) { + if attrs.from_str.is_none() && attrs.type_from.is_none() && attrs.type_try_from.is_none() { + cx.error_spanned(input, "Missing from_str, from or try_from attribute"); + } + } cx.check()?; let remote = remote.unwrap(); diff --git a/tvix/nix-compat-derive/src/internal/attrs.rs b/tvix/nix-compat-derive/src/internal/attrs.rs index 9ed84aaf8745..d0fa3b008e22 100644 --- a/tvix/nix-compat-derive/src/internal/attrs.rs +++ b/tvix/nix-compat-derive/src/internal/attrs.rs @@ -3,7 +3,9 @@ use syn::meta::ParseNestedMeta; use syn::parse::Parse; use syn::{parse_quote, Attribute, Expr, ExprLit, ExprPath, Lit, Token}; -use super::symbol::{Symbol, CRATE, DEFAULT, FROM, FROM_STR, NIX, TRY_FROM, VERSION}; +use super::symbol::{ + Symbol, CRATE, DEFAULT, DISPLAY, FROM, FROM_STR, INTO, NIX, TRY_FROM, TRY_INTO, VERSION, +}; use super::Context; #[derive(Debug, PartialEq, Eq)] @@ -104,6 +106,9 @@ pub struct Container { pub from_str: Option<syn::Path>, pub type_from: Option<syn::Type>, pub type_try_from: Option<syn::Type>, + pub type_into: Option<syn::Type>, + pub type_try_into: Option<syn::Type>, + pub display: Default, pub crate_path: Option<syn::Path>, } @@ -113,6 +118,9 @@ impl Container { let mut type_try_from = None; let mut crate_path = None; let mut from_str = None; + let mut type_into = None; + let mut type_try_into = None; + let mut display = Default::None; for attr in attrs { if attr.path() != NIX { @@ -125,6 +133,18 @@ impl Container { type_try_from = parse_lit(ctx, &meta, TRY_FROM)?; } else if meta.path == FROM_STR { from_str = Some(meta.path); + } else if meta.path == INTO { + type_into = parse_lit(ctx, &meta, INTO)?; + } else if meta.path == TRY_INTO { + type_try_into = parse_lit(ctx, &meta, TRY_INTO)?; + } else if meta.path == DISPLAY { + if meta.input.peek(Token![=]) { + if let Some(path) = parse_lit(ctx, &meta, DISPLAY)? { + display = Default::Path(path); + } + } else { + display = Default::Default(meta.path); + } } else if meta.path == CRATE { crate_path = parse_lit(ctx, &meta, CRATE)?; } else { @@ -144,6 +164,9 @@ impl Container { from_str, type_from, type_try_from, + type_into, + type_try_into, + display, crate_path, } } @@ -342,6 +365,46 @@ mod test { } #[test] + fn parse_container_from_str() { + let attrs: Vec<Attribute> = vec![parse_quote!(#[nix(from_str)])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: Some(parse_quote!(from_str)), + type_from: None, + type_try_from: None, + type_into: None, + type_try_into: None, + display: Default::None, + crate_path: None, + } + ); + } + + #[test] + fn parse_container_from() { + let attrs: Vec<Attribute> = vec![parse_quote!(#[nix(from="u64")])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: None, + type_from: Some(parse_quote!(u64)), + type_try_from: None, + type_into: None, + type_try_into: None, + display: Default::None, + crate_path: None, + } + ); + } + + #[test] fn parse_container_try_from() { let attrs: Vec<Attribute> = vec![parse_quote!(#[nix(try_from="u64")])]; let ctx = Context::new(); @@ -353,6 +416,89 @@ mod test { from_str: None, type_from: None, type_try_from: Some(parse_quote!(u64)), + type_into: None, + type_try_into: None, + display: Default::None, + crate_path: None, + } + ); + } + + #[test] + fn parse_container_into() { + let attrs: Vec<Attribute> = vec![parse_quote!(#[nix(into="u64")])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: None, + type_from: None, + type_try_from: None, + type_into: Some(parse_quote!(u64)), + type_try_into: None, + display: Default::None, + crate_path: None, + } + ); + } + + #[test] + fn parse_container_try_into() { + let attrs: Vec<Attribute> = vec![parse_quote!(#[nix(try_into="u64")])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: None, + type_from: None, + type_try_from: None, + type_into: None, + type_try_into: Some(parse_quote!(u64)), + display: Default::None, + crate_path: None, + } + ); + } + + #[test] + fn parse_container_display() { + let attrs: Vec<Attribute> = vec![parse_quote!(#[nix(display)])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: None, + type_from: None, + type_try_from: None, + type_into: None, + type_try_into: None, + display: Default::Default(parse_quote!(display)), + crate_path: None, + } + ); + } + + #[test] + fn parse_container_display_path() { + let attrs: Vec<Attribute> = vec![parse_quote!(#[nix(display="Path::display")])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: None, + type_from: None, + type_try_from: None, + type_into: None, + type_try_into: None, + display: Default::Path(parse_quote!(Path::display)), crate_path: None, } ); diff --git a/tvix/nix-compat-derive/src/internal/mod.rs b/tvix/nix-compat-derive/src/internal/mod.rs index 07ef43b6e0bb..aa42d904718d 100644 --- a/tvix/nix-compat-derive/src/internal/mod.rs +++ b/tvix/nix-compat-derive/src/internal/mod.rs @@ -154,10 +154,6 @@ impl<'a> Remote<'a> { input: &'a inputs::RemoteInput, ) -> Option<Remote<'a>> { let attrs = attrs::Container::from_ast(ctx, &input.attrs); - if attrs.from_str.is_none() && attrs.type_from.is_none() && attrs.type_try_from.is_none() { - ctx.error_spanned(input, "Missing from_str, from or try_from attribute"); - return None; - } Some(Remote { ty: &input.ident, attrs, diff --git a/tvix/nix-compat-derive/src/internal/symbol.rs b/tvix/nix-compat-derive/src/internal/symbol.rs index ed3fe304eb5d..2bbdc069aa0f 100644 --- a/tvix/nix-compat-derive/src/internal/symbol.rs +++ b/tvix/nix-compat-derive/src/internal/symbol.rs @@ -11,6 +11,9 @@ pub const DEFAULT: Symbol = Symbol("default"); pub const FROM: Symbol = Symbol("from"); pub const TRY_FROM: Symbol = Symbol("try_from"); pub const FROM_STR: Symbol = Symbol("from_str"); +pub const INTO: Symbol = Symbol("into"); +pub const TRY_INTO: Symbol = Symbol("try_into"); +pub const DISPLAY: Symbol = Symbol("display"); pub const CRATE: Symbol = Symbol("crate"); impl PartialEq<Symbol> for Path { diff --git a/tvix/nix-compat-derive/src/lib.rs b/tvix/nix-compat-derive/src/lib.rs index 89735cadf315..394473b1cbf8 100644 --- a/tvix/nix-compat-derive/src/lib.rs +++ b/tvix/nix-compat-derive/src/lib.rs @@ -6,7 +6,11 @@ //! 1. [`#[nix(from_str)]`](#nixfrom_str) //! 2. [`#[nix(from = "FromType")]`](#nixfrom--fromtype) //! 3. [`#[nix(try_from = "FromType")]`](#nixtry_from--fromtype) -//! 4. [`#[nix(crate = "...")]`](#nixcrate--) +//! 4. [`#[nix(into = "IntoType")]`](#nixinto--intotype) +//! 5. [`#[nix(try_into = "IntoType")]`](#nixtry_into--intotype) +//! 6. [`#[nix(display)]`](#nixdisplay) +//! 7. [`#[nix(display = "path")]`](#nixdisplay--path) +//! 8. [`#[nix(crate = "...")]`](#nixcrate--) //! 2. [Variant attributes](#variant-attributes) //! 1. [`#[nix(version = "range")]`](#nixversion--range) //! 3. [Field attributes](#field-attributes) @@ -17,20 +21,21 @@ //! ## Overview //! //! This crate contains derive macros and function-like macros for implementing -//! `NixDeserialize` with less boilerplate. +//! `NixDeserialize` and `NixSerialize` with less boilerplate. //! //! ### Examples +//! //! ```rust -//! # use nix_compat_derive::NixDeserialize; +//! # use nix_compat_derive::{NixDeserialize, NixSerialize}; //! # -//! #[derive(NixDeserialize)] +//! #[derive(NixDeserialize, NixSerialize)] //! struct Unnamed(u64, String); //! ``` //! //! ```rust -//! # use nix_compat_derive::NixDeserialize; +//! # use nix_compat_derive::{NixDeserialize, NixSerialize}; //! # -//! #[derive(NixDeserialize)] +//! #[derive(NixDeserialize, NixSerialize)] //! struct Fields { //! number: u64, //! message: String, @@ -38,9 +43,9 @@ //! ``` //! //! ```rust -//! # use nix_compat_derive::NixDeserialize; +//! # use nix_compat_derive::{NixDeserialize, NixSerialize}; //! # -//! #[derive(NixDeserialize)] +//! #[derive(NixDeserialize, NixSerialize)] //! struct Ignored; //! ``` //! @@ -64,7 +69,7 @@ //! #[derive(NixDeserialize)] //! #[nix(crate="nix_compat")] // <-- This is also a container attribute //! enum E { -//! #[nix(version="..=9")] // <-- This is a variant attribute +//! #[nix(version="..10")] // <-- This is a variant attribute //! A(u64), //! #[nix(version="10..")] // <-- This is also a variant attribute //! B(String), @@ -156,6 +161,114 @@ //! } //! ``` //! +//! ##### `#[nix(into = "IntoType")]` +//! +//! When `into` is specified the fields are all ignored and instead the +//! container type is converted to `IntoType` using `Into::into` and +//! `IntoType` is then serialized. Before converting `Clone::clone` is +//! called. +//! +//! This means that the container must implement `Into<IntoType>` and `Clone` +//! and `IntoType` must implement `NixSerialize`. +//! +//! ###### Example +//! +//! ```rust +//! # use nix_compat_derive::NixSerialize; +//! # +//! #[derive(Clone, NixSerialize)] +//! #[nix(into="usize")] +//! struct MyValue(usize); +//! impl From<MyValue> for usize { +//! fn from(val: MyValue) -> Self { +//! val.0 +//! } +//! } +//! ``` +//! +//! ##### `#[nix(try_into = "IntoType")]` +//! +//! When `try_into` is specified the fields are all ignored and instead the +//! container type is converted to `IntoType` using `TryInto::try_into` and +//! `IntoType` is then serialized. Before converting `Clone::clone` is +//! called. +//! +//! This means that the container must implement `TryInto<IntoType>` and +//! `Clone` and `IntoType` must implement `NixSerialize`. +//! The error returned from `try_into` also needs to implement `Display`. +//! +//! ###### Example +//! +//! ```rust +//! # use nix_compat_derive::NixSerialize; +//! # +//! #[derive(Clone, NixSerialize)] +//! #[nix(try_into="usize")] +//! struct WrongAnswer(usize); +//! impl TryFrom<WrongAnswer> for usize { +//! type Error = String; +//! fn try_from(val: WrongAnswer) -> Result<Self, Self::Error> { +//! if val.0 != 42 { +//! Ok(val.0) +//! } else { +//! Err("Got the answer to life the universe and everything".to_string()) +//! } +//! } +//! } +//! ``` +//! +//! ##### `#[nix(display)]` +//! +//! When `display` is specified the fields are all ignored and instead the +//! container must implement `Display` and `NixWrite::write_display` is used to +//! write the container. +//! +//! ###### Example +//! +//! ```rust +//! # use nix_compat_derive::NixSerialize; +//! # use std::fmt::{Display, Result, Formatter}; +//! # +//! #[derive(NixSerialize)] +//! #[nix(display)] +//! struct WrongAnswer(usize); +//! impl Display for WrongAnswer { +//! fn fmt(&self, f: &mut Formatter<'_>) -> Result { +//! write!(f, "Wrong Answer = {}", self.0) +//! } +//! } +//! ``` +//! +//! ##### `#[nix(display = "path")]` +//! +//! When `display` is specified the fields are all ignored and instead the +//! container the specified path must point to a function that is callable as +//! `fn(&T) -> impl Display`. The result from this call is then written with +//! `NixWrite::write_display`. +//! For example `default = "my_value"` would call `my_value(&self)` and `display = +//! "AType::empty"` would call `AType::empty(&self)`. +//! +//! ###### Example +//! +//! ```rust +//! # use nix_compat_derive::NixSerialize; +//! # use std::fmt::{Display, Result, Formatter}; +//! # +//! #[derive(NixSerialize)] +//! #[nix(display = "format_it")] +//! struct WrongAnswer(usize); +//! struct WrongDisplay<'a>(&'a WrongAnswer); +//! impl<'a> Display for WrongDisplay<'a> { +//! fn fmt(&self, f: &mut Formatter<'_>) -> Result { +//! write!(f, "Wrong Answer = {}", self.0.0) +//! } +//! } +//! +//! fn format_it(value: &WrongAnswer) -> impl Display + '_ { +//! WrongDisplay(value) +//! } +//! ``` +//! //! ##### `#[nix(crate = "...")]` //! //! Specify the path to the `nix-compat` crate instance to use when referring @@ -175,6 +288,7 @@ //! //! ```rust //! # use nix_compat_derive::NixDeserialize; +//! # //! #[derive(NixDeserialize)] //! enum Testing { //! #[nix(version="..=18")] @@ -260,6 +374,7 @@ use syn::{parse_quote, DeriveInput}; mod de; mod internal; +mod ser; #[proc_macro_derive(NixDeserialize, attributes(nix))] pub fn derive_nix_deserialize(item: TokenStream) -> TokenStream { @@ -270,6 +385,15 @@ pub fn derive_nix_deserialize(item: TokenStream) -> TokenStream { .into() } +#[proc_macro_derive(NixSerialize, attributes(nix))] +pub fn derive_nix_serialize(item: TokenStream) -> TokenStream { + let mut input = syn::parse_macro_input!(item as DeriveInput); + let crate_path: syn::Path = parse_quote!(::nix_compat); + ser::expand_nix_serialize(crate_path, &mut input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + /// Macro to implement `NixDeserialize` on a type. /// Sometimes you can't use the deriver to implement `NixDeserialize` /// (like when dealing with types in Rust standard library) but don't want @@ -301,3 +425,36 @@ pub fn nix_deserialize_remote(item: TokenStream) -> TokenStream { .unwrap_or_else(syn::Error::into_compile_error) .into() } + +/// Macro to implement `NixSerialize` on a type. +/// Sometimes you can't use the deriver to implement `NixSerialize` +/// (like when dealing with types in Rust standard library) but don't want +/// to implement it yourself. So this macro can be used for those situations +/// where you would derive using `#[nix(display)]`, `#[nix(display = "path")]`, +/// `#[nix(store_dir_display)]`, `#[nix(into = "IntoType")]` or +/// `#[nix(try_into = "IntoType")]` if you could. +/// +/// #### Example +/// +/// ```rust +/// # use nix_compat_derive::nix_serialize_remote; +/// # +/// #[derive(Clone)] +/// struct MyU64(u64); +/// +/// impl From<MyU64> for u64 { +/// fn from(value: MyU64) -> Self { +/// value.0 +/// } +/// } +/// +/// nix_serialize_remote!(#[nix(into="u64")] MyU64); +/// ``` +#[proc_macro] +pub fn nix_serialize_remote(item: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(item as RemoteInput); + let crate_path = parse_quote!(::nix_compat); + ser::expand_nix_serialize_remote(crate_path, &input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/tvix/nix-compat-derive/src/ser.rs b/tvix/nix-compat-derive/src/ser.rs new file mode 100644 index 000000000000..47ddfa39366d --- /dev/null +++ b/tvix/nix-compat-derive/src/ser.rs @@ -0,0 +1,227 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; +use syn::{DeriveInput, Generics, Path, Type}; + +use crate::internal::attrs::Default; +use crate::internal::inputs::RemoteInput; +use crate::internal::{attrs, Container, Context, Data, Field, Remote, Style, Variant}; + +pub fn expand_nix_serialize(crate_path: Path, input: &mut DeriveInput) -> syn::Result<TokenStream> { + let cx = Context::new(); + let cont = Container::from_ast(&cx, crate_path, input); + cx.check()?; + let cont = cont.unwrap(); + + let ty = cont.ident_type(); + let body = nix_serialize_body(&cont); + let crate_path = cont.crate_path(); + + Ok(nix_serialize_impl( + crate_path, + &ty, + &cont.original.generics, + body, + )) +} + +pub fn expand_nix_serialize_remote( + crate_path: Path, + input: &RemoteInput, +) -> syn::Result<TokenStream> { + let cx = Context::new(); + let remote = Remote::from_ast(&cx, crate_path, input); + if let Some(attrs) = remote.as_ref().map(|r| &r.attrs) { + if attrs.display.is_none() && attrs.type_into.is_none() && attrs.type_try_into.is_none() { + cx.error_spanned(input, "Missing into, try_into or display attribute"); + } + } + cx.check()?; + let remote = remote.unwrap(); + + let crate_path = remote.crate_path(); + let body = nix_serialize_body_into(crate_path, &remote.attrs).expect("From tokenstream"); + let generics = Generics::default(); + Ok(nix_serialize_impl(crate_path, remote.ty, &generics, body)) +} + +fn nix_serialize_impl( + crate_path: &Path, + ty: &Type, + generics: &Generics, + body: TokenStream, +) -> TokenStream { + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + quote! { + #[automatically_derived] + impl #impl_generics #crate_path::nix_daemon::ser::NixSerialize for #ty #ty_generics + #where_clause + { + async fn serialize<W>(&self, writer: &mut W) -> std::result::Result<(), W::Error> + where W: #crate_path::nix_daemon::ser::NixWrite + { + use #crate_path::nix_daemon::ser::Error as _; + #body + } + } + } +} + +fn nix_serialize_body_into( + crate_path: &syn::Path, + attrs: &attrs::Container, +) -> Option<TokenStream> { + if let Default::Default(span) = &attrs.display { + Some(nix_serialize_display(span.span())) + } else if let Default::Path(path) = &attrs.display { + Some(nix_serialize_display_path(path)) + } else if let Some(type_into) = attrs.type_into.as_ref() { + Some(nix_serialize_into(type_into)) + } else { + attrs + .type_try_into + .as_ref() + .map(|type_try_into| nix_serialize_try_into(crate_path, type_try_into)) + } +} + +fn nix_serialize_body(cont: &Container) -> TokenStream { + if let Some(tokens) = nix_serialize_body_into(cont.crate_path(), &cont.attrs) { + tokens + } else { + match &cont.data { + Data::Struct(_style, fields) => nix_serialize_struct(fields), + Data::Enum(variants) => nix_serialize_enum(variants), + } + } +} + +fn nix_serialize_struct(fields: &[Field<'_>]) -> TokenStream { + let write_fields = fields.iter().map(|f| { + let field = &f.member; + let ty = f.ty; + let write_value = quote_spanned! { + ty.span()=> writer.write_value(&self.#field).await? + }; + if let Some(version) = f.attrs.version.as_ref() { + quote! { + if (#version).contains(&writer.version().minor()) { + #write_value; + } + } + } else { + quote! { + #write_value; + } + } + }); + + quote! { + #(#write_fields)* + Ok(()) + } +} + +fn nix_serialize_variant(variant: &Variant<'_>) -> TokenStream { + let ident = variant.ident; + let write_fields = variant.fields.iter().map(|f| { + let field = f.var_ident(); + let ty = f.ty; + let write_value = quote_spanned! { + ty.span()=> writer.write_value(#field).await? + }; + if let Some(version) = f.attrs.version.as_ref() { + quote! { + if (#version).contains(&writer.version().minor()) { + #write_value; + } + } + } else { + quote! { + #write_value; + } + } + }); + let field_names = variant.fields.iter().map(|f| f.var_ident()); + let destructure = match variant.style { + Style::Struct => { + quote! { + Self::#ident { #(#field_names),* } + } + } + Style::Tuple => { + quote! { + Self::#ident(#(#field_names),*) + } + } + Style::Unit => quote!(Self::#ident), + }; + let ignore = match variant.style { + Style::Struct => { + quote! { + Self::#ident { .. } + } + } + Style::Tuple => { + quote! { + Self::#ident(_, ..) + } + } + Style::Unit => quote!(Self::#ident), + }; + let version = &variant.attrs.version; + quote! { + #destructure if (#version).contains(&writer.version().minor()) => { + #(#write_fields)* + } + #ignore => { + return Err(W::Error::invalid_enum(format!("{} is not valid for version {}", stringify!(#ident), writer.version()))); + } + } +} + +fn nix_serialize_enum(variants: &[Variant<'_>]) -> TokenStream { + let match_variant = variants + .iter() + .map(|variant| nix_serialize_variant(variant)); + quote! { + match self { + #(#match_variant)* + } + Ok(()) + } +} + +fn nix_serialize_into(ty: &Type) -> TokenStream { + quote_spanned! { + ty.span() => + { + let other : #ty = <Self as Clone>::clone(self).into(); + writer.write_value(&other).await + } + } +} + +fn nix_serialize_try_into(crate_path: &Path, ty: &Type) -> TokenStream { + quote_spanned! { + ty.span() => + { + use #crate_path::nix_daemon::ser::Error; + let other : #ty = <Self as Clone>::clone(self).try_into().map_err(Error::unsupported_data)?; + writer.write_value(&other).await + } + } +} + +fn nix_serialize_display(span: Span) -> TokenStream { + quote_spanned! { + span => writer.write_display(self).await + } +} + +fn nix_serialize_display_path(path: &syn::ExprPath) -> TokenStream { + quote_spanned! { + path.span() => writer.write_display(#path(self)).await + } +} diff --git a/tvix/nix-compat/Cargo.toml b/tvix/nix-compat/Cargo.toml index f430a5461829..cbbf97175d14 100644 --- a/tvix/nix-compat/Cargo.toml +++ b/tvix/nix-compat/Cargo.toml @@ -43,6 +43,7 @@ futures = { workspace = true } hex-literal = { workspace = true } mimalloc = { workspace = true } pretty_assertions = { workspace = true } +proptest = { workspace = true, features = ["std", "alloc", "tempfile"] } rstest = { workspace = true } serde_json = { workspace = true } smol_str = { workspace = true } diff --git a/tvix/nix-compat/src/nix_daemon/mod.rs b/tvix/nix-compat/src/nix_daemon/mod.rs index 11413e85fd1b..a943b279f891 100644 --- a/tvix/nix-compat/src/nix_daemon/mod.rs +++ b/tvix/nix-compat/src/nix_daemon/mod.rs @@ -4,3 +4,4 @@ mod protocol_version; pub use protocol_version::ProtocolVersion; pub mod de; +pub mod ser; diff --git a/tvix/nix-compat/src/nix_daemon/ser/bytes.rs b/tvix/nix-compat/src/nix_daemon/ser/bytes.rs new file mode 100644 index 000000000000..19494934ff32 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/bytes.rs @@ -0,0 +1,89 @@ +use bytes::Bytes; + +use super::{NixSerialize, NixWrite}; + +impl NixSerialize for Bytes { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self).await + } +} + +impl<'a> NixSerialize for &'a [u8] { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self).await + } +} + +impl NixSerialize for String { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +impl NixSerialize for str { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::nix_daemon::ser::{NixWrite, NixWriter}; + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_write_str(#[case] value: &str, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_write_string(#[case] value: &str, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value.to_string()).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/ser/collections.rs b/tvix/nix-compat/src/nix_daemon/ser/collections.rs new file mode 100644 index 000000000000..70c32e1c79ac --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/collections.rs @@ -0,0 +1,94 @@ +use std::collections::BTreeMap; +use std::future::Future; + +use super::{NixSerialize, NixWrite}; + +impl<T> NixSerialize for Vec<T> +where + T: NixSerialize + Send + Sync, +{ + #[allow(clippy::manual_async_fn)] + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite, + { + async move { + writer.write_value(&self.len()).await?; + for value in self.iter() { + writer.write_value(value).await?; + } + Ok(()) + } + } +} + +impl<K, V> NixSerialize for BTreeMap<K, V> +where + K: NixSerialize + Ord + Send + Sync, + V: NixSerialize + Send + Sync, +{ + #[allow(clippy::manual_async_fn)] + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite, + { + async move { + writer.write_value(&self.len()).await?; + for (key, value) in self.iter() { + writer.write_value(key).await?; + writer.write_value(value).await?; + } + Ok(()) + } + } +} + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + use std::fmt; + + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::nix_daemon::ser::{NixSerialize, NixWrite, NixWriter}; + + #[rstest] + #[case::empty(vec![], &hex!("0000 0000 0000 0000"))] + #[case::one(vec![0x29], &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(vec![0x7469, 10], &hex!("0200 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_write_small_vec(#[case] value: Vec<usize>, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + fn empty_map() -> BTreeMap<usize, u64> { + BTreeMap::new() + } + macro_rules! map { + ($($key:expr => $value:expr),*) => {{ + let mut ret = BTreeMap::new(); + $(ret.insert($key, $value);)* + ret + }}; + } + + #[rstest] + #[case::empty(empty_map(), &hex!("0000 0000 0000 0000"))] + #[case::one(map![0x7469usize => 10u64], &hex!("0100 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_write_small_btree_map<E>(#[case] value: E, #[case] data: &[u8]) + where + E: NixSerialize + Send + PartialEq + fmt::Debug, + { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/ser/display.rs b/tvix/nix-compat/src/nix_daemon/ser/display.rs new file mode 100644 index 000000000000..a3438d50d8ff --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/display.rs @@ -0,0 +1,8 @@ +use nix_compat_derive::nix_serialize_remote; + +use crate::nixhash; + +nix_serialize_remote!( + #[nix(display)] + nixhash::HashAlgo +); diff --git a/tvix/nix-compat/src/nix_daemon/ser/int.rs b/tvix/nix-compat/src/nix_daemon/ser/int.rs new file mode 100644 index 000000000000..1be06442e322 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/int.rs @@ -0,0 +1,108 @@ +#[cfg(feature = "nix-compat-derive")] +use nix_compat_derive::nix_serialize_remote; + +use super::{Error, NixSerialize, NixWrite}; + +impl NixSerialize for u64 { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_number(*self).await + } +} + +impl NixSerialize for usize { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + let v = (*self).try_into().map_err(W::Error::unsupported_data)?; + writer.write_number(v).await + } +} + +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u8 +); +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u16 +); +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u32 +); + +impl NixSerialize for bool { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + if *self { + writer.write_number(1).await + } else { + writer.write_number(0).await + } + } +} + +impl NixSerialize for i64 { + async fn serialize<W>(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_number(*self as u64).await + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::nix_daemon::ser::{NixWrite, NixWriter}; + + #[rstest] + #[case::simple_false(false, &hex!("0000 0000 0000 0000"))] + #[case::simple_true(true, &hex!("0100 0000 0000 0000"))] + #[tokio::test] + async fn test_write_bool(#[case] value: bool, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_write_u64(#[case] value: u64, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(usize::MAX, &usize::MAX.to_le_bytes())] + #[tokio::test] + async fn test_write_usize(#[case] value: usize, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/ser/mock.rs b/tvix/nix-compat/src/nix_daemon/ser/mock.rs new file mode 100644 index 000000000000..1319a8da3228 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/mock.rs @@ -0,0 +1,672 @@ +use std::collections::VecDeque; +use std::fmt; +use std::io; +use std::thread; + +#[cfg(test)] +use ::proptest::prelude::TestCaseError; +use thiserror::Error; + +use crate::nix_daemon::ProtocolVersion; + +use super::NixWrite; + +#[derive(Debug, Error, PartialEq, Eq, Clone)] +pub enum Error { + #[error("custom error '{0}'")] + Custom(String), + #[error("unsupported data error '{0}'")] + UnsupportedData(String), + #[error("Invalid enum: {0}")] + InvalidEnum(String), + #[error("IO error {0} '{1}'")] + IO(io::ErrorKind, String), + #[error("wrong write: expected {0} got {1}")] + WrongWrite(OperationType, OperationType), + #[error("unexpected write: got an extra {0}")] + ExtraWrite(OperationType), + #[error("got an unexpected number {0} in write_number")] + UnexpectedNumber(u64), + #[error("got an unexpected slice '{0:?}' in write_slice")] + UnexpectedSlice(Vec<u8>), + #[error("got an unexpected display '{0:?}' in write_slice")] + UnexpectedDisplay(String), +} + +impl Error { + pub fn unexpected_write_number(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteNumber) + } + + pub fn extra_write_number() -> Error { + Error::ExtraWrite(OperationType::WriteNumber) + } + + pub fn unexpected_write_slice(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteSlice) + } + + pub fn unexpected_write_display(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteDisplay) + } +} + +impl super::Error for Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + Self::Custom(msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + Self::IO(err.kind(), err.to_string()) + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + Self::UnsupportedData(msg.to_string()) + } + + fn invalid_enum<T: fmt::Display>(msg: T) -> Self { + Self::InvalidEnum(msg.to_string()) + } +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OperationType { + WriteNumber, + WriteSlice, + WriteDisplay, +} + +impl fmt::Display for OperationType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::WriteNumber => write!(f, "write_number"), + Self::WriteSlice => write!(f, "write_slice"), + Self::WriteDisplay => write!(f, "write_display"), + } + } +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, PartialEq, Eq)] +enum Operation { + WriteNumber(u64, Result<(), Error>), + WriteSlice(Vec<u8>, Result<(), Error>), + WriteDisplay(String, Result<(), Error>), +} + +impl From<Operation> for OperationType { + fn from(value: Operation) -> Self { + match value { + Operation::WriteNumber(_, _) => OperationType::WriteNumber, + Operation::WriteSlice(_, _) => OperationType::WriteSlice, + Operation::WriteDisplay(_, _) => OperationType::WriteDisplay, + } + } +} + +pub struct Builder { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Builder { + pub fn new() -> Builder { + Builder { + version: Default::default(), + ops: VecDeque::new(), + } + } + + pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self { + self.version = version.into(); + self + } + + pub fn write_number(&mut self, value: u64) -> &mut Self { + self.ops.push_back(Operation::WriteNumber(value, Ok(()))); + self + } + + pub fn write_number_error(&mut self, value: u64, err: Error) -> &mut Self { + self.ops.push_back(Operation::WriteNumber(value, Err(err))); + self + } + + pub fn write_slice(&mut self, value: &[u8]) -> &mut Self { + self.ops + .push_back(Operation::WriteSlice(value.to_vec(), Ok(()))); + self + } + + pub fn write_slice_error(&mut self, value: &[u8], err: Error) -> &mut Self { + self.ops + .push_back(Operation::WriteSlice(value.to_vec(), Err(err))); + self + } + + pub fn write_display<D>(&mut self, value: D) -> &mut Self + where + D: fmt::Display, + { + let msg = value.to_string(); + self.ops.push_back(Operation::WriteDisplay(msg, Ok(()))); + self + } + + pub fn write_display_error<D>(&mut self, value: D, err: Error) -> &mut Self + where + D: fmt::Display, + { + let msg = value.to_string(); + self.ops.push_back(Operation::WriteDisplay(msg, Err(err))); + self + } + + #[cfg(test)] + fn write_operation_type(&mut self, op: OperationType) -> &mut Self { + match op { + OperationType::WriteNumber => self.write_number(10), + OperationType::WriteSlice => self.write_slice(b"testing"), + OperationType::WriteDisplay => self.write_display("testing"), + } + } + + #[cfg(test)] + fn write_operation(&mut self, op: &Operation) -> &mut Self { + match op { + Operation::WriteNumber(value, Ok(_)) => self.write_number(*value), + Operation::WriteNumber(value, Err(Error::UnexpectedNumber(_))) => { + self.write_number(*value) + } + Operation::WriteNumber(_, Err(Error::ExtraWrite(OperationType::WriteNumber))) => self, + Operation::WriteNumber(_, Err(Error::WrongWrite(op, OperationType::WriteNumber))) => { + self.write_operation_type(*op) + } + Operation::WriteNumber(value, Err(Error::Custom(msg))) => { + self.write_number_error(*value, Error::Custom(msg.clone())) + } + Operation::WriteNumber(value, Err(Error::IO(kind, msg))) => { + self.write_number_error(*value, Error::IO(*kind, msg.clone())) + } + Operation::WriteSlice(value, Ok(_)) => self.write_slice(value), + Operation::WriteSlice(value, Err(Error::UnexpectedSlice(_))) => self.write_slice(value), + Operation::WriteSlice(_, Err(Error::ExtraWrite(OperationType::WriteSlice))) => self, + Operation::WriteSlice(_, Err(Error::WrongWrite(op, OperationType::WriteSlice))) => { + self.write_operation_type(*op) + } + Operation::WriteSlice(value, Err(Error::Custom(msg))) => { + self.write_slice_error(value, Error::Custom(msg.clone())) + } + Operation::WriteSlice(value, Err(Error::IO(kind, msg))) => { + self.write_slice_error(value, Error::IO(*kind, msg.clone())) + } + Operation::WriteDisplay(value, Ok(_)) => self.write_display(value), + Operation::WriteDisplay(value, Err(Error::Custom(msg))) => { + self.write_display_error(value, Error::Custom(msg.clone())) + } + Operation::WriteDisplay(value, Err(Error::IO(kind, msg))) => { + self.write_display_error(value, Error::IO(*kind, msg.clone())) + } + Operation::WriteDisplay(value, Err(Error::UnexpectedDisplay(_))) => { + self.write_display(value) + } + Operation::WriteDisplay(_, Err(Error::ExtraWrite(OperationType::WriteDisplay))) => self, + Operation::WriteDisplay(_, Err(Error::WrongWrite(op, OperationType::WriteDisplay))) => { + self.write_operation_type(*op) + } + s => panic!("Invalid operation {:?}", s), + } + } + + pub fn build(&mut self) -> Mock { + Mock { + version: self.version, + ops: self.ops.clone(), + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +pub struct Mock { + version: ProtocolVersion, + ops: VecDeque<Operation>, +} + +impl Mock { + #[cfg(test)] + #[allow(dead_code)] + async fn assert_operation(&mut self, op: Operation) { + match op { + Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { + assert_eq!( + self.write_number(value).await, + Err(Error::UnexpectedNumber(value)) + ); + } + Operation::WriteNumber(value, res) => { + assert_eq!(self.write_number(value).await, res); + } + Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { + assert_eq!(self.write_slice(value).await, res.clone()); + } + Operation::WriteSlice(value, res) => { + assert_eq!(self.write_slice(&value).await, res); + } + Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { + assert_eq!(self.write_display(value).await, res.clone()); + } + Operation::WriteDisplay(value, res) => { + assert_eq!(self.write_display(value).await, res); + } + } + } + + #[cfg(test)] + async fn prop_assert_operation(&mut self, op: Operation) -> Result<(), TestCaseError> { + use ::proptest::prop_assert_eq; + + match op { + Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { + prop_assert_eq!( + self.write_number(value).await, + Err(Error::UnexpectedNumber(value)) + ); + } + Operation::WriteNumber(value, res) => { + prop_assert_eq!(self.write_number(value).await, res); + } + Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { + prop_assert_eq!(self.write_slice(value).await, res.clone()); + } + Operation::WriteSlice(value, res) => { + prop_assert_eq!(self.write_slice(&value).await, res); + } + Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { + prop_assert_eq!(self.write_display(&value).await, res.clone()); + } + Operation::WriteDisplay(value, res) => { + prop_assert_eq!(self.write_display(&value).await, res); + } + } + Ok(()) + } +} + +impl NixWrite for Mock { + type Error = Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { + match self.ops.pop_front() { + Some(Operation::WriteNumber(expected, ret)) => { + if value != expected { + return Err(Error::UnexpectedNumber(value)); + } + ret + } + Some(op) => Err(Error::unexpected_write_number(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteNumber)), + } + } + + async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + match self.ops.pop_front() { + Some(Operation::WriteSlice(expected, ret)) => { + if buf != expected { + return Err(Error::UnexpectedSlice(buf.to_vec())); + } + ret + } + Some(op) => Err(Error::unexpected_write_slice(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteSlice)), + } + } + + async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error> + where + D: fmt::Display + Send, + Self: Sized, + { + let value = msg.to_string(); + match self.ops.pop_front() { + Some(Operation::WriteDisplay(expected, ret)) => { + if value != expected { + return Err(Error::UnexpectedDisplay(value)); + } + ret + } + Some(op) => Err(Error::unexpected_write_display(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteDisplay)), + } + } +} + +impl Drop for Mock { + fn drop(&mut self) { + // No need to panic again + if thread::panicking() { + return; + } + if let Some(op) = self.ops.front() { + panic!("reader dropped with {op:?} operation still unread") + } + } +} + +#[cfg(test)] +mod proptest { + use std::io; + + use proptest::{ + prelude::{any, Arbitrary, BoxedStrategy, Just, Strategy}, + prop_oneof, + }; + + use super::{Error, Operation, OperationType}; + + pub fn arb_write_number_operation() -> impl Strategy<Value = Operation> { + ( + any::<u64>(), + prop_oneof![ + Just(Ok(())), + any::<u64>().prop_map(|v| Err(Error::UnexpectedNumber(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteSlice, + OperationType::WriteNumber + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteDisplay, + OperationType::WriteNumber + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same number", |(v, res)| match res { + Err(Error::UnexpectedNumber(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteNumber(v, res)) + } + + pub fn arb_write_slice_operation() -> impl Strategy<Value = Operation> { + ( + any::<Vec<u8>>(), + prop_oneof![ + Just(Ok(())), + any::<Vec<u8>>().prop_map(|v| Err(Error::UnexpectedSlice(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteNumber, + OperationType::WriteSlice + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteDisplay, + OperationType::WriteSlice + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same slice", |(v, res)| match res { + Err(Error::UnexpectedSlice(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteSlice(v, res)) + } + + #[allow(dead_code)] + pub fn arb_extra_write() -> impl Strategy<Value = Operation> { + prop_oneof![ + any::<u64>().prop_map(|msg| { + Operation::WriteNumber(msg, Err(Error::ExtraWrite(OperationType::WriteNumber))) + }), + any::<Vec<u8>>().prop_map(|msg| { + Operation::WriteSlice(msg, Err(Error::ExtraWrite(OperationType::WriteSlice))) + }), + any::<String>().prop_map(|msg| { + Operation::WriteDisplay(msg, Err(Error::ExtraWrite(OperationType::WriteDisplay))) + }), + ] + } + + pub fn arb_write_display_operation() -> impl Strategy<Value = Operation> { + ( + any::<String>(), + prop_oneof![ + Just(Ok(())), + any::<String>().prop_map(|v| Err(Error::UnexpectedDisplay(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteNumber, + OperationType::WriteDisplay + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteSlice, + OperationType::WriteDisplay + ))), + any::<String>().prop_map(|s| Err(Error::Custom(s))), + (any::<io::ErrorKind>(), any::<String>()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same string", |(v, res)| match res { + Err(Error::UnexpectedDisplay(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteDisplay(v, res)) + } + + pub fn arb_operation() -> impl Strategy<Value = Operation> { + prop_oneof![ + arb_write_number_operation(), + arb_write_slice_operation(), + arb_write_display_operation(), + ] + } + + impl Arbitrary for Operation { + type Parameters = (); + type Strategy = BoxedStrategy<Operation>; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + arb_operation().boxed() + } + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use proptest::prelude::any; + use proptest::prelude::TestCaseError; + use proptest::proptest; + + use crate::nix_daemon::ser::mock::proptest::arb_extra_write; + use crate::nix_daemon::ser::mock::Operation; + use crate::nix_daemon::ser::mock::OperationType; + use crate::nix_daemon::ser::Error as _; + use crate::nix_daemon::ser::NixWrite; + + use super::{Builder, Error}; + + #[tokio::test] + async fn write_number() { + let mut mock = Builder::new().write_number(10).build(); + mock.write_number(10).await.unwrap(); + } + + #[tokio::test] + async fn write_number_error() { + let mut mock = Builder::new() + .write_number_error(10, Error::custom("bad number")) + .build(); + assert_eq!( + Err(Error::custom("bad number")), + mock.write_number(10).await + ); + } + + #[tokio::test] + async fn write_number_unexpected() { + let mut mock = Builder::new().write_slice(b"").build(); + assert_eq!( + Err(Error::unexpected_write_number(OperationType::WriteSlice)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn write_number_unexpected_number() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::UnexpectedNumber(11)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn extra_write_number() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteNumber)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn write_slice() { + let mut mock = Builder::new() + .write_slice(&[]) + .write_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .build(); + mock.write_slice(&[]).await.expect("write_slice empty"); + mock.write_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .await + .expect("write_slice"); + } + + #[tokio::test] + async fn write_slice_error() { + let mut mock = Builder::new() + .write_slice_error(&[], Error::custom("bad slice")) + .build(); + assert_eq!(Err(Error::custom("bad slice")), mock.write_slice(&[]).await); + } + + #[tokio::test] + async fn write_slice_unexpected() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::unexpected_write_slice(OperationType::WriteNumber)), + mock.write_slice(b"").await + ); + } + + #[tokio::test] + async fn write_slice_unexpected_slice() { + let mut mock = Builder::new().write_slice(b"").build(); + assert_eq!( + Err(Error::UnexpectedSlice(b"bad slice".to_vec())), + mock.write_slice(b"bad slice").await + ); + } + + #[tokio::test] + async fn extra_write_slice() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteSlice)), + mock.write_slice(b"extra slice").await + ); + } + + #[tokio::test] + async fn write_display() { + let mut mock = Builder::new().write_display("testing").build(); + mock.write_display("testing").await.unwrap(); + } + + #[tokio::test] + async fn write_display_error() { + let mut mock = Builder::new() + .write_display_error("testing", Error::custom("bad number")) + .build(); + assert_eq!( + Err(Error::custom("bad number")), + mock.write_display("testing").await + ); + } + + #[tokio::test] + async fn write_display_unexpected() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::unexpected_write_display(OperationType::WriteNumber)), + mock.write_display("").await + ); + } + + #[tokio::test] + async fn write_display_unexpected_display() { + let mut mock = Builder::new().write_display("").build(); + assert_eq!( + Err(Error::UnexpectedDisplay("bad display".to_string())), + mock.write_display("bad display").await + ); + } + + #[tokio::test] + async fn extra_write_display() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteDisplay)), + mock.write_display("extra slice").await + ); + } + + #[test] + #[should_panic] + fn operations_left() { + let _ = Builder::new().write_number(10).build(); + } + + #[test] + fn proptest_mock() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + proptest!(|( + operations in any::<Vec<Operation>>(), + extra_operations in proptest::collection::vec(arb_extra_write(), 0..3) + )| { + rt.block_on(async { + let mut builder = Builder::new(); + for op in operations.iter() { + builder.write_operation(op); + } + for op in extra_operations.iter() { + builder.write_operation(op); + } + let mut mock = builder.build(); + for op in operations { + mock.prop_assert_operation(op).await?; + } + for op in extra_operations { + mock.prop_assert_operation(op).await?; + } + Ok(()) as Result<(), TestCaseError> + })?; + }); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/ser/mod.rs b/tvix/nix-compat/src/nix_daemon/ser/mod.rs new file mode 100644 index 000000000000..5860226f39eb --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/mod.rs @@ -0,0 +1,124 @@ +use std::error::Error as StdError; +use std::future::Future; +use std::{fmt, io}; + +use super::ProtocolVersion; + +mod bytes; +mod collections; +#[cfg(feature = "nix-compat-derive")] +mod display; +mod int; +#[cfg(any(test, feature = "test"))] +pub mod mock; +mod writer; + +pub use writer::{NixWriter, NixWriterBuilder}; + +pub trait Error: Sized + StdError { + fn custom<T: fmt::Display>(msg: T) -> Self; + + fn io_error(err: std::io::Error) -> Self { + Self::custom(format_args!("There was an I/O error {}", err)) + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } + + fn invalid_enum<T: fmt::Display>(msg: T) -> Self { + Self::custom(msg) + } +} + +impl Error for io::Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::Other, msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + err + } + + fn unsupported_data<T: fmt::Display>(msg: T) -> Self { + io::Error::new(io::ErrorKind::InvalidData, msg.to_string()) + } +} + +pub trait NixWrite: Send { + type Error: Error; + + /// Some types are serialized differently depending on the version + /// of the protocol and so this can be used for implementing that. + fn version(&self) -> ProtocolVersion; + + /// Write a single u64 to the protocol. + fn write_number(&mut self, value: u64) -> impl Future<Output = Result<(), Self::Error>> + Send; + + /// Write a slice of bytes to the protocol. + fn write_slice(&mut self, buf: &[u8]) -> impl Future<Output = Result<(), Self::Error>> + Send; + + /// Write a value that implements `std::fmt::Display` to the protocol. + /// The protocol uses many small string formats and instead of allocating + /// a `String` each time we want to write one an implementation of `NixWrite` + /// can instead use `Display` to dump these formats to a reusable buffer. + fn write_display<D>(&mut self, msg: D) -> impl Future<Output = Result<(), Self::Error>> + Send + where + D: fmt::Display + Send, + Self: Sized, + { + async move { + let s = msg.to_string(); + self.write_slice(s.as_bytes()).await + } + } + + /// Write a value to the protocol. + /// Uses `NixSerialize::serialize` to write the value. + fn write_value<V>(&mut self, value: &V) -> impl Future<Output = Result<(), Self::Error>> + Send + where + V: NixSerialize + Send + ?Sized, + Self: Sized, + { + value.serialize(self) + } +} + +impl<T: NixWrite> NixWrite for &mut T { + type Error = T::Error; + + fn version(&self) -> ProtocolVersion { + (**self).version() + } + + fn write_number(&mut self, value: u64) -> impl Future<Output = Result<(), Self::Error>> + Send { + (**self).write_number(value) + } + + fn write_slice(&mut self, buf: &[u8]) -> impl Future<Output = Result<(), Self::Error>> + Send { + (**self).write_slice(buf) + } + + fn write_display<D>(&mut self, msg: D) -> impl Future<Output = Result<(), Self::Error>> + Send + where + D: fmt::Display + Send, + Self: Sized, + { + (**self).write_display(msg) + } + + fn write_value<V>(&mut self, value: &V) -> impl Future<Output = Result<(), Self::Error>> + Send + where + V: NixSerialize + Send + ?Sized, + Self: Sized, + { + (**self).write_value(value) + } +} + +pub trait NixSerialize { + /// Write a value to the writer. + fn serialize<W>(&self, writer: &mut W) -> impl Future<Output = Result<(), W::Error>> + Send + where + W: NixWrite; +} diff --git a/tvix/nix-compat/src/nix_daemon/ser/writer.rs b/tvix/nix-compat/src/nix_daemon/ser/writer.rs new file mode 100644 index 000000000000..87e30580af34 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/writer.rs @@ -0,0 +1,308 @@ +use std::fmt::{self, Write as _}; +use std::future::poll_fn; +use std::io::{self, Cursor}; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use bytes::{Buf, BufMut, BytesMut}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +use crate::nix_daemon::ProtocolVersion; +use crate::wire::padding_len; +use crate::wire::EMPTY_BYTES; + +use super::{Error, NixWrite}; + +pub struct NixWriterBuilder { + buf: Option<BytesMut>, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, +} + +impl Default for NixWriterBuilder { + fn default() -> Self { + Self { + buf: Default::default(), + reserved_buf_size: 8192, + max_buf_size: 8192, + version: Default::default(), + } + } +} + +impl NixWriterBuilder { + pub fn set_buffer(mut self, buf: BytesMut) -> Self { + self.buf = Some(buf); + self + } + + pub fn set_reserved_buf_size(mut self, size: usize) -> Self { + self.reserved_buf_size = size; + self + } + + pub fn set_max_buf_size(mut self, size: usize) -> Self { + self.max_buf_size = size; + self + } + + pub fn set_version(mut self, version: ProtocolVersion) -> Self { + self.version = version; + self + } + + pub fn build<W>(self, writer: W) -> NixWriter<W> { + let buf = self + .buf + .unwrap_or_else(|| BytesMut::with_capacity(self.max_buf_size)); + NixWriter { + buf, + inner: writer, + reserved_buf_size: self.reserved_buf_size, + max_buf_size: self.max_buf_size, + version: self.version, + } + } +} + +pin_project! { + pub struct NixWriter<W> { + #[pin] + inner: W, + buf: BytesMut, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, + } +} + +impl NixWriter<Cursor<Vec<u8>>> { + pub fn builder() -> NixWriterBuilder { + NixWriterBuilder::default() + } +} + +impl<W> NixWriter<W> +where + W: AsyncWriteExt, +{ + pub fn new(writer: W) -> NixWriter<W> { + NixWriter::builder().build(writer) + } + + pub fn buffer(&self) -> &[u8] { + &self.buf[..] + } + + pub fn set_version(&mut self, version: ProtocolVersion) { + self.version = version; + } + + /// Remaining capacity in internal buffer + pub fn remaining_mut(&self) -> usize { + self.buf.capacity() - self.buf.len() + } + + fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + let mut this = self.project(); + while !this.buf.is_empty() { + let n = ready!(this.inner.as_mut().poll_write(cx, &this.buf[..]))?; + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write the buffer", + ))); + } + this.buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl<W> NixWriter<W> +where + W: AsyncWriteExt + Unpin, +{ + async fn flush_buf(&mut self) -> Result<(), io::Error> { + let mut s = Pin::new(self); + poll_fn(move |cx| s.as_mut().poll_flush_buf(cx)).await + } +} + +impl<W> AsyncWrite for NixWriter<W> +where + W: AsyncWrite, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + // Flush + if self.remaining_mut() < buf.len() { + ready!(self.as_mut().poll_flush_buf(cx))?; + } + let this = self.project(); + if buf.len() > this.buf.capacity() { + this.inner.poll_write(cx, buf) + } else { + this.buf.put_slice(buf); + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + ready!(self.as_mut().poll_flush_buf(cx))?; + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + ready!(self.as_mut().poll_flush_buf(cx))?; + self.project().inner.poll_shutdown(cx) + } +} + +impl<W> NixWrite for NixWriter<W> +where + W: AsyncWrite + Send + Unpin, +{ + type Error = io::Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { + let mut buf = [0u8; 8]; + BufMut::put_u64_le(&mut &mut buf[..], value); + self.write_all(&buf).await + } + + async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + let padding = padding_len(buf.len() as u64) as usize; + self.write_value(&buf.len()).await?; + self.write_all(buf).await?; + if padding > 0 { + self.write_all(&EMPTY_BYTES[..padding]).await + } else { + Ok(()) + } + } + + async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error> + where + D: fmt::Display + Send, + Self: Sized, + { + // Ensure that buffer has space for at least reserved_buf_size bytes + if self.remaining_mut() < self.reserved_buf_size && !self.buf.is_empty() { + self.flush_buf().await?; + } + let offset = self.buf.len(); + self.buf.put_u64_le(0); + if let Err(err) = write!(self.buf, "{}", msg) { + self.buf.truncate(offset); + return Err(Self::Error::unsupported_data(err)); + } + let len = self.buf.len() - offset - 8; + BufMut::put_u64_le(&mut &mut self.buf[offset..(offset + 8)], len as u64); + let padding = padding_len(len as u64) as usize; + self.write_all(&EMPTY_BYTES[..padding]).await + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::nix_daemon::ser::NixWrite; + + use super::NixWriter; + + #[rstest] + #[case(1, &hex!("0100 0000 0000 0000"))] + #[case::evil(666, &hex!("9A02 0000 0000 0000"))] + #[case::max(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_write_number(#[case] number: u64, #[case] buf: &[u8]) { + let mock = Builder::new().write(buf).build(); + let mut writer = NixWriter::new(mock); + + writer.write_number(number).await.unwrap(); + assert_eq!(writer.buffer(), buf); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } + + #[rstest] + #[case::empty(b"", &hex!("0000 0000 0000 0000"))] + #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_write_slice( + #[case] value: &[u8], + #[case] buf: &[u8], + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] buf_size: usize, + ) { + let mut builder = Builder::new(); + for chunk in buf.chunks(chunks_size) { + builder.write(chunk); + builder.wait(Duration::ZERO); + } + let mock = builder.build(); + let mut writer = NixWriter::builder().set_max_buf_size(buf_size).build(mock); + + writer.write_slice(value).await.unwrap(); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_write_display( + #[case] value: &str, + #[case] buf: &[u8], + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, + ) { + let mut builder = Builder::new(); + for chunk in buf.chunks(chunks_size) { + builder.write(chunk); + builder.wait(Duration::ZERO); + } + let mock = builder.build(); + let mut writer = NixWriter::builder().build(mock); + + writer.write_display(value).await.unwrap(); + assert_eq!(writer.buffer(), buf); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } +} diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs index 74adfb49b6a4..9b981fbbd2c0 100644 --- a/tvix/nix-compat/src/wire/bytes/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/mod.rs @@ -181,7 +181,7 @@ pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>( /// Computes the number of bytes we should add to len (a length in /// bytes) to be aligned on 64 bits (8 bytes). -fn padding_len(len: u64) -> u8 { +pub(crate) fn padding_len(len: u64) -> u8 { let aligned = len.wrapping_add(7) & !7; aligned.wrapping_sub(len) as u8 } |