about summary refs log tree commit diff
path: root/tvix/serde/src
diff options
context:
space:
mode:
authorVincent Ambo <mail@tazj.in>2023-01-02T10·39+0300
committertazjin <tazjin@tvl.su>2023-01-04T17·21+0000
commit34be6466d4a5da7dd3ad55ce80c951f21e45520c (patch)
tree28456ec15c4a2c6cae8eab74ebb5434be599b8e0 /tvix/serde/src
parent0e88eb83efb194427329ccffd3b48671e1d72107 (diff)
feat(tvix/serde): implement enum deserialisation r/5585
Implements externally tagged enum deserialisation. Other serialisation
methods are handled by serde internally using the existing methods.

See the tests for examples.

Change-Id: Ic4a9da3b5a32ddbb5918b1512e70c3ac5ce64f04
Reviewed-on: https://cl.tvl.fyi/c/depot/+/7721
Tested-by: BuildkiteCI
Autosubmit: tazjin <tazjin@tvl.su>
Reviewed-by: flokli <flokli@flokli.de>
Diffstat (limited to 'tvix/serde/src')
-rw-r--r--tvix/serde/src/de.rs81
-rw-r--r--tvix/serde/src/de_tests.rs103
-rw-r--r--tvix/serde/src/error.rs10
3 files changed, 188 insertions, 6 deletions
diff --git a/tvix/serde/src/de.rs b/tvix/serde/src/de.rs
index 2f7b2ba4d66d..e6bcf41cf231 100644
--- a/tvix/serde/src/de.rs
+++ b/tvix/serde/src/de.rs
@@ -1,7 +1,7 @@
 //! Deserialisation from Nix to Rust values.
 
-use serde::de;
 use serde::de::value::{MapDeserializer, SeqDeserializer};
+use serde::de::{self, EnumAccess, VariantAccess};
 use tvix_eval::Value;
 
 use crate::error::Error;
@@ -221,14 +221,14 @@ impl<'de> de::Deserializer<'de> for NixDeserializer {
         Err(unexpected("string", &self.value))
     }
 
-    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
     where
         V: de::Visitor<'de>,
     {
         unimplemented!()
     }
 
-    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
     where
         V: de::Visitor<'de>,
     {
@@ -307,7 +307,7 @@ impl<'de> de::Deserializer<'de> for NixDeserializer {
     fn deserialize_tuple_struct<V>(
         self,
         _name: &'static str,
-        len: usize,
+        _len: usize,
         visitor: V,
     ) -> Result<V::Value, Self::Error>
     where
@@ -348,16 +348,27 @@ impl<'de> de::Deserializer<'de> for NixDeserializer {
         self.deserialize_map(visitor)
     }
 
+    // This method is responsible for deserializing the externally
+    // tagged enum variant serialisation.
     fn deserialize_enum<V>(
         self,
         name: &'static str,
-        variants: &'static [&'static str],
+        _variants: &'static [&'static str],
         visitor: V,
     ) -> Result<V::Value, Self::Error>
     where
         V: de::Visitor<'de>,
     {
-        todo!()
+        match self.value {
+            // a string represents a unit variant
+            Value::String(s) => visitor.visit_enum(de::value::StrDeserializer::new(s.as_str())),
+
+            // an attribute set however represents an externally
+            // tagged enum with content
+            Value::Attrs(attrs) => visitor.visit_enum(Enum(*attrs)),
+
+            _ => Err(unexpected(name, &self.value)),
+        }
     }
 
     fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
@@ -374,3 +385,61 @@ impl<'de> de::Deserializer<'de> for NixDeserializer {
         visitor.visit_unit()
     }
 }
