diff options
author | Vincent Ambo <mail@tazj.in> | 2023-01-02T10·39+0300 |
---|---|---|
committer | tazjin <tazjin@tvl.su> | 2023-01-04T17·21+0000 |
commit | 34be6466d4a5da7dd3ad55ce80c951f21e45520c (patch) | |
tree | 28456ec15c4a2c6cae8eab74ebb5434be599b8e0 /tvix/serde/src/de.rs | |
parent | 0e88eb83efb194427329ccffd3b48671e1d72107 (diff) |
feat(tvix/serde): implement enum deserialisation r/5585
Implements externally tagged enum deserialisation. Other serialisation methods are handled by serde internally using the existing methods. See the tests for examples. Change-Id: Ic4a9da3b5a32ddbb5918b1512e70c3ac5ce64f04 Reviewed-on: https://cl.tvl.fyi/c/depot/+/7721 Tested-by: BuildkiteCI Autosubmit: tazjin <tazjin@tvl.su> Reviewed-by: flokli <flokli@flokli.de>
Diffstat (limited to 'tvix/serde/src/de.rs')
-rw-r--r-- | tvix/serde/src/de.rs | 81 |
1 files changed, 75 insertions, 6 deletions
diff --git a/tvix/serde/src/de.rs b/tvix/serde/src/de.rs index 2f7b2ba4d66d..e6bcf41cf231 100644 --- a/tvix/serde/src/de.rs +++ b/tvix/serde/src/de.rs @@ -1,7 +1,7 @@ //! Deserialisation from Nix to Rust values. -use serde::de; use serde::de::value::{MapDeserializer, SeqDeserializer}; +use serde::de::{self, EnumAccess, VariantAccess}; use tvix_eval::Value; use crate::error::Error; @@ -221,14 +221,14 @@ impl<'de> de::Deserializer<'de> for NixDeserializer { Err(unexpected("string", &self.value)) } - fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error> + fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error> where V: de::Visitor<'de>, { unimplemented!() } - fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error> + fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error> where V: de::Visitor<'de>, { @@ -307,7 +307,7 @@ impl<'de> de::Deserializer<'de> for NixDeserializer { fn deserialize_tuple_struct<V>( self, _name: &'static str, - len: usize, + _len: usize, visitor: V, ) -> Result<V::Value, Self::Error> where @@ -348,16 +348,27 @@ impl<'de> de::Deserializer<'de> for NixDeserializer { self.deserialize_map(visitor) } + // This method is responsible for deserializing the externally + // tagged enum variant serialisation. fn deserialize_enum<V>( self, name: &'static str, - variants: &'static [&'static str], + _variants: &'static [&'static str], visitor: V, ) -> Result<V::Value, Self::Error> where V: de::Visitor<'de>, { - todo!() + match self.value { + // a string represents a unit variant + Value::String(s) => visitor.visit_enum(de::value::StrDeserializer::new(s.as_str())), + + // an attribute set however represents an externally + // tagged enum with content + Value::Attrs(attrs) => visitor.visit_enum(Enum(*attrs)), + + _ => Err(unexpected(name, &self.value)), + } } fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error> @@ -374,3 +385,61 @@ impl<'de> de::Deserializer<'de> for NixDeserializer { visitor.visit_unit() } } + +struct Enum(tvix_eval::NixAttrs); + +impl<'de> EnumAccess<'de> for Enum { + type Error = Error; + type Variant = NixDeserializer; + + // TODO: pass the known variants down here and check against them + fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: de::DeserializeSeed<'de>, + { + if self.0.len() != 1 { + return Err(Error::AmbiguousEnum); + } + + let (key, value) = self.0.into_iter().next().expect("length asserted above"); + let val = seed.deserialize(de::value::StrDeserializer::<Error>::new(key.as_str()))?; + + Ok((val, NixDeserializer::new(value))) + } +} + +impl<'de> VariantAccess<'de> for NixDeserializer { + type Error = Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + // If this case is hit, a user specified the name of a unit + // enum variant but gave it content. Unit enum deserialisation + // is handled in `deserialize_enum` above. + Err(Error::UnitEnumContent) + } + + fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error> + where + T: de::DeserializeSeed<'de>, + { + seed.deserialize(self) + } + + fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + de::Deserializer::deserialize_seq(self, visitor) + } + + fn struct_variant<V>( + self, + _fields: &'static [&'static str], + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: de::Visitor<'de>, + { + de::Deserializer::deserialize_map(self, visitor) + } +} |