diff options
author | Brian Olsen <brian@maven-group.org> | 2024-07-22T14·51+0200 |
---|---|---|
committer | clbot <clbot@tvl.fyi> | 2024-08-25T15·05+0000 |
commit | ced05a2bb6b66d30208520d0791f4fa298c429e2 (patch) | |
tree | d295eb9766db79df8f3053bc5022e24059012f16 /tvix/nix-compat-derive/src/de.rs | |
parent | 9af69204787d47cfe551f524d01b1a726971f06e (diff) |
feat(tvix/nix-compat-derive): Add deriver for NixDeserialize r/8586
This adds a nix-compat-derive derive crate that implements a deriver for NixDeserialize implementations. This is to reduce the amount of code needed to implement deserialization for all the types used by the Nix daemon protocol. Change-Id: I484724b550e8a1d5e9adad9555d9dc1374ae95c2 Reviewed-on: https://cl.tvl.fyi/c/depot/+/12022 Autosubmit: Brian Olsen <me@griff.name> Tested-by: BuildkiteCI Reviewed-by: flokli <flokli@flokli.de>
Diffstat (limited to 'tvix/nix-compat-derive/src/de.rs')
-rw-r--r-- | tvix/nix-compat-derive/src/de.rs | 272 |
1 files changed, 272 insertions, 0 deletions
diff --git a/tvix/nix-compat-derive/src/de.rs b/tvix/nix-compat-derive/src/de.rs new file mode 100644 index 000000000000..ee79ea9d1012 --- /dev/null +++ b/tvix/nix-compat-derive/src/de.rs @@ -0,0 +1,272 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{quote, quote_spanned, ToTokens}; +use syn::spanned::Spanned; +use syn::{DeriveInput, Generics, Path, Type}; + +use crate::internal::attrs::Default; +use crate::internal::inputs::RemoteInput; +use crate::internal::{attrs, Container, Context, Data, Field, Remote, Style, Variant}; + +pub fn expand_nix_deserialize(nnixrs: Path, input: &mut DeriveInput) -> syn::Result<TokenStream> { + let cx = Context::new(); + let cont = Container::from_ast(&cx, nnixrs, input); + cx.check()?; + let cont = cont.unwrap(); + + let ty = cont.ident_type(); + let body = nix_deserialize_body(&cont); + let crate_path = cont.crate_path(); + + Ok(nix_deserialize_impl( + crate_path, + &ty, + &cont.original.generics, + body, + )) +} + +pub fn expand_nix_deserialize_remote( + crate_path: Path, + input: &RemoteInput, +) -> syn::Result<TokenStream> { + let cx = Context::new(); + let remote = Remote::from_ast(&cx, crate_path, input); + cx.check()?; + let remote = remote.unwrap(); + + let crate_path = remote.crate_path(); + let body = nix_deserialize_body_from(crate_path, &remote.attrs).expect("From tokenstream"); + let generics = Generics::default(); + Ok(nix_deserialize_impl(crate_path, remote.ty, &generics, body)) +} + +fn nix_deserialize_impl( + crate_path: &Path, + ty: &Type, + generics: &Generics, + body: TokenStream, +) -> TokenStream { + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + quote! { + #[automatically_derived] + impl #impl_generics #crate_path::nix_daemon::de::NixDeserialize for #ty #ty_generics + #where_clause + { + #[allow(clippy::manual_async_fn)] + fn try_deserialize<R>(reader: &mut R) -> impl ::std::future::Future<Output=Result<Option<Self>, R::Error>> + Send + '_ + where R: ?Sized + #crate_path::nix_daemon::de::NixRead + Send, + { + #body + } + } + } +} + +fn nix_deserialize_body_from( + crate_path: &syn::Path, + attrs: &attrs::Container, +) -> Option<TokenStream> { + if let Some(span) = attrs.from_str.as_ref() { + Some(nix_deserialize_from_str(crate_path, span.span())) + } else if let Some(type_from) = attrs.type_from.as_ref() { + Some(nix_deserialize_from(type_from)) + } else { + attrs + .type_try_from + .as_ref() + .map(|type_try_from| nix_deserialize_try_from(crate_path, type_try_from)) + } +} + +fn nix_deserialize_body(cont: &Container) -> TokenStream { + if let Some(tokens) = nix_deserialize_body_from(cont.crate_path(), &cont.attrs) { + tokens + } else { + match &cont.data { + Data::Struct(style, fields) => nix_deserialize_struct(*style, fields), + Data::Enum(variants) => nix_deserialize_enum(variants), + } + } +} + +fn nix_deserialize_struct(style: Style, fields: &[Field<'_>]) -> TokenStream { + let read_fields = fields.iter().map(|f| { + let field = f.var_ident(); + let ty = f.ty; + let read_value = quote_spanned! { + ty.span()=> if first__ { + first__ = false; + if let Some(v) = reader.try_read_value::<#ty>().await? { + v + } else { + return Ok(None); + } + } else { + reader.read_value::<#ty>().await? + } + }; + if let Some(version) = f.attrs.version.as_ref() { + let default = match &f.attrs.default { + Default::Default => quote_spanned!(ty.span()=>::std::default::Default::default), + Default::Path(path) => path.to_token_stream(), + _ => panic!("No default for versioned field"), + }; + quote! { + let #field : #ty = if (#version).contains(&reader.version().minor()) { + #read_value + } else { + #default() + }; + } + } else { + quote! { + let #field : #ty = #read_value; + } + } + }); + + let field_names = fields.iter().map(|f| f.var_ident()); + let construct = match style { + Style::Struct => { + quote! { + Self { #(#field_names),* } + } + } + Style::Tuple => { + quote! { + Self(#(#field_names),*) + } + } + Style::Unit => quote!(Self), + }; + quote! { + #[allow(unused_assignments)] + async move { + let mut first__ = true; + #(#read_fields)* + Ok(Some(#construct)) + } + } +} + +fn nix_deserialize_variant(variant: &Variant<'_>) -> TokenStream { + let ident = variant.ident; + let read_fields = variant.fields.iter().map(|f| { + let field = f.var_ident(); + let ty = f.ty; + let read_value = quote_spanned! { + ty.span()=> if first__ { + first__ = false; + if let Some(v) = reader.try_read_value::<#ty>().await? { + v + } else { + return Ok(None); + } + } else { + reader.read_value::<#ty>().await? + } + }; + if let Some(version) = f.attrs.version.as_ref() { + let default = match &f.attrs.default { + Default::Default => quote_spanned!(ty.span()=>::std::default::Default::default), + Default::Path(path) => path.to_token_stream(), + _ => panic!("No default for versioned field"), + }; + quote! { + let #field : #ty = if (#version).contains(&reader.version().minor()) { + #read_value + } else { + #default() + }; + } + } else { + quote! { + let #field : #ty = #read_value; + } + } + }); + let field_names = variant.fields.iter().map(|f| f.var_ident()); + let construct = match variant.style { + Style::Struct => { + quote! { + Self::#ident { #(#field_names),* } + } + } + Style::Tuple => { + quote! { + Self::#ident(#(#field_names),*) + } + } + Style::Unit => quote!(Self::#ident), + }; + let version = &variant.attrs.version; + quote! { + #version => { + #(#read_fields)* + Ok(Some(#construct)) + } + } +} + +fn nix_deserialize_enum(variants: &[Variant<'_>]) -> TokenStream { + let match_variant = variants + .iter() + .map(|variant| nix_deserialize_variant(variant)); + quote! { + #[allow(unused_assignments)] + async move { + let mut first__ = true; + match reader.version().minor() { + #(#match_variant)* + } + } + } +} + +fn nix_deserialize_from(ty: &Type) -> TokenStream { + quote_spanned! { + ty.span() => + async move { + if let Some(value) = reader.try_read_value::<#ty>().await? { + Ok(Some(<Self as ::std::convert::From<#ty>>::from(value))) + } else { + Ok(None) + } + } + } +} + +fn nix_deserialize_try_from(crate_path: &Path, ty: &Type) -> TokenStream { + quote_spanned! { + ty.span() => + async move { + use #crate_path::nix_daemon::de::Error; + if let Some(item) = reader.try_read_value::<#ty>().await? { + <Self as ::std::convert::TryFrom<#ty>>::try_from(item) + .map_err(Error::invalid_data) + .map(Some) + } else { + Ok(None) + } + } + } +} + +fn nix_deserialize_from_str(crate_path: &Path, span: Span) -> TokenStream { + quote_spanned! { + span => + async move { + use #crate_path::nix_daemon::de::Error; + if let Some(buf) = reader.try_read_bytes().await? { + let s = ::std::str::from_utf8(&buf) + .map_err(Error::invalid_data)?; + <Self as ::std::str::FromStr>::from_str(s) + .map_err(Error::invalid_data) + .map(Some) + } else { + Ok(None) + } + } + } +} |