From b88579ade41244b09555bbb68296033fc300043f Mon Sep 17 00:00:00 2001 From: Brian Olsen Date: Sun, 3 Nov 2024 20:42:01 +0100 Subject: feat(tvix/nix-compat): Add nix serialization support This change implements the serialization part that is needed to implement the nix daemon protocol. Previously was add deserialization and derivers for that and this then adds the other part of that equation so that you can write types that can then be read using deserialization. Change-Id: I2917de634980a93822a4f5a8ad38897b9ce16d89 Reviewed-on: https://cl.tvl.fyi/c/depot/+/12729 Autosubmit: Brian Olsen Reviewed-by: flokli Tested-by: BuildkiteCI --- tvix/Cargo.lock | 2 + tvix/Cargo.nix | 12 + tvix/nix-compat-derive-tests/tests/write_derive.rs | 370 ++++++++++++ tvix/nix-compat-derive/Cargo.toml | 1 + tvix/nix-compat-derive/src/de.rs | 5 + tvix/nix-compat-derive/src/internal/attrs.rs | 148 ++++- tvix/nix-compat-derive/src/internal/mod.rs | 4 - tvix/nix-compat-derive/src/internal/symbol.rs | 3 + tvix/nix-compat-derive/src/lib.rs | 175 +++++- tvix/nix-compat-derive/src/ser.rs | 227 +++++++ tvix/nix-compat/Cargo.toml | 1 + tvix/nix-compat/src/nix_daemon/mod.rs | 1 + tvix/nix-compat/src/nix_daemon/ser/bytes.rs | 89 +++ tvix/nix-compat/src/nix_daemon/ser/collections.rs | 94 +++ tvix/nix-compat/src/nix_daemon/ser/display.rs | 8 + tvix/nix-compat/src/nix_daemon/ser/int.rs | 108 ++++ tvix/nix-compat/src/nix_daemon/ser/mock.rs | 672 +++++++++++++++++++++ tvix/nix-compat/src/nix_daemon/ser/mod.rs | 124 ++++ tvix/nix-compat/src/nix_daemon/ser/writer.rs | 308 ++++++++++ tvix/nix-compat/src/wire/bytes/mod.rs | 2 +- 20 files changed, 2339 insertions(+), 15 deletions(-) create mode 100644 tvix/nix-compat-derive-tests/tests/write_derive.rs create mode 100644 tvix/nix-compat-derive/src/ser.rs create mode 100644 tvix/nix-compat/src/nix_daemon/ser/bytes.rs create mode 100644 tvix/nix-compat/src/nix_daemon/ser/collections.rs create mode 100644 tvix/nix-compat/src/nix_daemon/ser/display.rs create mode 100644 tvix/nix-compat/src/nix_daemon/ser/int.rs create mode 100644 tvix/nix-compat/src/nix_daemon/ser/mock.rs create mode 100644 tvix/nix-compat/src/nix_daemon/ser/mod.rs create mode 100644 tvix/nix-compat/src/nix_daemon/ser/writer.rs diff --git a/tvix/Cargo.lock b/tvix/Cargo.lock index 2984479c18ef..73badc2af496 100644 --- a/tvix/Cargo.lock +++ b/tvix/Cargo.lock @@ -2302,6 +2302,7 @@ dependencies = [ "num-traits", "pin-project-lite", "pretty_assertions", + "proptest", "rstest", "serde", "serde_json", @@ -2322,6 +2323,7 @@ dependencies = [ "nix-compat", "pretty_assertions", "proc-macro2", + "proptest", "quote", "rstest", "syn 2.0.79", diff --git a/tvix/Cargo.nix b/tvix/Cargo.nix index 749f39fb93cc..a09cbc016a23 100644 --- a/tvix/Cargo.nix +++ b/tvix/Cargo.nix @@ -7298,6 +7298,12 @@ rec { name = "pretty_assertions"; packageId = "pretty_assertions"; } + { + name = "proptest"; + packageId = "proptest"; + usesDefaultFeatures = false; + features = [ "std" "alloc" "tempfile" ]; + } { name = "rstest"; packageId = "rstest"; @@ -7369,6 +7375,12 @@ rec { name = "pretty_assertions"; packageId = "pretty_assertions"; } + { + name = "proptest"; + packageId = "proptest"; + usesDefaultFeatures = false; + features = [ "std" "alloc" "tempfile" ]; + } { name = "rstest"; packageId = "rstest"; diff --git a/tvix/nix-compat-derive-tests/tests/write_derive.rs b/tvix/nix-compat-derive-tests/tests/write_derive.rs new file mode 100644 index 000000000000..1ed5dbf7e735 --- /dev/null +++ b/tvix/nix-compat-derive-tests/tests/write_derive.rs @@ -0,0 +1,370 @@ +use std::fmt; + +use nix_compat::nix_daemon::ser::{ + mock::{Builder, Error}, + NixWrite as _, +}; +use nix_compat_derive::NixSerialize; + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct UnitTest; + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct EmptyTupleTest(); + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct StructTest { + first: u64, + second: String, +} + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct TupleTest(u64, String); + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct StructVersionTest { + test: u64, + #[nix(version = "20..")] + hello: String, +} + +fn default_test() -> StructVersionTest { + StructVersionTest { + test: 89, + hello: String::from("klomp"), + } +} + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct TupleVersionTest(u64, #[nix(version = "25..")] String); + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +pub struct TupleVersionDefaultTest(u64, #[nix(version = "..25")] StructVersionTest); + +#[tokio::test] +async fn write_unit() { + let mut mock = Builder::new().build(); + mock.write_value(&UnitTest).await.unwrap(); +} + +#[tokio::test] +async fn write_empty_tuple() { + let mut mock = Builder::new().build(); + mock.write_value(&EmptyTupleTest()).await.unwrap(); +} + +#[tokio::test] +async fn write_struct() { + let mut mock = Builder::new() + .write_number(89) + .write_slice(b"klomp") + .build(); + mock.write_value(&StructTest { + first: 89, + second: String::from("klomp"), + }) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_tuple() { + let mut mock = Builder::new() + .write_number(89) + .write_slice(b"klomp") + .build(); + mock.write_value(&TupleTest(89, String::from("klomp"))) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_struct_version() { + let mut mock = Builder::new() + .version((1, 20)) + .write_number(89) + .write_slice(b"klomp") + .build(); + mock.write_value(&default_test()).await.unwrap(); +} + +#[tokio::test] +async fn write_struct_without_version() { + let mut mock = Builder::new().version((1, 19)).write_number(89).build(); + mock.write_value(&StructVersionTest { + test: 89, + hello: String::new(), + }) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_tuple_version() { + let mut mock = Builder::new() + .version((1, 26)) + .write_number(89) + .write_slice(b"klomp") + .build(); + mock.write_value(&TupleVersionTest(89, "klomp".into())) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_tuple_without_version() { + let mut mock = Builder::new().version((1, 19)).write_number(89).build(); + mock.write_value(&TupleVersionTest(89, String::new())) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_complex_1() { + let mut mock = Builder::new() + .version((1, 19)) + .write_number(999) + .write_number(666) + .build(); + mock.write_value(&TupleVersionDefaultTest( + 999, + StructVersionTest { + test: 666, + hello: String::new(), + }, + )) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_complex_2() { + let mut mock = Builder::new() + .version((1, 20)) + .write_number(999) + .write_number(666) + .write_slice(b"The quick brown \xF0\x9F\xA6\x8A jumps over 13 lazy \xF0\x9F\x90\xB6.") + .build(); + mock.write_value(&TupleVersionDefaultTest( + 999, + StructVersionTest { + test: 666, + hello: String::from("The quick brown 🦊 jumps over 13 lazy 🐶."), + }, + )) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_complex_3() { + let mut mock = Builder::new().version((1, 25)).write_number(999).build(); + mock.write_value(&TupleVersionDefaultTest( + 999, + StructVersionTest { + test: 89, + hello: String::from("klomp"), + }, + )) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_complex_4() { + let mut mock = Builder::new().version((1, 26)).write_number(999).build(); + mock.write_value(&TupleVersionDefaultTest( + 999, + StructVersionTest { + test: 89, + hello: String::from("klomp"), + }, + )) + .await + .unwrap(); +} + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +#[nix(display)] +struct TestFromStr; + +impl fmt::Display for TestFromStr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "test") + } +} + +#[tokio::test] +async fn write_display() { + let mut mock = Builder::new().write_display("test").build(); + mock.write_value(&TestFromStr).await.unwrap(); +} + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +#[nix(display = "TestFromStr2::display")] +struct TestFromStr2; +struct TestFromStrDisplay; + +impl fmt::Display for TestFromStrDisplay { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "test") + } +} +impl TestFromStr2 { + fn display(&self) -> TestFromStrDisplay { + TestFromStrDisplay + } +} + +#[tokio::test] +async fn write_display_path() { + let mut mock = Builder::new().write_display("test").build(); + mock.write_value(&TestFromStr2).await.unwrap(); +} + +#[derive(Clone, Debug, PartialEq, Eq, NixSerialize)] +#[nix(try_into = "u64")] +struct TestTryFromU64(u64); + +impl TryFrom for u64 { + type Error = u64; + + fn try_from(value: TestTryFromU64) -> Result { + if value.0 != 42 { + Ok(value.0) + } else { + Err(value.0) + } + } +} + +#[tokio::test] +async fn write_try_into_u64() { + let mut mock = Builder::new().write_number(666).build(); + mock.write_value(&TestTryFromU64(666)).await.unwrap(); +} + +#[tokio::test] +async fn write_try_into_u64_invalid_data() { + let mut mock = Builder::new().build(); + let err = mock.write_value(&TestTryFromU64(42)).await.unwrap_err(); + assert_eq!(Error::UnsupportedData("42".into()), err); +} + +#[derive(Clone, Debug, PartialEq, Eq, NixSerialize)] +#[nix(into = "u64")] +struct TestFromU64; + +impl From for u64 { + fn from(_value: TestFromU64) -> u64 { + 42 + } +} + +#[tokio::test] +async fn write_into_u64() { + let mut mock = Builder::new().write_number(42).build(); + mock.write_value(&TestFromU64).await.unwrap(); +} + +#[derive(Debug, PartialEq, Eq, NixSerialize)] +enum TestEnum { + #[nix(version = "..=19")] + Pre20(TestFromU64, #[nix(version = "10..")] u64), + #[nix(version = "20..=29")] + Post20(StructVersionTest), + #[nix(version = "30..=39")] + Post30, + #[nix(version = "40..")] + Post40 { + msg: String, + #[nix(version = "45..")] + level: u64, + }, +} + +#[tokio::test] +async fn write_enum_9() { + let mut mock = Builder::new().version((1, 9)).write_number(42).build(); + mock.write_value(&TestEnum::Pre20(TestFromU64, 666)) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_enum_19() { + let mut mock = Builder::new() + .version((1, 19)) + .write_number(42) + .write_number(666) + .build(); + mock.write_value(&TestEnum::Pre20(TestFromU64, 666)) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_enum_20() { + let mut mock = Builder::new() + .version((1, 20)) + .write_number(666) + .write_slice(b"klomp") + .build(); + mock.write_value(&TestEnum::Post20(StructVersionTest { + test: 666, + hello: "klomp".into(), + })) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_enum_30() { + let mut mock = Builder::new().version((1, 30)).build(); + mock.write_value(&TestEnum::Post30).await.unwrap(); +} + +#[tokio::test] +async fn write_enum_40() { + let mut mock = Builder::new() + .version((1, 40)) + .write_slice(b"hello world") + .build(); + mock.write_value(&TestEnum::Post40 { + msg: "hello world".into(), + level: 9001, + }) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_enum_45() { + let mut mock = Builder::new() + .version((1, 45)) + .write_slice(b"hello world") + .write_number(9001) + .build(); + mock.write_value(&TestEnum::Post40 { + msg: "hello world".into(), + level: 9001, + }) + .await + .unwrap(); +} + +#[tokio::test] +async fn write_wrong_enum() { + let mut mock = Builder::new().version((1, 30)).build(); + let err = mock + .write_value(&TestEnum::Post40 { + msg: "hello world".into(), + level: 9001, + }) + .await + .unwrap_err(); + assert_eq!( + err, + Error::InvalidEnum("Post40 is not valid for version 1.30".into()) + ) +} diff --git a/tvix/nix-compat-derive/Cargo.toml b/tvix/nix-compat-derive/Cargo.toml index da6d6744e650..bc656b42842c 100644 --- a/tvix/nix-compat-derive/Cargo.toml +++ b/tvix/nix-compat-derive/Cargo.toml @@ -14,6 +14,7 @@ syn = { version = "2.0.76", features = ["full", "extra-traits"] } [dev-dependencies] hex-literal = { workspace = true } pretty_assertions = { workspace = true } +proptest = { workspace = true, features = ["std", "alloc", "tempfile"] } rstest = { workspace = true } tokio-test = { workspace = true } tokio = { workspace = true, features = ["io-util", "macros"] } diff --git a/tvix/nix-compat-derive/src/de.rs b/tvix/nix-compat-derive/src/de.rs index 2214254e2b32..e678b50b0533 100644 --- a/tvix/nix-compat-derive/src/de.rs +++ b/tvix/nix-compat-derive/src/de.rs @@ -34,6 +34,11 @@ pub fn expand_nix_deserialize_remote( ) -> syn::Result { let cx = Context::new(); let remote = Remote::from_ast(&cx, crate_path, input); + if let Some(attrs) = remote.as_ref().map(|r| &r.attrs) { + if attrs.from_str.is_none() && attrs.type_from.is_none() && attrs.type_try_from.is_none() { + cx.error_spanned(input, "Missing from_str, from or try_from attribute"); + } + } cx.check()?; let remote = remote.unwrap(); diff --git a/tvix/nix-compat-derive/src/internal/attrs.rs b/tvix/nix-compat-derive/src/internal/attrs.rs index 9ed84aaf8745..d0fa3b008e22 100644 --- a/tvix/nix-compat-derive/src/internal/attrs.rs +++ b/tvix/nix-compat-derive/src/internal/attrs.rs @@ -3,7 +3,9 @@ use syn::meta::ParseNestedMeta; use syn::parse::Parse; use syn::{parse_quote, Attribute, Expr, ExprLit, ExprPath, Lit, Token}; -use super::symbol::{Symbol, CRATE, DEFAULT, FROM, FROM_STR, NIX, TRY_FROM, VERSION}; +use super::symbol::{ + Symbol, CRATE, DEFAULT, DISPLAY, FROM, FROM_STR, INTO, NIX, TRY_FROM, TRY_INTO, VERSION, +}; use super::Context; #[derive(Debug, PartialEq, Eq)] @@ -104,6 +106,9 @@ pub struct Container { pub from_str: Option, pub type_from: Option, pub type_try_from: Option, + pub type_into: Option, + pub type_try_into: Option, + pub display: Default, pub crate_path: Option, } @@ -113,6 +118,9 @@ impl Container { let mut type_try_from = None; let mut crate_path = None; let mut from_str = None; + let mut type_into = None; + let mut type_try_into = None; + let mut display = Default::None; for attr in attrs { if attr.path() != NIX { @@ -125,6 +133,18 @@ impl Container { type_try_from = parse_lit(ctx, &meta, TRY_FROM)?; } else if meta.path == FROM_STR { from_str = Some(meta.path); + } else if meta.path == INTO { + type_into = parse_lit(ctx, &meta, INTO)?; + } else if meta.path == TRY_INTO { + type_try_into = parse_lit(ctx, &meta, TRY_INTO)?; + } else if meta.path == DISPLAY { + if meta.input.peek(Token![=]) { + if let Some(path) = parse_lit(ctx, &meta, DISPLAY)? { + display = Default::Path(path); + } + } else { + display = Default::Default(meta.path); + } } else if meta.path == CRATE { crate_path = parse_lit(ctx, &meta, CRATE)?; } else { @@ -144,6 +164,9 @@ impl Container { from_str, type_from, type_try_from, + type_into, + type_try_into, + display, crate_path, } } @@ -341,6 +364,46 @@ mod test { ); } + #[test] + fn parse_container_from_str() { + let attrs: Vec = vec![parse_quote!(#[nix(from_str)])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: Some(parse_quote!(from_str)), + type_from: None, + type_try_from: None, + type_into: None, + type_try_into: None, + display: Default::None, + crate_path: None, + } + ); + } + + #[test] + fn parse_container_from() { + let attrs: Vec = vec![parse_quote!(#[nix(from="u64")])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: None, + type_from: Some(parse_quote!(u64)), + type_try_from: None, + type_into: None, + type_try_into: None, + display: Default::None, + crate_path: None, + } + ); + } + #[test] fn parse_container_try_from() { let attrs: Vec = vec![parse_quote!(#[nix(try_from="u64")])]; @@ -353,6 +416,89 @@ mod test { from_str: None, type_from: None, type_try_from: Some(parse_quote!(u64)), + type_into: None, + type_try_into: None, + display: Default::None, + crate_path: None, + } + ); + } + + #[test] + fn parse_container_into() { + let attrs: Vec = vec![parse_quote!(#[nix(into="u64")])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: None, + type_from: None, + type_try_from: None, + type_into: Some(parse_quote!(u64)), + type_try_into: None, + display: Default::None, + crate_path: None, + } + ); + } + + #[test] + fn parse_container_try_into() { + let attrs: Vec = vec![parse_quote!(#[nix(try_into="u64")])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: None, + type_from: None, + type_try_from: None, + type_into: None, + type_try_into: Some(parse_quote!(u64)), + display: Default::None, + crate_path: None, + } + ); + } + + #[test] + fn parse_container_display() { + let attrs: Vec = vec![parse_quote!(#[nix(display)])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: None, + type_from: None, + type_try_from: None, + type_into: None, + type_try_into: None, + display: Default::Default(parse_quote!(display)), + crate_path: None, + } + ); + } + + #[test] + fn parse_container_display_path() { + let attrs: Vec = vec![parse_quote!(#[nix(display="Path::display")])]; + let ctx = Context::new(); + let container = Container::from_ast(&ctx, &attrs); + ctx.check().unwrap(); + assert_eq!( + container, + Container { + from_str: None, + type_from: None, + type_try_from: None, + type_into: None, + type_try_into: None, + display: Default::Path(parse_quote!(Path::display)), crate_path: None, } ); diff --git a/tvix/nix-compat-derive/src/internal/mod.rs b/tvix/nix-compat-derive/src/internal/mod.rs index 07ef43b6e0bb..aa42d904718d 100644 --- a/tvix/nix-compat-derive/src/internal/mod.rs +++ b/tvix/nix-compat-derive/src/internal/mod.rs @@ -154,10 +154,6 @@ impl<'a> Remote<'a> { input: &'a inputs::RemoteInput, ) -> Option> { let attrs = attrs::Container::from_ast(ctx, &input.attrs); - if attrs.from_str.is_none() && attrs.type_from.is_none() && attrs.type_try_from.is_none() { - ctx.error_spanned(input, "Missing from_str, from or try_from attribute"); - return None; - } Some(Remote { ty: &input.ident, attrs, diff --git a/tvix/nix-compat-derive/src/internal/symbol.rs b/tvix/nix-compat-derive/src/internal/symbol.rs index ed3fe304eb5d..2bbdc069aa0f 100644 --- a/tvix/nix-compat-derive/src/internal/symbol.rs +++ b/tvix/nix-compat-derive/src/internal/symbol.rs @@ -11,6 +11,9 @@ pub const DEFAULT: Symbol = Symbol("default"); pub const FROM: Symbol = Symbol("from"); pub const TRY_FROM: Symbol = Symbol("try_from"); pub const FROM_STR: Symbol = Symbol("from_str"); +pub const INTO: Symbol = Symbol("into"); +pub const TRY_INTO: Symbol = Symbol("try_into"); +pub const DISPLAY: Symbol = Symbol("display"); pub const CRATE: Symbol = Symbol("crate"); impl PartialEq for Path { diff --git a/tvix/nix-compat-derive/src/lib.rs b/tvix/nix-compat-derive/src/lib.rs index 89735cadf315..394473b1cbf8 100644 --- a/tvix/nix-compat-derive/src/lib.rs +++ b/tvix/nix-compat-derive/src/lib.rs @@ -6,7 +6,11 @@ //! 1. [`#[nix(from_str)]`](#nixfrom_str) //! 2. [`#[nix(from = "FromType")]`](#nixfrom--fromtype) //! 3. [`#[nix(try_from = "FromType")]`](#nixtry_from--fromtype) -//! 4. [`#[nix(crate = "...")]`](#nixcrate--) +//! 4. [`#[nix(into = "IntoType")]`](#nixinto--intotype) +//! 5. [`#[nix(try_into = "IntoType")]`](#nixtry_into--intotype) +//! 6. [`#[nix(display)]`](#nixdisplay) +//! 7. [`#[nix(display = "path")]`](#nixdisplay--path) +//! 8. [`#[nix(crate = "...")]`](#nixcrate--) //! 2. [Variant attributes](#variant-attributes) //! 1. [`#[nix(version = "range")]`](#nixversion--range) //! 3. [Field attributes](#field-attributes) @@ -17,20 +21,21 @@ //! ## Overview //! //! This crate contains derive macros and function-like macros for implementing -//! `NixDeserialize` with less boilerplate. +//! `NixDeserialize` and `NixSerialize` with less boilerplate. //! //! ### Examples +//! //! ```rust -//! # use nix_compat_derive::NixDeserialize; +//! # use nix_compat_derive::{NixDeserialize, NixSerialize}; //! # -//! #[derive(NixDeserialize)] +//! #[derive(NixDeserialize, NixSerialize)] //! struct Unnamed(u64, String); //! ``` //! //! ```rust -//! # use nix_compat_derive::NixDeserialize; +//! # use nix_compat_derive::{NixDeserialize, NixSerialize}; //! # -//! #[derive(NixDeserialize)] +//! #[derive(NixDeserialize, NixSerialize)] //! struct Fields { //! number: u64, //! message: String, @@ -38,9 +43,9 @@ //! ``` //! //! ```rust -//! # use nix_compat_derive::NixDeserialize; +//! # use nix_compat_derive::{NixDeserialize, NixSerialize}; //! # -//! #[derive(NixDeserialize)] +//! #[derive(NixDeserialize, NixSerialize)] //! struct Ignored; //! ``` //! @@ -64,7 +69,7 @@ //! #[derive(NixDeserialize)] //! #[nix(crate="nix_compat")] // <-- This is also a container attribute //! enum E { -//! #[nix(version="..=9")] // <-- This is a variant attribute +//! #[nix(version="..10")] // <-- This is a variant attribute //! A(u64), //! #[nix(version="10..")] // <-- This is also a variant attribute //! B(String), @@ -156,6 +161,114 @@ //! } //! ``` //! +//! ##### `#[nix(into = "IntoType")]` +//! +//! When `into` is specified the fields are all ignored and instead the +//! container type is converted to `IntoType` using `Into::into` and +//! `IntoType` is then serialized. Before converting `Clone::clone` is +//! called. +//! +//! This means that the container must implement `Into` and `Clone` +//! and `IntoType` must implement `NixSerialize`. +//! +//! ###### Example +//! +//! ```rust +//! # use nix_compat_derive::NixSerialize; +//! # +//! #[derive(Clone, NixSerialize)] +//! #[nix(into="usize")] +//! struct MyValue(usize); +//! impl From for usize { +//! fn from(val: MyValue) -> Self { +//! val.0 +//! } +//! } +//! ``` +//! +//! ##### `#[nix(try_into = "IntoType")]` +//! +//! When `try_into` is specified the fields are all ignored and instead the +//! container type is converted to `IntoType` using `TryInto::try_into` and +//! `IntoType` is then serialized. Before converting `Clone::clone` is +//! called. +//! +//! This means that the container must implement `TryInto` and +//! `Clone` and `IntoType` must implement `NixSerialize`. +//! The error returned from `try_into` also needs to implement `Display`. +//! +//! ###### Example +//! +//! ```rust +//! # use nix_compat_derive::NixSerialize; +//! # +//! #[derive(Clone, NixSerialize)] +//! #[nix(try_into="usize")] +//! struct WrongAnswer(usize); +//! impl TryFrom for usize { +//! type Error = String; +//! fn try_from(val: WrongAnswer) -> Result { +//! if val.0 != 42 { +//! Ok(val.0) +//! } else { +//! Err("Got the answer to life the universe and everything".to_string()) +//! } +//! } +//! } +//! ``` +//! +//! ##### `#[nix(display)]` +//! +//! When `display` is specified the fields are all ignored and instead the +//! container must implement `Display` and `NixWrite::write_display` is used to +//! write the container. +//! +//! ###### Example +//! +//! ```rust +//! # use nix_compat_derive::NixSerialize; +//! # use std::fmt::{Display, Result, Formatter}; +//! # +//! #[derive(NixSerialize)] +//! #[nix(display)] +//! struct WrongAnswer(usize); +//! impl Display for WrongAnswer { +//! fn fmt(&self, f: &mut Formatter<'_>) -> Result { +//! write!(f, "Wrong Answer = {}", self.0) +//! } +//! } +//! ``` +//! +//! ##### `#[nix(display = "path")]` +//! +//! When `display` is specified the fields are all ignored and instead the +//! container the specified path must point to a function that is callable as +//! `fn(&T) -> impl Display`. The result from this call is then written with +//! `NixWrite::write_display`. +//! For example `default = "my_value"` would call `my_value(&self)` and `display = +//! "AType::empty"` would call `AType::empty(&self)`. +//! +//! ###### Example +//! +//! ```rust +//! # use nix_compat_derive::NixSerialize; +//! # use std::fmt::{Display, Result, Formatter}; +//! # +//! #[derive(NixSerialize)] +//! #[nix(display = "format_it")] +//! struct WrongAnswer(usize); +//! struct WrongDisplay<'a>(&'a WrongAnswer); +//! impl<'a> Display for WrongDisplay<'a> { +//! fn fmt(&self, f: &mut Formatter<'_>) -> Result { +//! write!(f, "Wrong Answer = {}", self.0.0) +//! } +//! } +//! +//! fn format_it(value: &WrongAnswer) -> impl Display + '_ { +//! WrongDisplay(value) +//! } +//! ``` +//! //! ##### `#[nix(crate = "...")]` //! //! Specify the path to the `nix-compat` crate instance to use when referring @@ -175,6 +288,7 @@ //! //! ```rust //! # use nix_compat_derive::NixDeserialize; +//! # //! #[derive(NixDeserialize)] //! enum Testing { //! #[nix(version="..=18")] @@ -260,6 +374,7 @@ use syn::{parse_quote, DeriveInput}; mod de; mod internal; +mod ser; #[proc_macro_derive(NixDeserialize, attributes(nix))] pub fn derive_nix_deserialize(item: TokenStream) -> TokenStream { @@ -270,6 +385,15 @@ pub fn derive_nix_deserialize(item: TokenStream) -> TokenStream { .into() } +#[proc_macro_derive(NixSerialize, attributes(nix))] +pub fn derive_nix_serialize(item: TokenStream) -> TokenStream { + let mut input = syn::parse_macro_input!(item as DeriveInput); + let crate_path: syn::Path = parse_quote!(::nix_compat); + ser::expand_nix_serialize(crate_path, &mut input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + /// Macro to implement `NixDeserialize` on a type. /// Sometimes you can't use the deriver to implement `NixDeserialize` /// (like when dealing with types in Rust standard library) but don't want @@ -301,3 +425,36 @@ pub fn nix_deserialize_remote(item: TokenStream) -> TokenStream { .unwrap_or_else(syn::Error::into_compile_error) .into() } + +/// Macro to implement `NixSerialize` on a type. +/// Sometimes you can't use the deriver to implement `NixSerialize` +/// (like when dealing with types in Rust standard library) but don't want +/// to implement it yourself. So this macro can be used for those situations +/// where you would derive using `#[nix(display)]`, `#[nix(display = "path")]`, +/// `#[nix(store_dir_display)]`, `#[nix(into = "IntoType")]` or +/// `#[nix(try_into = "IntoType")]` if you could. +/// +/// #### Example +/// +/// ```rust +/// # use nix_compat_derive::nix_serialize_remote; +/// # +/// #[derive(Clone)] +/// struct MyU64(u64); +/// +/// impl From for u64 { +/// fn from(value: MyU64) -> Self { +/// value.0 +/// } +/// } +/// +/// nix_serialize_remote!(#[nix(into="u64")] MyU64); +/// ``` +#[proc_macro] +pub fn nix_serialize_remote(item: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(item as RemoteInput); + let crate_path = parse_quote!(::nix_compat); + ser::expand_nix_serialize_remote(crate_path, &input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/tvix/nix-compat-derive/src/ser.rs b/tvix/nix-compat-derive/src/ser.rs new file mode 100644 index 000000000000..47ddfa39366d --- /dev/null +++ b/tvix/nix-compat-derive/src/ser.rs @@ -0,0 +1,227 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; +use syn::{DeriveInput, Generics, Path, Type}; + +use crate::internal::attrs::Default; +use crate::internal::inputs::RemoteInput; +use crate::internal::{attrs, Container, Context, Data, Field, Remote, Style, Variant}; + +pub fn expand_nix_serialize(crate_path: Path, input: &mut DeriveInput) -> syn::Result { + let cx = Context::new(); + let cont = Container::from_ast(&cx, crate_path, input); + cx.check()?; + let cont = cont.unwrap(); + + let ty = cont.ident_type(); + let body = nix_serialize_body(&cont); + let crate_path = cont.crate_path(); + + Ok(nix_serialize_impl( + crate_path, + &ty, + &cont.original.generics, + body, + )) +} + +pub fn expand_nix_serialize_remote( + crate_path: Path, + input: &RemoteInput, +) -> syn::Result { + let cx = Context::new(); + let remote = Remote::from_ast(&cx, crate_path, input); + if let Some(attrs) = remote.as_ref().map(|r| &r.attrs) { + if attrs.display.is_none() && attrs.type_into.is_none() && attrs.type_try_into.is_none() { + cx.error_spanned(input, "Missing into, try_into or display attribute"); + } + } + cx.check()?; + let remote = remote.unwrap(); + + let crate_path = remote.crate_path(); + let body = nix_serialize_body_into(crate_path, &remote.attrs).expect("From tokenstream"); + let generics = Generics::default(); + Ok(nix_serialize_impl(crate_path, remote.ty, &generics, body)) +} + +fn nix_serialize_impl( + crate_path: &Path, + ty: &Type, + generics: &Generics, + body: TokenStream, +) -> TokenStream { + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + quote! { + #[automatically_derived] + impl #impl_generics #crate_path::nix_daemon::ser::NixSerialize for #ty #ty_generics + #where_clause + { + async fn serialize(&self, writer: &mut W) -> std::result::Result<(), W::Error> + where W: #crate_path::nix_daemon::ser::NixWrite + { + use #crate_path::nix_daemon::ser::Error as _; + #body + } + } + } +} + +fn nix_serialize_body_into( + crate_path: &syn::Path, + attrs: &attrs::Container, +) -> Option { + if let Default::Default(span) = &attrs.display { + Some(nix_serialize_display(span.span())) + } else if let Default::Path(path) = &attrs.display { + Some(nix_serialize_display_path(path)) + } else if let Some(type_into) = attrs.type_into.as_ref() { + Some(nix_serialize_into(type_into)) + } else { + attrs + .type_try_into + .as_ref() + .map(|type_try_into| nix_serialize_try_into(crate_path, type_try_into)) + } +} + +fn nix_serialize_body(cont: &Container) -> TokenStream { + if let Some(tokens) = nix_serialize_body_into(cont.crate_path(), &cont.attrs) { + tokens + } else { + match &cont.data { + Data::Struct(_style, fields) => nix_serialize_struct(fields), + Data::Enum(variants) => nix_serialize_enum(variants), + } + } +} + +fn nix_serialize_struct(fields: &[Field<'_>]) -> TokenStream { + let write_fields = fields.iter().map(|f| { + let field = &f.member; + let ty = f.ty; + let write_value = quote_spanned! { + ty.span()=> writer.write_value(&self.#field).await? + }; + if let Some(version) = f.attrs.version.as_ref() { + quote! { + if (#version).contains(&writer.version().minor()) { + #write_value; + } + } + } else { + quote! { + #write_value; + } + } + }); + + quote! { + #(#write_fields)* + Ok(()) + } +} + +fn nix_serialize_variant(variant: &Variant<'_>) -> TokenStream { + let ident = variant.ident; + let write_fields = variant.fields.iter().map(|f| { + let field = f.var_ident(); + let ty = f.ty; + let write_value = quote_spanned! { + ty.span()=> writer.write_value(#field).await? + }; + if let Some(version) = f.attrs.version.as_ref() { + quote! { + if (#version).contains(&writer.version().minor()) { + #write_value; + } + } + } else { + quote! { + #write_value; + } + } + }); + let field_names = variant.fields.iter().map(|f| f.var_ident()); + let destructure = match variant.style { + Style::Struct => { + quote! { + Self::#ident { #(#field_names),* } + } + } + Style::Tuple => { + quote! { + Self::#ident(#(#field_names),*) + } + } + Style::Unit => quote!(Self::#ident), + }; + let ignore = match variant.style { + Style::Struct => { + quote! { + Self::#ident { .. } + } + } + Style::Tuple => { + quote! { + Self::#ident(_, ..) + } + } + Style::Unit => quote!(Self::#ident), + }; + let version = &variant.attrs.version; + quote! { + #destructure if (#version).contains(&writer.version().minor()) => { + #(#write_fields)* + } + #ignore => { + return Err(W::Error::invalid_enum(format!("{} is not valid for version {}", stringify!(#ident), writer.version()))); + } + } +} + +fn nix_serialize_enum(variants: &[Variant<'_>]) -> TokenStream { + let match_variant = variants + .iter() + .map(|variant| nix_serialize_variant(variant)); + quote! { + match self { + #(#match_variant)* + } + Ok(()) + } +} + +fn nix_serialize_into(ty: &Type) -> TokenStream { + quote_spanned! { + ty.span() => + { + let other : #ty = ::clone(self).into(); + writer.write_value(&other).await + } + } +} + +fn nix_serialize_try_into(crate_path: &Path, ty: &Type) -> TokenStream { + quote_spanned! { + ty.span() => + { + use #crate_path::nix_daemon::ser::Error; + let other : #ty = ::clone(self).try_into().map_err(Error::unsupported_data)?; + writer.write_value(&other).await + } + } +} + +fn nix_serialize_display(span: Span) -> TokenStream { + quote_spanned! { + span => writer.write_display(self).await + } +} + +fn nix_serialize_display_path(path: &syn::ExprPath) -> TokenStream { + quote_spanned! { + path.span() => writer.write_display(#path(self)).await + } +} diff --git a/tvix/nix-compat/Cargo.toml b/tvix/nix-compat/Cargo.toml index f430a5461829..cbbf97175d14 100644 --- a/tvix/nix-compat/Cargo.toml +++ b/tvix/nix-compat/Cargo.toml @@ -43,6 +43,7 @@ futures = { workspace = true } hex-literal = { workspace = true } mimalloc = { workspace = true } pretty_assertions = { workspace = true } +proptest = { workspace = true, features = ["std", "alloc", "tempfile"] } rstest = { workspace = true } serde_json = { workspace = true } smol_str = { workspace = true } diff --git a/tvix/nix-compat/src/nix_daemon/mod.rs b/tvix/nix-compat/src/nix_daemon/mod.rs index 11413e85fd1b..a943b279f891 100644 --- a/tvix/nix-compat/src/nix_daemon/mod.rs +++ b/tvix/nix-compat/src/nix_daemon/mod.rs @@ -4,3 +4,4 @@ mod protocol_version; pub use protocol_version::ProtocolVersion; pub mod de; +pub mod ser; diff --git a/tvix/nix-compat/src/nix_daemon/ser/bytes.rs b/tvix/nix-compat/src/nix_daemon/ser/bytes.rs new file mode 100644 index 000000000000..19494934ff32 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/bytes.rs @@ -0,0 +1,89 @@ +use bytes::Bytes; + +use super::{NixSerialize, NixWrite}; + +impl NixSerialize for Bytes { + async fn serialize(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self).await + } +} + +impl<'a> NixSerialize for &'a [u8] { + async fn serialize(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self).await + } +} + +impl NixSerialize for String { + async fn serialize(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +impl NixSerialize for str { + async fn serialize(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_slice(self.as_bytes()).await + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::nix_daemon::ser::{NixWrite, NixWriter}; + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_write_str(#[case] value: &str, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[case::utf8("The quick brown 🦊 jumps over 13 lazy 🐶.", &hex!("2D00 0000 0000 0000 5468 6520 7175 6963 6b20 6272 6f77 6e20 f09f a68a 206a 756d 7073 206f 7665 7220 3133 206c 617a 7920 f09f 90b6 2e00 0000"))] + #[tokio::test] + async fn test_write_string(#[case] value: &str, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value.to_string()).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/ser/collections.rs b/tvix/nix-compat/src/nix_daemon/ser/collections.rs new file mode 100644 index 000000000000..70c32e1c79ac --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/collections.rs @@ -0,0 +1,94 @@ +use std::collections::BTreeMap; +use std::future::Future; + +use super::{NixSerialize, NixWrite}; + +impl NixSerialize for Vec +where + T: NixSerialize + Send + Sync, +{ + #[allow(clippy::manual_async_fn)] + fn serialize(&self, writer: &mut W) -> impl Future> + Send + where + W: NixWrite, + { + async move { + writer.write_value(&self.len()).await?; + for value in self.iter() { + writer.write_value(value).await?; + } + Ok(()) + } + } +} + +impl NixSerialize for BTreeMap +where + K: NixSerialize + Ord + Send + Sync, + V: NixSerialize + Send + Sync, +{ + #[allow(clippy::manual_async_fn)] + fn serialize(&self, writer: &mut W) -> impl Future> + Send + where + W: NixWrite, + { + async move { + writer.write_value(&self.len()).await?; + for (key, value) in self.iter() { + writer.write_value(key).await?; + writer.write_value(value).await?; + } + Ok(()) + } + } +} + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + use std::fmt; + + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::nix_daemon::ser::{NixSerialize, NixWrite, NixWriter}; + + #[rstest] + #[case::empty(vec![], &hex!("0000 0000 0000 0000"))] + #[case::one(vec![0x29], &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(vec![0x7469, 10], &hex!("0200 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_write_small_vec(#[case] value: Vec, #[case] data: &[u8]) { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + fn empty_map() -> BTreeMap { + BTreeMap::new() + } + macro_rules! map { + ($($key:expr => $value:expr),*) => {{ + let mut ret = BTreeMap::new(); + $(ret.insert($key, $value);)* + ret + }}; + } + + #[rstest] + #[case::empty(empty_map(), &hex!("0000 0000 0000 0000"))] + #[case::one(map![0x7469usize => 10u64], &hex!("0100 0000 0000 0000 6974 0000 0000 0000 0A00 0000 0000 0000"))] + #[tokio::test] + async fn test_write_small_btree_map(#[case] value: E, #[case] data: &[u8]) + where + E: NixSerialize + Send + PartialEq + fmt::Debug, + { + let mock = Builder::new().write(data).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/ser/display.rs b/tvix/nix-compat/src/nix_daemon/ser/display.rs new file mode 100644 index 000000000000..a3438d50d8ff --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/display.rs @@ -0,0 +1,8 @@ +use nix_compat_derive::nix_serialize_remote; + +use crate::nixhash; + +nix_serialize_remote!( + #[nix(display)] + nixhash::HashAlgo +); diff --git a/tvix/nix-compat/src/nix_daemon/ser/int.rs b/tvix/nix-compat/src/nix_daemon/ser/int.rs new file mode 100644 index 000000000000..1be06442e322 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/int.rs @@ -0,0 +1,108 @@ +#[cfg(feature = "nix-compat-derive")] +use nix_compat_derive::nix_serialize_remote; + +use super::{Error, NixSerialize, NixWrite}; + +impl NixSerialize for u64 { + async fn serialize(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_number(*self).await + } +} + +impl NixSerialize for usize { + async fn serialize(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + let v = (*self).try_into().map_err(W::Error::unsupported_data)?; + writer.write_number(v).await + } +} + +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u8 +); +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u16 +); +#[cfg(feature = "nix-compat-derive")] +nix_serialize_remote!( + #[nix(into = "u64")] + u32 +); + +impl NixSerialize for bool { + async fn serialize(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + if *self { + writer.write_number(1).await + } else { + writer.write_number(0).await + } + } +} + +impl NixSerialize for i64 { + async fn serialize(&self, writer: &mut W) -> Result<(), W::Error> + where + W: NixWrite, + { + writer.write_number(*self as u64).await + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::nix_daemon::ser::{NixWrite, NixWriter}; + + #[rstest] + #[case::simple_false(false, &hex!("0000 0000 0000 0000"))] + #[case::simple_true(true, &hex!("0100 0000 0000 0000"))] + #[tokio::test] + async fn test_write_bool(#[case] value: bool, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_write_u64(#[case] value: u64, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } + + #[rstest] + #[case::zero(0, &hex!("0000 0000 0000 0000"))] + #[case::one(1, &hex!("0100 0000 0000 0000"))] + #[case::other(0x563412, &hex!("1234 5600 0000 0000"))] + #[case::max_value(usize::MAX, &usize::MAX.to_le_bytes())] + #[tokio::test] + async fn test_write_usize(#[case] value: usize, #[case] expected: &[u8]) { + let mock = Builder::new().write(expected).build(); + let mut writer = NixWriter::new(mock); + writer.write_value(&value).await.unwrap(); + writer.flush().await.unwrap(); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/ser/mock.rs b/tvix/nix-compat/src/nix_daemon/ser/mock.rs new file mode 100644 index 000000000000..1319a8da3228 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/mock.rs @@ -0,0 +1,672 @@ +use std::collections::VecDeque; +use std::fmt; +use std::io; +use std::thread; + +#[cfg(test)] +use ::proptest::prelude::TestCaseError; +use thiserror::Error; + +use crate::nix_daemon::ProtocolVersion; + +use super::NixWrite; + +#[derive(Debug, Error, PartialEq, Eq, Clone)] +pub enum Error { + #[error("custom error '{0}'")] + Custom(String), + #[error("unsupported data error '{0}'")] + UnsupportedData(String), + #[error("Invalid enum: {0}")] + InvalidEnum(String), + #[error("IO error {0} '{1}'")] + IO(io::ErrorKind, String), + #[error("wrong write: expected {0} got {1}")] + WrongWrite(OperationType, OperationType), + #[error("unexpected write: got an extra {0}")] + ExtraWrite(OperationType), + #[error("got an unexpected number {0} in write_number")] + UnexpectedNumber(u64), + #[error("got an unexpected slice '{0:?}' in write_slice")] + UnexpectedSlice(Vec), + #[error("got an unexpected display '{0:?}' in write_slice")] + UnexpectedDisplay(String), +} + +impl Error { + pub fn unexpected_write_number(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteNumber) + } + + pub fn extra_write_number() -> Error { + Error::ExtraWrite(OperationType::WriteNumber) + } + + pub fn unexpected_write_slice(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteSlice) + } + + pub fn unexpected_write_display(expected: OperationType) -> Error { + Error::WrongWrite(expected, OperationType::WriteDisplay) + } +} + +impl super::Error for Error { + fn custom(msg: T) -> Self { + Self::Custom(msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + Self::IO(err.kind(), err.to_string()) + } + + fn unsupported_data(msg: T) -> Self { + Self::UnsupportedData(msg.to_string()) + } + + fn invalid_enum(msg: T) -> Self { + Self::InvalidEnum(msg.to_string()) + } +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OperationType { + WriteNumber, + WriteSlice, + WriteDisplay, +} + +impl fmt::Display for OperationType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::WriteNumber => write!(f, "write_number"), + Self::WriteSlice => write!(f, "write_slice"), + Self::WriteDisplay => write!(f, "write_display"), + } + } +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, PartialEq, Eq)] +enum Operation { + WriteNumber(u64, Result<(), Error>), + WriteSlice(Vec, Result<(), Error>), + WriteDisplay(String, Result<(), Error>), +} + +impl From for OperationType { + fn from(value: Operation) -> Self { + match value { + Operation::WriteNumber(_, _) => OperationType::WriteNumber, + Operation::WriteSlice(_, _) => OperationType::WriteSlice, + Operation::WriteDisplay(_, _) => OperationType::WriteDisplay, + } + } +} + +pub struct Builder { + version: ProtocolVersion, + ops: VecDeque, +} + +impl Builder { + pub fn new() -> Builder { + Builder { + version: Default::default(), + ops: VecDeque::new(), + } + } + + pub fn version>(&mut self, version: V) -> &mut Self { + self.version = version.into(); + self + } + + pub fn write_number(&mut self, value: u64) -> &mut Self { + self.ops.push_back(Operation::WriteNumber(value, Ok(()))); + self + } + + pub fn write_number_error(&mut self, value: u64, err: Error) -> &mut Self { + self.ops.push_back(Operation::WriteNumber(value, Err(err))); + self + } + + pub fn write_slice(&mut self, value: &[u8]) -> &mut Self { + self.ops + .push_back(Operation::WriteSlice(value.to_vec(), Ok(()))); + self + } + + pub fn write_slice_error(&mut self, value: &[u8], err: Error) -> &mut Self { + self.ops + .push_back(Operation::WriteSlice(value.to_vec(), Err(err))); + self + } + + pub fn write_display(&mut self, value: D) -> &mut Self + where + D: fmt::Display, + { + let msg = value.to_string(); + self.ops.push_back(Operation::WriteDisplay(msg, Ok(()))); + self + } + + pub fn write_display_error(&mut self, value: D, err: Error) -> &mut Self + where + D: fmt::Display, + { + let msg = value.to_string(); + self.ops.push_back(Operation::WriteDisplay(msg, Err(err))); + self + } + + #[cfg(test)] + fn write_operation_type(&mut self, op: OperationType) -> &mut Self { + match op { + OperationType::WriteNumber => self.write_number(10), + OperationType::WriteSlice => self.write_slice(b"testing"), + OperationType::WriteDisplay => self.write_display("testing"), + } + } + + #[cfg(test)] + fn write_operation(&mut self, op: &Operation) -> &mut Self { + match op { + Operation::WriteNumber(value, Ok(_)) => self.write_number(*value), + Operation::WriteNumber(value, Err(Error::UnexpectedNumber(_))) => { + self.write_number(*value) + } + Operation::WriteNumber(_, Err(Error::ExtraWrite(OperationType::WriteNumber))) => self, + Operation::WriteNumber(_, Err(Error::WrongWrite(op, OperationType::WriteNumber))) => { + self.write_operation_type(*op) + } + Operation::WriteNumber(value, Err(Error::Custom(msg))) => { + self.write_number_error(*value, Error::Custom(msg.clone())) + } + Operation::WriteNumber(value, Err(Error::IO(kind, msg))) => { + self.write_number_error(*value, Error::IO(*kind, msg.clone())) + } + Operation::WriteSlice(value, Ok(_)) => self.write_slice(value), + Operation::WriteSlice(value, Err(Error::UnexpectedSlice(_))) => self.write_slice(value), + Operation::WriteSlice(_, Err(Error::ExtraWrite(OperationType::WriteSlice))) => self, + Operation::WriteSlice(_, Err(Error::WrongWrite(op, OperationType::WriteSlice))) => { + self.write_operation_type(*op) + } + Operation::WriteSlice(value, Err(Error::Custom(msg))) => { + self.write_slice_error(value, Error::Custom(msg.clone())) + } + Operation::WriteSlice(value, Err(Error::IO(kind, msg))) => { + self.write_slice_error(value, Error::IO(*kind, msg.clone())) + } + Operation::WriteDisplay(value, Ok(_)) => self.write_display(value), + Operation::WriteDisplay(value, Err(Error::Custom(msg))) => { + self.write_display_error(value, Error::Custom(msg.clone())) + } + Operation::WriteDisplay(value, Err(Error::IO(kind, msg))) => { + self.write_display_error(value, Error::IO(*kind, msg.clone())) + } + Operation::WriteDisplay(value, Err(Error::UnexpectedDisplay(_))) => { + self.write_display(value) + } + Operation::WriteDisplay(_, Err(Error::ExtraWrite(OperationType::WriteDisplay))) => self, + Operation::WriteDisplay(_, Err(Error::WrongWrite(op, OperationType::WriteDisplay))) => { + self.write_operation_type(*op) + } + s => panic!("Invalid operation {:?}", s), + } + } + + pub fn build(&mut self) -> Mock { + Mock { + version: self.version, + ops: self.ops.clone(), + } + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +pub struct Mock { + version: ProtocolVersion, + ops: VecDeque, +} + +impl Mock { + #[cfg(test)] + #[allow(dead_code)] + async fn assert_operation(&mut self, op: Operation) { + match op { + Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { + assert_eq!( + self.write_number(value).await, + Err(Error::UnexpectedNumber(value)) + ); + } + Operation::WriteNumber(value, res) => { + assert_eq!(self.write_number(value).await, res); + } + Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { + assert_eq!(self.write_slice(value).await, res.clone()); + } + Operation::WriteSlice(value, res) => { + assert_eq!(self.write_slice(&value).await, res); + } + Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { + assert_eq!(self.write_display(value).await, res.clone()); + } + Operation::WriteDisplay(value, res) => { + assert_eq!(self.write_display(value).await, res); + } + } + } + + #[cfg(test)] + async fn prop_assert_operation(&mut self, op: Operation) -> Result<(), TestCaseError> { + use ::proptest::prop_assert_eq; + + match op { + Operation::WriteNumber(_, Err(Error::UnexpectedNumber(value))) => { + prop_assert_eq!( + self.write_number(value).await, + Err(Error::UnexpectedNumber(value)) + ); + } + Operation::WriteNumber(value, res) => { + prop_assert_eq!(self.write_number(value).await, res); + } + Operation::WriteSlice(_, ref res @ Err(Error::UnexpectedSlice(ref value))) => { + prop_assert_eq!(self.write_slice(value).await, res.clone()); + } + Operation::WriteSlice(value, res) => { + prop_assert_eq!(self.write_slice(&value).await, res); + } + Operation::WriteDisplay(_, ref res @ Err(Error::UnexpectedDisplay(ref value))) => { + prop_assert_eq!(self.write_display(&value).await, res.clone()); + } + Operation::WriteDisplay(value, res) => { + prop_assert_eq!(self.write_display(&value).await, res); + } + } + Ok(()) + } +} + +impl NixWrite for Mock { + type Error = Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { + match self.ops.pop_front() { + Some(Operation::WriteNumber(expected, ret)) => { + if value != expected { + return Err(Error::UnexpectedNumber(value)); + } + ret + } + Some(op) => Err(Error::unexpected_write_number(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteNumber)), + } + } + + async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + match self.ops.pop_front() { + Some(Operation::WriteSlice(expected, ret)) => { + if buf != expected { + return Err(Error::UnexpectedSlice(buf.to_vec())); + } + ret + } + Some(op) => Err(Error::unexpected_write_slice(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteSlice)), + } + } + + async fn write_display(&mut self, msg: D) -> Result<(), Self::Error> + where + D: fmt::Display + Send, + Self: Sized, + { + let value = msg.to_string(); + match self.ops.pop_front() { + Some(Operation::WriteDisplay(expected, ret)) => { + if value != expected { + return Err(Error::UnexpectedDisplay(value)); + } + ret + } + Some(op) => Err(Error::unexpected_write_display(op.into())), + _ => Err(Error::ExtraWrite(OperationType::WriteDisplay)), + } + } +} + +impl Drop for Mock { + fn drop(&mut self) { + // No need to panic again + if thread::panicking() { + return; + } + if let Some(op) = self.ops.front() { + panic!("reader dropped with {op:?} operation still unread") + } + } +} + +#[cfg(test)] +mod proptest { + use std::io; + + use proptest::{ + prelude::{any, Arbitrary, BoxedStrategy, Just, Strategy}, + prop_oneof, + }; + + use super::{Error, Operation, OperationType}; + + pub fn arb_write_number_operation() -> impl Strategy { + ( + any::(), + prop_oneof![ + Just(Ok(())), + any::().prop_map(|v| Err(Error::UnexpectedNumber(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteSlice, + OperationType::WriteNumber + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteDisplay, + OperationType::WriteNumber + ))), + any::().prop_map(|s| Err(Error::Custom(s))), + (any::(), any::()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same number", |(v, res)| match res { + Err(Error::UnexpectedNumber(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteNumber(v, res)) + } + + pub fn arb_write_slice_operation() -> impl Strategy { + ( + any::>(), + prop_oneof![ + Just(Ok(())), + any::>().prop_map(|v| Err(Error::UnexpectedSlice(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteNumber, + OperationType::WriteSlice + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteDisplay, + OperationType::WriteSlice + ))), + any::().prop_map(|s| Err(Error::Custom(s))), + (any::(), any::()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same slice", |(v, res)| match res { + Err(Error::UnexpectedSlice(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteSlice(v, res)) + } + + #[allow(dead_code)] + pub fn arb_extra_write() -> impl Strategy { + prop_oneof![ + any::().prop_map(|msg| { + Operation::WriteNumber(msg, Err(Error::ExtraWrite(OperationType::WriteNumber))) + }), + any::>().prop_map(|msg| { + Operation::WriteSlice(msg, Err(Error::ExtraWrite(OperationType::WriteSlice))) + }), + any::().prop_map(|msg| { + Operation::WriteDisplay(msg, Err(Error::ExtraWrite(OperationType::WriteDisplay))) + }), + ] + } + + pub fn arb_write_display_operation() -> impl Strategy { + ( + any::(), + prop_oneof![ + Just(Ok(())), + any::().prop_map(|v| Err(Error::UnexpectedDisplay(v))), + Just(Err(Error::WrongWrite( + OperationType::WriteNumber, + OperationType::WriteDisplay + ))), + Just(Err(Error::WrongWrite( + OperationType::WriteSlice, + OperationType::WriteDisplay + ))), + any::().prop_map(|s| Err(Error::Custom(s))), + (any::(), any::()) + .prop_map(|(kind, msg)| Err(Error::IO(kind, msg))), + ], + ) + .prop_filter("same string", |(v, res)| match res { + Err(Error::UnexpectedDisplay(exp_v)) => v != exp_v, + _ => true, + }) + .prop_map(|(v, res)| Operation::WriteDisplay(v, res)) + } + + pub fn arb_operation() -> impl Strategy { + prop_oneof![ + arb_write_number_operation(), + arb_write_slice_operation(), + arb_write_display_operation(), + ] + } + + impl Arbitrary for Operation { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + arb_operation().boxed() + } + } +} + +#[cfg(test)] +mod test { + use hex_literal::hex; + use proptest::prelude::any; + use proptest::prelude::TestCaseError; + use proptest::proptest; + + use crate::nix_daemon::ser::mock::proptest::arb_extra_write; + use crate::nix_daemon::ser::mock::Operation; + use crate::nix_daemon::ser::mock::OperationType; + use crate::nix_daemon::ser::Error as _; + use crate::nix_daemon::ser::NixWrite; + + use super::{Builder, Error}; + + #[tokio::test] + async fn write_number() { + let mut mock = Builder::new().write_number(10).build(); + mock.write_number(10).await.unwrap(); + } + + #[tokio::test] + async fn write_number_error() { + let mut mock = Builder::new() + .write_number_error(10, Error::custom("bad number")) + .build(); + assert_eq!( + Err(Error::custom("bad number")), + mock.write_number(10).await + ); + } + + #[tokio::test] + async fn write_number_unexpected() { + let mut mock = Builder::new().write_slice(b"").build(); + assert_eq!( + Err(Error::unexpected_write_number(OperationType::WriteSlice)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn write_number_unexpected_number() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::UnexpectedNumber(11)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn extra_write_number() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteNumber)), + mock.write_number(11).await + ); + } + + #[tokio::test] + async fn write_slice() { + let mut mock = Builder::new() + .write_slice(&[]) + .write_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .build(); + mock.write_slice(&[]).await.expect("write_slice empty"); + mock.write_slice(&hex!("0000 1234 5678 9ABC DEFF")) + .await + .expect("write_slice"); + } + + #[tokio::test] + async fn write_slice_error() { + let mut mock = Builder::new() + .write_slice_error(&[], Error::custom("bad slice")) + .build(); + assert_eq!(Err(Error::custom("bad slice")), mock.write_slice(&[]).await); + } + + #[tokio::test] + async fn write_slice_unexpected() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::unexpected_write_slice(OperationType::WriteNumber)), + mock.write_slice(b"").await + ); + } + + #[tokio::test] + async fn write_slice_unexpected_slice() { + let mut mock = Builder::new().write_slice(b"").build(); + assert_eq!( + Err(Error::UnexpectedSlice(b"bad slice".to_vec())), + mock.write_slice(b"bad slice").await + ); + } + + #[tokio::test] + async fn extra_write_slice() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteSlice)), + mock.write_slice(b"extra slice").await + ); + } + + #[tokio::test] + async fn write_display() { + let mut mock = Builder::new().write_display("testing").build(); + mock.write_display("testing").await.unwrap(); + } + + #[tokio::test] + async fn write_display_error() { + let mut mock = Builder::new() + .write_display_error("testing", Error::custom("bad number")) + .build(); + assert_eq!( + Err(Error::custom("bad number")), + mock.write_display("testing").await + ); + } + + #[tokio::test] + async fn write_display_unexpected() { + let mut mock = Builder::new().write_number(10).build(); + assert_eq!( + Err(Error::unexpected_write_display(OperationType::WriteNumber)), + mock.write_display("").await + ); + } + + #[tokio::test] + async fn write_display_unexpected_display() { + let mut mock = Builder::new().write_display("").build(); + assert_eq!( + Err(Error::UnexpectedDisplay("bad display".to_string())), + mock.write_display("bad display").await + ); + } + + #[tokio::test] + async fn extra_write_display() { + let mut mock = Builder::new().build(); + assert_eq!( + Err(Error::ExtraWrite(OperationType::WriteDisplay)), + mock.write_display("extra slice").await + ); + } + + #[test] + #[should_panic] + fn operations_left() { + let _ = Builder::new().write_number(10).build(); + } + + #[test] + fn proptest_mock() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + proptest!(|( + operations in any::>(), + extra_operations in proptest::collection::vec(arb_extra_write(), 0..3) + )| { + rt.block_on(async { + let mut builder = Builder::new(); + for op in operations.iter() { + builder.write_operation(op); + } + for op in extra_operations.iter() { + builder.write_operation(op); + } + let mut mock = builder.build(); + for op in operations { + mock.prop_assert_operation(op).await?; + } + for op in extra_operations { + mock.prop_assert_operation(op).await?; + } + Ok(()) as Result<(), TestCaseError> + })?; + }); + } +} diff --git a/tvix/nix-compat/src/nix_daemon/ser/mod.rs b/tvix/nix-compat/src/nix_daemon/ser/mod.rs new file mode 100644 index 000000000000..5860226f39eb --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/mod.rs @@ -0,0 +1,124 @@ +use std::error::Error as StdError; +use std::future::Future; +use std::{fmt, io}; + +use super::ProtocolVersion; + +mod bytes; +mod collections; +#[cfg(feature = "nix-compat-derive")] +mod display; +mod int; +#[cfg(any(test, feature = "test"))] +pub mod mock; +mod writer; + +pub use writer::{NixWriter, NixWriterBuilder}; + +pub trait Error: Sized + StdError { + fn custom(msg: T) -> Self; + + fn io_error(err: std::io::Error) -> Self { + Self::custom(format_args!("There was an I/O error {}", err)) + } + + fn unsupported_data(msg: T) -> Self { + Self::custom(msg) + } + + fn invalid_enum(msg: T) -> Self { + Self::custom(msg) + } +} + +impl Error for io::Error { + fn custom(msg: T) -> Self { + io::Error::new(io::ErrorKind::Other, msg.to_string()) + } + + fn io_error(err: std::io::Error) -> Self { + err + } + + fn unsupported_data(msg: T) -> Self { + io::Error::new(io::ErrorKind::InvalidData, msg.to_string()) + } +} + +pub trait NixWrite: Send { + type Error: Error; + + /// Some types are serialized differently depending on the version + /// of the protocol and so this can be used for implementing that. + fn version(&self) -> ProtocolVersion; + + /// Write a single u64 to the protocol. + fn write_number(&mut self, value: u64) -> impl Future> + Send; + + /// Write a slice of bytes to the protocol. + fn write_slice(&mut self, buf: &[u8]) -> impl Future> + Send; + + /// Write a value that implements `std::fmt::Display` to the protocol. + /// The protocol uses many small string formats and instead of allocating + /// a `String` each time we want to write one an implementation of `NixWrite` + /// can instead use `Display` to dump these formats to a reusable buffer. + fn write_display(&mut self, msg: D) -> impl Future> + Send + where + D: fmt::Display + Send, + Self: Sized, + { + async move { + let s = msg.to_string(); + self.write_slice(s.as_bytes()).await + } + } + + /// Write a value to the protocol. + /// Uses `NixSerialize::serialize` to write the value. + fn write_value(&mut self, value: &V) -> impl Future> + Send + where + V: NixSerialize + Send + ?Sized, + Self: Sized, + { + value.serialize(self) + } +} + +impl NixWrite for &mut T { + type Error = T::Error; + + fn version(&self) -> ProtocolVersion { + (**self).version() + } + + fn write_number(&mut self, value: u64) -> impl Future> + Send { + (**self).write_number(value) + } + + fn write_slice(&mut self, buf: &[u8]) -> impl Future> + Send { + (**self).write_slice(buf) + } + + fn write_display(&mut self, msg: D) -> impl Future> + Send + where + D: fmt::Display + Send, + Self: Sized, + { + (**self).write_display(msg) + } + + fn write_value(&mut self, value: &V) -> impl Future> + Send + where + V: NixSerialize + Send + ?Sized, + Self: Sized, + { + (**self).write_value(value) + } +} + +pub trait NixSerialize { + /// Write a value to the writer. + fn serialize(&self, writer: &mut W) -> impl Future> + Send + where + W: NixWrite; +} diff --git a/tvix/nix-compat/src/nix_daemon/ser/writer.rs b/tvix/nix-compat/src/nix_daemon/ser/writer.rs new file mode 100644 index 000000000000..87e30580af34 --- /dev/null +++ b/tvix/nix-compat/src/nix_daemon/ser/writer.rs @@ -0,0 +1,308 @@ +use std::fmt::{self, Write as _}; +use std::future::poll_fn; +use std::io::{self, Cursor}; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +use bytes::{Buf, BufMut, BytesMut}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +use crate::nix_daemon::ProtocolVersion; +use crate::wire::padding_len; +use crate::wire::EMPTY_BYTES; + +use super::{Error, NixWrite}; + +pub struct NixWriterBuilder { + buf: Option, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, +} + +impl Default for NixWriterBuilder { + fn default() -> Self { + Self { + buf: Default::default(), + reserved_buf_size: 8192, + max_buf_size: 8192, + version: Default::default(), + } + } +} + +impl NixWriterBuilder { + pub fn set_buffer(mut self, buf: BytesMut) -> Self { + self.buf = Some(buf); + self + } + + pub fn set_reserved_buf_size(mut self, size: usize) -> Self { + self.reserved_buf_size = size; + self + } + + pub fn set_max_buf_size(mut self, size: usize) -> Self { + self.max_buf_size = size; + self + } + + pub fn set_version(mut self, version: ProtocolVersion) -> Self { + self.version = version; + self + } + + pub fn build(self, writer: W) -> NixWriter { + let buf = self + .buf + .unwrap_or_else(|| BytesMut::with_capacity(self.max_buf_size)); + NixWriter { + buf, + inner: writer, + reserved_buf_size: self.reserved_buf_size, + max_buf_size: self.max_buf_size, + version: self.version, + } + } +} + +pin_project! { + pub struct NixWriter { + #[pin] + inner: W, + buf: BytesMut, + reserved_buf_size: usize, + max_buf_size: usize, + version: ProtocolVersion, + } +} + +impl NixWriter>> { + pub fn builder() -> NixWriterBuilder { + NixWriterBuilder::default() + } +} + +impl NixWriter +where + W: AsyncWriteExt, +{ + pub fn new(writer: W) -> NixWriter { + NixWriter::builder().build(writer) + } + + pub fn buffer(&self) -> &[u8] { + &self.buf[..] + } + + pub fn set_version(&mut self, version: ProtocolVersion) { + self.version = version; + } + + /// Remaining capacity in internal buffer + pub fn remaining_mut(&self) -> usize { + self.buf.capacity() - self.buf.len() + } + + fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + while !this.buf.is_empty() { + let n = ready!(this.inner.as_mut().poll_write(cx, &this.buf[..]))?; + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write the buffer", + ))); + } + this.buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl NixWriter +where + W: AsyncWriteExt + Unpin, +{ + async fn flush_buf(&mut self) -> Result<(), io::Error> { + let mut s = Pin::new(self); + poll_fn(move |cx| s.as_mut().poll_flush_buf(cx)).await + } +} + +impl AsyncWrite for NixWriter +where + W: AsyncWrite, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // Flush + if self.remaining_mut() < buf.len() { + ready!(self.as_mut().poll_flush_buf(cx))?; + } + let this = self.project(); + if buf.len() > this.buf.capacity() { + this.inner.poll_write(cx, buf) + } else { + this.buf.put_slice(buf); + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush_buf(cx))?; + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + ready!(self.as_mut().poll_flush_buf(cx))?; + self.project().inner.poll_shutdown(cx) + } +} + +impl NixWrite for NixWriter +where + W: AsyncWrite + Send + Unpin, +{ + type Error = io::Error; + + fn version(&self) -> ProtocolVersion { + self.version + } + + async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> { + let mut buf = [0u8; 8]; + BufMut::put_u64_le(&mut &mut buf[..], value); + self.write_all(&buf).await + } + + async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + let padding = padding_len(buf.len() as u64) as usize; + self.write_value(&buf.len()).await?; + self.write_all(buf).await?; + if padding > 0 { + self.write_all(&EMPTY_BYTES[..padding]).await + } else { + Ok(()) + } + } + + async fn write_display(&mut self, msg: D) -> Result<(), Self::Error> + where + D: fmt::Display + Send, + Self: Sized, + { + // Ensure that buffer has space for at least reserved_buf_size bytes + if self.remaining_mut() < self.reserved_buf_size && !self.buf.is_empty() { + self.flush_buf().await?; + } + let offset = self.buf.len(); + self.buf.put_u64_le(0); + if let Err(err) = write!(self.buf, "{}", msg) { + self.buf.truncate(offset); + return Err(Self::Error::unsupported_data(err)); + } + let len = self.buf.len() - offset - 8; + BufMut::put_u64_le(&mut &mut self.buf[offset..(offset + 8)], len as u64); + let padding = padding_len(len as u64) as usize; + self.write_all(&EMPTY_BYTES[..padding]).await + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use hex_literal::hex; + use rstest::rstest; + use tokio::io::AsyncWriteExt as _; + use tokio_test::io::Builder; + + use crate::nix_daemon::ser::NixWrite; + + use super::NixWriter; + + #[rstest] + #[case(1, &hex!("0100 0000 0000 0000"))] + #[case::evil(666, &hex!("9A02 0000 0000 0000"))] + #[case::max(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))] + #[tokio::test] + async fn test_write_number(#[case] number: u64, #[case] buf: &[u8]) { + let mock = Builder::new().write(buf).build(); + let mut writer = NixWriter::new(mock); + + writer.write_number(number).await.unwrap(); + assert_eq!(writer.buffer(), buf); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } + + #[rstest] + #[case::empty(b"", &hex!("0000 0000 0000 0000"))] + #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_write_slice( + #[case] value: &[u8], + #[case] buf: &[u8], + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] buf_size: usize, + ) { + let mut builder = Builder::new(); + for chunk in buf.chunks(chunks_size) { + builder.write(chunk); + builder.wait(Duration::ZERO); + } + let mock = builder.build(); + let mut writer = NixWriter::builder().set_max_buf_size(buf_size).build(mock); + + writer.write_slice(value).await.unwrap(); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } + + #[rstest] + #[case::empty("", &hex!("0000 0000 0000 0000"))] + #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))] + #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))] + #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))] + #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))] + #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))] + #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))] + #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))] + #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))] + #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))] + #[tokio::test] + async fn test_write_display( + #[case] value: &str, + #[case] buf: &[u8], + #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize, + ) { + let mut builder = Builder::new(); + for chunk in buf.chunks(chunks_size) { + builder.write(chunk); + builder.wait(Duration::ZERO); + } + let mock = builder.build(); + let mut writer = NixWriter::builder().build(mock); + + writer.write_display(value).await.unwrap(); + assert_eq!(writer.buffer(), buf); + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), b""); + } +} diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs index 74adfb49b6a4..9b981fbbd2c0 100644 --- a/tvix/nix-compat/src/wire/bytes/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/mod.rs @@ -181,7 +181,7 @@ pub async fn write_bytes>( /// Computes the number of bytes we should add to len (a length in /// bytes) to be aligned on 64 bits (8 bytes). -fn padding_len(len: u64) -> u8 { +pub(crate) fn padding_len(len: u64) -> u8 { let aligned = len.wrapping_add(7) & !7; aligned.wrapping_sub(len) as u8 } -- cgit 1.4.1