diff options
Diffstat (limited to 'tvix/serde/src')
-rw-r--r-- | tvix/serde/src/de.rs | 463 | ||||
-rw-r--r-- | tvix/serde/src/de_tests.rs | 244 | ||||
-rw-r--r-- | tvix/serde/src/error.rs | 102 | ||||
-rw-r--r-- | tvix/serde/src/lib.rs | 12 |
4 files changed, 821 insertions, 0 deletions
diff --git a/tvix/serde/src/de.rs b/tvix/serde/src/de.rs new file mode 100644 index 000000000000..2e8a9618e637 --- /dev/null +++ b/tvix/serde/src/de.rs @@ -0,0 +1,463 @@ +//! Deserialisation from Nix to Rust values. + +use serde::de::value::{MapDeserializer, SeqDeserializer}; +use serde::de::{self, EnumAccess, VariantAccess}; +pub use tvix_eval::Evaluation; +use tvix_eval::Value; + +use crate::error::Error; + +struct NixDeserializer { + value: tvix_eval::Value, +} + +impl NixDeserializer { + fn new(value: Value) -> Self { + if let Value::Thunk(thunk) = value { + Self::new(thunk.value().clone()) + } else { + Self { value } + } + } +} + +impl de::IntoDeserializer<'_, Error> for NixDeserializer { + type Deserializer = Self; + + fn into_deserializer(self) -> Self::Deserializer { + self + } +} + +/// Evaluate the Nix code in `src` and attempt to deserialise the +/// value it returns to `T`. +pub fn from_str<'code, T>(src: &'code str) -> Result<T, Error> +where + T: serde::Deserialize<'code>, +{ + from_str_with_config(src, |_| /* no extra config */ ()) +} + +/// Evaluate the Nix code in `src`, with extra configuration for the +/// `tvix_eval::Evaluation` provided by the given closure. +pub fn from_str_with_config<'code, T, F>(src: &'code str, config: F) -> Result<T, Error> +where + T: serde::Deserialize<'code>, + F: FnOnce(&mut Evaluation), +{ + // First step is to evaluate the Nix code ... + let mut eval = Evaluation::new(src, None); + config(&mut eval); + + eval.strict = true; + let source = eval.source_map(); + let result = eval.evaluate(); + + if !result.errors.is_empty() { + return Err(Error::NixErrors { + errors: result.errors, + source, + }); + } + + let de = NixDeserializer::new(result.value.expect("value should be present on success")); + + T::deserialize(de) +} + +fn unexpected(expected: &'static str, got: &Value) -> Error { + Error::UnexpectedType { + expected, + got: got.type_of(), + } +} + +fn visit_integer<I: TryFrom<i64>>(v: &Value) -> Result<I, Error> { + match v { + Value::Integer(i) => I::try_from(*i).map_err(|_| Error::IntegerConversion { + got: *i, + need: std::any::type_name::<I>(), + }), + + _ => Err(unexpected("integer", v)), + } +} + +impl<'de> de::Deserializer<'de> for NixDeserializer { + type Error = Error; + + fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + match self.value { + Value::Null => visitor.visit_unit(), + Value::Bool(b) => visitor.visit_bool(b), + Value::Integer(i) => visitor.visit_i64(i), + Value::Float(f) => visitor.visit_f64(f), + Value::String(s) => visitor.visit_string(s.to_string()), + Value::Path(p) => visitor.visit_string(p.to_string_lossy().into()), // TODO: hmm + Value::Attrs(_) => self.deserialize_map(visitor), + Value::List(_) => self.deserialize_seq(visitor), + + // tvix-eval types that can not be deserialized through serde. + Value::Closure(_) + | Value::Builtin(_) + | Value::Thunk(_) + | Value::AttrNotFound + | Value::Blueprint(_) + | Value::DeferredUpvalue(_) + | Value::UnresolvedPath(_) + | Value::Json(_) + | Value::Catchable(_) + | Value::FinaliseRequest(_) => Err(Error::Unserializable { + value_type: self.value.type_of(), + }), + } + } + + fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + match self.value { + Value::Bool(b) => visitor.visit_bool(b), + _ => Err(unexpected("bool", &self.value)), + } + } + + fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + visitor.visit_i8(visit_integer(&self.value)?) + } + + fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + visitor.visit_i16(visit_integer(&self.value)?) + } + + fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + visitor.visit_i32(visit_integer(&self.value)?) + } + + fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + visitor.visit_i64(visit_integer(&self.value)?) + } + + fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + visitor.visit_u8(visit_integer(&self.value)?) + } + + fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + visitor.visit_u16(visit_integer(&self.value)?) + } + + fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + visitor.visit_u32(visit_integer(&self.value)?) + } + + fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + visitor.visit_u64(visit_integer(&self.value)?) + } + + fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + if let Value::Float(f) = self.value { + return visitor.visit_f32(f as f32); + } + + Err(unexpected("float", &self.value)) + } + + fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + if let Value::Float(f) = self.value { + return visitor.visit_f64(f); + } + + Err(unexpected("float", &self.value)) + } + + fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + if let Value::String(s) = &self.value { + let chars = s.as_str().chars().collect::<Vec<_>>(); + if chars.len() == 1 { + return visitor.visit_char(chars[0]); + } + } + + Err(unexpected("char", &self.value)) + } + + fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + if let Value::String(s) = &self.value { + return visitor.visit_str(s.as_str()); + } + + Err(unexpected("string", &self.value)) + } + + fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + if let Value::String(s) = &self.value { + return visitor.visit_str(s.as_str()); + } + + Err(unexpected("string", &self.value)) + } + + fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + unimplemented!() + } + + fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + unimplemented!() + } + + // Note that this can not distinguish between a serialisation of + // `Some(())` and `None`. + fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + if let Value::Null = self.value { + visitor.visit_none() + } else { + visitor.visit_some(self) + } + } + + fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + if let Value::Null = self.value { + return visitor.visit_unit(); + } + + Err(unexpected("null", &self.value)) + } + + fn deserialize_unit_struct<V>( + self, + _name: &'static str, + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + self.deserialize_unit(visitor) + } + + fn deserialize_newtype_struct<V>( + self, + _name: &'static str, + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + if let Value::List(list) = self.value { + let mut seq = SeqDeserializer::new(list.into_iter().map(NixDeserializer::new)); + let result = visitor.visit_seq(&mut seq)?; + seq.end()?; + return Ok(result); + } + + Err(unexpected("list", &self.value)) + } + + fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + // just represent tuples as lists ... + self.deserialize_seq(visitor) + } + + fn deserialize_tuple_struct<V>( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + // same as above + self.deserialize_seq(visitor) + } + + fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + if let Value::Attrs(attrs) = self.value { + let mut map = MapDeserializer::new(attrs.into_iter().map(|(k, v)| { + ( + NixDeserializer::new(Value::String(k)), + NixDeserializer::new(v), + ) + })); + let result = visitor.visit_map(&mut map)?; + map.end()?; + return Ok(result); + } + + Err(unexpected("map", &self.value)) + } + + fn deserialize_struct<V>( + self, + _name: &'static str, + _fields: &'static [&'static str], + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + self.deserialize_map(visitor) + } + + // This method is responsible for deserializing the externally + // tagged enum variant serialisation. + fn deserialize_enum<V>( + self, + name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + match self.value { + // a string represents a unit variant + Value::String(s) => visitor.visit_enum(de::value::StrDeserializer::new(s.as_str())), + + // an attribute set however represents an externally + // tagged enum with content + Value::Attrs(attrs) => visitor.visit_enum(Enum(*attrs)), + + _ => Err(unexpected(name, &self.value)), + } + } + + fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + self.deserialize_str(visitor) + } + + fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + visitor.visit_unit() + } +} + +struct Enum(tvix_eval::NixAttrs); + +impl<'de> EnumAccess<'de> for Enum { + type Error = Error; + type Variant = NixDeserializer; + + // TODO: pass the known variants down here and check against them + fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: de::DeserializeSeed<'de>, + { + if self.0.len() != 1 { + return Err(Error::AmbiguousEnum); + } + + let (key, value) = self.0.into_iter().next().expect("length asserted above"); + let val = seed.deserialize(de::value::StrDeserializer::<Error>::new(key.as_str()))?; + + Ok((val, NixDeserializer::new(value))) + } +} + +impl<'de> VariantAccess<'de> for NixDeserializer { + type Error = Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + // If this case is hit, a user specified the name of a unit + // enum variant but gave it content. Unit enum deserialisation + // is handled in `deserialize_enum` above. + Err(Error::UnitEnumContent) + } + + fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error> + where + T: de::DeserializeSeed<'de>, + { + seed.deserialize(self) + } + + fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + de::Deserializer::deserialize_seq(self, visitor) + } + + fn struct_variant<V>( + self, + _fields: &'static [&'static str], + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + de::Deserializer::deserialize_map(self, visitor) + } +} diff --git a/tvix/serde/src/de_tests.rs b/tvix/serde/src/de_tests.rs new file mode 100644 index 000000000000..807d953c77fa --- /dev/null +++ b/tvix/serde/src/de_tests.rs @@ -0,0 +1,244 @@ +use serde::Deserialize; +use std::collections::HashMap; +use tvix_eval::builtin_macros::builtins; + +use crate::de::{from_str, from_str_with_config}; + +#[test] +fn deserialize_none() { + let result: Option<usize> = from_str("null").expect("should deserialize"); + assert_eq!(None, result); +} + +#[test] +fn deserialize_some() { + let result: Option<usize> = from_str("40 + 2").expect("should deserialize"); + assert_eq!(Some(42), result); +} + +#[test] +fn deserialize_string() { + let result: String = from_str( + r#" + let greeter = name: "Hello ${name}!"; + in greeter "Slartibartfast" + "#, + ) + .expect("should deserialize"); + + assert_eq!(result, "Hello Slartibartfast!"); +} + +#[test] +fn deserialize_empty_list() { + let result: Vec<usize> = from_str("[ ]").expect("should deserialize"); + assert!(result.is_empty()) +} + +#[test] +fn deserialize_integer_list() { + let result: Vec<usize> = + from_str("builtins.map (n: n + 2) [ 21 40 67 ]").expect("should deserialize"); + assert_eq!(result, vec![23, 42, 69]); +} + +#[test] +fn deserialize_empty_map() { + let result: HashMap<String, usize> = from_str("{ }").expect("should deserialize"); + assert!(result.is_empty()); +} + +#[test] +fn deserialize_integer_map() { + let result: HashMap<String, usize> = from_str("{ age = 40 + 2; }").expect("should deserialize"); + assert_eq!(result.len(), 1); + assert_eq!(*result.get("age").unwrap(), 42); +} + +#[test] +fn deserialize_struct() { + #[derive(Debug, Deserialize, PartialEq)] + struct Person { + name: String, + age: usize, + } + + let result: Person = from_str( + r#" + { + name = "Slartibartfast"; + age = 42; + } + "#, + ) + .expect("should deserialize"); + + assert_eq!( + result, + Person { + name: "Slartibartfast".into(), + age: 42, + } + ); +} + +#[test] +fn deserialize_newtype() { + #[derive(Debug, Deserialize, PartialEq)] + struct Number(usize); + + let result: Number = from_str("42").expect("should deserialize"); + assert_eq!(result, Number(42)); +} + +#[test] +fn deserialize_tuple() { + let result: (String, usize) = from_str(r#" [ "foo" 42 ] "#).expect("should deserialize"); + assert_eq!(result, ("foo".into(), 42)); +} + +#[test] +fn deserialize_unit_enum() { + #[derive(Debug, Deserialize, PartialEq)] + enum Foo { + Bar, + Baz, + } + + let result: Foo = from_str("\"Baz\"").expect("should deserialize"); + assert_eq!(result, Foo::Baz); +} + +#[test] +fn deserialize_tuple_enum() { + #[derive(Debug, Deserialize, PartialEq)] + enum Foo { + Bar, + Baz(String, usize), + } + + let result: Foo = from_str( + r#" + { + Baz = [ "Slartibartfast" 42 ]; + } + "#, + ) + .expect("should deserialize"); + + assert_eq!(result, Foo::Baz("Slartibartfast".into(), 42)); +} + +#[test] +fn deserialize_struct_enum() { + #[derive(Debug, Deserialize, PartialEq)] + enum Foo { + Bar, + Baz { name: String, age: usize }, + } + + let result: Foo = from_str( + r#" + { + Baz.name = "Slartibartfast"; + Baz.age = 42; + } + "#, + ) + .expect("should deserialize"); + + assert_eq!( + result, + Foo::Baz { + name: "Slartibartfast".into(), + age: 42 + } + ); +} + +#[test] +fn deserialize_enum_all() { + #[derive(Debug, Deserialize, PartialEq)] + #[serde(rename_all = "snake_case")] + enum TestEnum { + Unit, + Tuple(String, String), + Struct { name: String, age: usize }, + } + + let result: Vec<TestEnum> = from_str( + r#" + let + mkTuple = country: drink: { tuple = [ country drink ]; }; + in + [ + (mkTuple "UK" "cask ale") + + "unit" + + { + struct.name = "Slartibartfast"; + struct.age = 42; + } + + (mkTuple "Russia" "квас") + ] + "#, + ) + .expect("should deserialize"); + + let expected = vec![ + TestEnum::Tuple("UK".into(), "cask ale".into()), + TestEnum::Unit, + TestEnum::Struct { + name: "Slartibartfast".into(), + age: 42, + }, + TestEnum::Tuple("Russia".into(), "квас".into()), + ]; + + assert_eq!(result, expected); +} + +#[test] +fn deserialize_with_config() { + let result: String = from_str_with_config("builtins.testWithConfig", |eval| { + // Add a literal string builtin that just returns `"ok"`. + eval.src_builtins.push(("testWithConfig", "\"ok\"")); + }) + .expect("should deserialize"); + + assert_eq!(result, "ok"); +} + +#[builtins] +mod test_builtins { + use tvix_eval::generators::{Gen, GenCo}; + use tvix_eval::{ErrorKind, NixString, Value}; + + #[builtin("prependHello")] + pub async fn builtin_prepend_hello(co: GenCo, x: Value) -> Result<Value, ErrorKind> { + match x { + Value::String(s) => { + let new_string = NixString::from(format!("hello {}", s.as_str())); + Ok(Value::String(new_string)) + } + _ => Err(ErrorKind::TypeError { + expected: "string", + actual: "not string", + }), + } + } +} + +#[test] +fn deserialize_with_extra_builtin() { + let code = "builtins.prependHello \"world\""; + + let result: String = from_str_with_config(code, |eval| { + eval.builtins.append(&mut test_builtins::builtins()); + }) + .expect("should deserialize"); + + assert_eq!(result, "hello world"); +} diff --git a/tvix/serde/src/error.rs b/tvix/serde/src/error.rs new file mode 100644 index 000000000000..c1d2258bbfe4 --- /dev/null +++ b/tvix/serde/src/error.rs @@ -0,0 +1,102 @@ +//! When serialising Nix goes wrong ... + +use std::error; +use std::fmt::Display; + +#[derive(Clone, Debug)] +pub enum Error { + /// Attempted to deserialise an unsupported Nix value (such as a + /// function) that can not be represented by the + /// [`serde::Deserialize`] trait. + Unserializable { value_type: &'static str }, + + /// Expected to deserialize a value that is unsupported by Nix. + Unsupported { wanted: &'static str }, + + /// Expected a specific type, but got something else on the Nix side. + UnexpectedType { + expected: &'static str, + got: &'static str, + }, + + /// Deserialisation error returned from `serde::de`. + Deserialization(String), + + /// Deserialized integer did not fit. + IntegerConversion { got: i64, need: &'static str }, + + /// Evaluation of the supplied Nix code failed while computing the + /// value for deserialisation. + NixErrors { + errors: Vec<tvix_eval::Error>, + source: tvix_eval::SourceCode, + }, + + /// Could not determine an externally tagged enum representation. + AmbiguousEnum, + + /// Attempted to provide content to a unit enum. + UnitEnumContent, +} + +impl Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::Unserializable { value_type } => write!( + f, + "can not deserialise a Nix '{}' into a Rust type", + value_type + ), + + Error::Unsupported { wanted } => { + write!(f, "can not deserialize a '{}' from a Nix value", wanted) + } + + Error::UnexpectedType { expected, got } => { + write!(f, "expected type {}, but got Nix type {}", expected, got) + } + + Error::NixErrors { errors, source } => { + writeln!( + f, + "{} occured during Nix evaluation: ", + if errors.len() == 1 { "error" } else { "errors" } + )?; + + for err in errors { + writeln!(f, "{}", err.fancy_format_str(source))?; + } + + Ok(()) + } + + Error::Deserialization(err) => write!(f, "deserialisation error occured: {}", err), + + Error::IntegerConversion { got, need } => { + write!(f, "i64({}) does not fit in a {}", got, need) + } + + Error::AmbiguousEnum => write!(f, "could not determine enum variant: ambiguous keys"), + + Error::UnitEnumContent => write!(f, "provided content for unit enum variant"), + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match self { + Self::NixErrors { errors, .. } => errors.first().map(|e| e as &dyn error::Error), + _ => None, + } + } +} + +impl serde::de::Error for Error { + fn custom<T>(err: T) -> Self + where + T: Display, + { + Self::Deserialization(err.to_string()) + } +} diff --git a/tvix/serde/src/lib.rs b/tvix/serde/src/lib.rs new file mode 100644 index 000000000000..6a44affdc0e1 --- /dev/null +++ b/tvix/serde/src/lib.rs @@ -0,0 +1,12 @@ +//! `tvix-serde` implements (de-)serialisation of Rust data structures +//! to/from Nix. This is intended to make it easy to use Nix as as +//! configuration language. + +mod de; +mod error; + +pub use de::from_str; +pub use de::from_str_with_config; + +#[cfg(test)] +mod de_tests; |