+
+struct Enum(tvix_eval::NixAttrs);
+
+impl<'de> EnumAccess<'de> for Enum {
+    type Error = Error;
+    type Variant = NixDeserializer;
+
+    // TODO: pass the known variants down here and check against them
+    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
+    where
+        V: de::DeserializeSeed<'de>,
+    {
+        if self.0.len() != 1 {
+            return Err(Error::AmbiguousEnum);
+        }
+
+        let (key, value) = self.0.into_iter().next().expect("length asserted above");
+        let val = seed.deserialize(de::value::StrDeserializer::<Error>::new(key.as_str()))?;
+
+        Ok((val, NixDeserializer::new(value)))
+    }
+}
+
+impl<'de> VariantAccess<'de> for NixDeserializer {
+    type Error = Error;
+
+    fn unit_variant(self) -> Result<(), Self::Error> {
+        // If this case is hit, a user specified the name of a unit
+        // enum variant but gave it content. Unit enum deserialisation
+        // is handled in `deserialize_enum` above.
+        Err(Error::UnitEnumContent)
+    }
+
+    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
+    where
+        T: de::DeserializeSeed<'de>,
+    {
+        seed.deserialize(self)
+    }
+
+    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: de::Visitor<'de>,
+    {
+        de::Deserializer::deserialize_seq(self, visitor)
+    }
+
+    fn struct_variant<V>(
+        self,
+        _fields: &'static [&'static str],
+        visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: de::Visitor<'de>,
+    {
+        de::Deserializer::deserialize_map(self, visitor)
+    }
+}
diff --git a/tvix/serde/src/de_tests.rs b/tvix/serde/src/de_tests.rs
index 1613b874d949..8fe15a17e378 100644
--- a/tvix/serde/src/de_tests.rs
+++ b/tvix/serde/src/de_tests.rs
@@ -95,3 +95,106 @@ fn deserialize_tuple() {
     let result: (String, usize) = from_str(r#" [ "foo" 42 ] "#).expect("should deserialize");
     assert_eq!(result, ("foo".into(), 42));
 }
+
+#[test]
+fn deserialize_unit_enum() {
+    #[derive(Debug, Deserialize, PartialEq)]
+    enum Foo {
+        Bar,
+        Baz,
+    }
+
+    let result: Foo = from_str("\"Baz\"").expect("should deserialize");
+    assert_eq!(result, Foo::Baz);
+}
+
+#[test]
+fn deserialize_tuple_enum() {
+    #[derive(Debug, Deserialize, PartialEq)]
+    enum Foo {
+        Bar,
+        Baz(String, usize),
+    }
+
+    let result: Foo = from_str(
+        r#"
+    {
+      Baz = [ "Slartibartfast" 42 ];
+    }
+    "#,
+    )
+    .expect("should deserialize");
+
+    assert_eq!(result, Foo::Baz("Slartibartfast".into(), 42));
+}
+
+#[test]
+fn deserialize_struct_enum() {
+    #[derive(Debug, Deserialize, PartialEq)]
+    enum Foo {
+        Bar,
+        Baz { name: String, age: usize },
+    }
+
+    let result: Foo = from_str(
+        r#"
+    {
+      Baz.name = "Slartibartfast";
+      Baz.age = 42;
+    }
+    "#,
+    )
+    .expect("should deserialize");
+
+    assert_eq!(
+        result,
+        Foo::Baz {
+            name: "Slartibartfast".into(),
+            age: 42
+        }
+    );
+}
+
+#[test]
+fn deserialize_enum_all() {
+    #[derive(Debug, Deserialize, PartialEq)]
+    #[serde(rename_all = "snake_case")]
+    enum TestEnum {
+        UnitVariant,
+        TupleVariant(String, String),
+        StructVariant { name: String, age: usize },
+    }
+
+    let result: Vec<TestEnum> = from_str(
+        r#"
+      let
+        mkTuple = country: drink: { tuple_variant = [ country drink ]; };
+      in
+      [
+        (mkTuple "UK" "cask ale")
+
+        "unit_variant"
+
+        {
+          struct_variant.name = "Slartibartfast";
+          struct_variant.age = 42;
+        }
+
+        (mkTuple "Russia" "квас")
+      ]
+    "#,
+    )
+    .expect("should deserialize");
+
+    let expected = vec![
+        TestEnum::TupleVariant("UK".into(), "cask ale".into()),
+        TestEnum::UnitVariant,
+        TestEnum::StructVariant {
+            name: "Slartibartfast".into(),
+            age: 42,
+        },
+        TestEnum::TupleVariant("Russia".into(), "квас".into()),
+    ];
+
+    assert_eq!(result, expected);
+}
diff --git a/tvix/serde/src/error.rs b/tvix/serde/src/error.rs
index fb83105cd210..f206b830e95f 100644
--- a/tvix/serde/src/error.rs
+++ b/tvix/serde/src/error.rs
@@ -31,6 +31,12 @@ pub enum Error {
         errors: Vec<tvix_eval::Error>,
         source: tvix_eval::SourceCode,
     },
+
+    /// Could not determine an externally tagged enum representation.
+    AmbiguousEnum,
+
+    /// Attempted to provide content to a unit enum.
+    UnitEnumContent,
 }
 
 impl Display for Error {
@@ -69,6 +75,10 @@ impl Display for Error {
             Error::IntegerConversion { got, need } => {
                 write!(f, "i64({}) does not fit in a {}", got, need)
             }
+
+            Error::AmbiguousEnum => write!(f, "could not determine enum variant: ambiguous keys"),
+
+            Error::UnitEnumContent => write!(f, "provided content for unit enum variant"),
         }
     }
 }