about summary refs log tree commit diff
path: root/tvix/castore/src/composition.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/castore/src/composition.rs')
-rw-r--r--tvix/castore/src/composition.rs79
1 files changed, 67 insertions, 12 deletions
diff --git a/tvix/castore/src/composition.rs b/tvix/castore/src/composition.rs
index cd4064af9a3e..9e7b3712fb7a 100644
--- a/tvix/castore/src/composition.rs
+++ b/tvix/castore/src/composition.rs
@@ -31,6 +31,13 @@
 //!     }
 //! }
 //!
+//! impl TryFrom<url::Url> for MyBlobServiceConfig {
+//!     type Error = Box<dyn std::error::Error + Send + Sync>;
+//!     fn try_from(url: url::Url) -> Result<Self, Self::Error> {
+//!         todo!()
+//!     }
+//! }
+//!
 //! pub fn add_my_service(reg: &mut Registry) {
 //!     reg.register::<Box<dyn ServiceBuilder<Output = dyn BlobService>>, MyBlobServiceConfig>("myblobservicetype");
 //! }
@@ -100,7 +107,7 @@ use tonic::async_trait;
 // This is really ugly. Really we would want to store this as a generic static field:
 //
 // ```
-// struct Registry<T>(BTreeMap<(&'static str), BoxSeedFn<T>);
+// struct Registry<T>(BTreeMap<(&'static str), RegistryEntry<T>);
 // static REG<T>: Registry<T>;
 // ```
 //
@@ -116,6 +123,12 @@ use tonic::async_trait;
 // I said it was ugly...
 #[derive(Default)]
 pub struct Registry(BTreeMap<(TypeId, &'static str), Box<dyn Any + Sync>>);
+pub type FromUrlSeed<T> =
+    Box<dyn Fn(url::Url) -> Result<T, Box<dyn std::error::Error + Send + Sync>> + Sync>;
+pub struct RegistryEntry<T> {
+    serde_deserialize_seed: BoxFnSeed<DeserializeWithRegistry<T>>,
+    from_url_seed: FromUrlSeed<DeserializeWithRegistry<T>>,
+}
 
 struct RegistryWithFakeType<'r, T>(&'r Registry, PhantomData<T>);
 
@@ -137,7 +150,9 @@ impl<'r, 'de: 'r, T: 'static> SeedFactory<'de, TagString<'de>> for RegistryWithF
             .ok_or_else(|| serde::de::Error::custom("Unknown tag"))?
             .1;
 
-        Ok(<dyn Any>::downcast_ref(&**seed).unwrap())
+        let entry: &RegistryEntry<T> = <dyn Any>::downcast_ref(&**seed).unwrap();
+
+        Ok(&entry.serde_deserialize_seed)
     }
 }
 
@@ -146,7 +161,7 @@ impl<'r, 'de: 'r, T: 'static> SeedFactory<'de, TagString<'de>> for RegistryWithF
 /// Wrap your type in this in order to deserialize it using a registry, e.g.
 /// `RegistryWithFakeType<Box<dyn MyTrait>>`, then the types registered for `Box<dyn MyTrait>`
 /// will be used.
-pub struct DeserializeWithRegistry<T>(T);
+pub struct DeserializeWithRegistry<T>(pub T);
 
 impl Registry {
     /// Registers a mapping from type tag to a concrete type into the registry.
@@ -156,14 +171,30 @@ impl Registry {
     /// deserializes into an input with the type tag "myblobservicetype" into a
     /// `Box<dyn FooTrait>`, it will first call the Deserialize imple of `FooStruct` and
     /// then convert it into a `Box<dyn FooTrait>` using From::from.
-    pub fn register<T: 'static, C: DeserializeOwned + Into<T>>(&mut self, type_name: &'static str) {
-        let seed = BoxFnSeed::new(|x| {
-            deserialize::<C>(x)
-                .map(Into::into)
-                .map(DeserializeWithRegistry)
-        });
-        self.0
-            .insert((TypeId::of::<T>(), type_name), Box::new(seed));
+    pub fn register<
+        T: 'static,
+        C: DeserializeOwned
+            + TryFrom<url::Url, Error = Box<dyn std::error::Error + Send + Sync>>
+            + Into<T>,
+    >(
+        &mut self,
+        type_name: &'static str,
+    ) {
+        self.0.insert(
+            (TypeId::of::<T>(), type_name),
+            Box::new(RegistryEntry {
+                serde_deserialize_seed: BoxFnSeed::new(|x| {
+                    deserialize::<C>(x)
+                        .map(Into::into)
+                        .map(DeserializeWithRegistry)
+                }),
+                from_url_seed: Box::new(|url| {
+                    C::try_from(url)
+                        .map(Into::into)
+                        .map(DeserializeWithRegistry)
+                }),
+            }),
+        );
     }
 }
 
@@ -180,6 +211,30 @@ impl<'de, T: 'static> serde::Deserialize<'de> for DeserializeWithRegistry<T> {
     }
 }
 
+#[derive(Debug, thiserror::Error)]
+enum TryFromUrlError {
+    #[error("Unknown tag: {0}")]
+    UnknownTag(String),
+}
+
+impl<T: 'static> TryFrom<url::Url> for DeserializeWithRegistry<T> {
+    type Error = Box<dyn std::error::Error + Send + Sync>;
+    fn try_from(url: url::Url) -> Result<Self, Self::Error> {
+        let tag = url.scheme().split('+').next().unwrap();
+        // same as in the SeedFactory impl: using find() and not get() because of https://github.com/rust-lang/rust/issues/80389
+        let seed = ACTIVE_REG
+            .get()
+            .unwrap()
+            .0
+            .iter()
+            .find(|(k, _)| *k == &(TypeId::of::<T>(), tag))
+            .ok_or_else(|| Box::new(TryFromUrlError::UnknownTag(tag.into())))?
+            .1;
+        let entry: &RegistryEntry<T> = <dyn Any>::downcast_ref(&**seed).unwrap();
+        (entry.from_url_seed)(url)
+    }
+}
+
 thread_local! {
     /// The active Registry is global state, because there is no convenient and universal way to pass state
     /// into the functions usually used for deserialization, e.g. `serde_json::from_str`, `toml::from_str`,
@@ -200,7 +255,7 @@ pub fn with_registry<R>(reg: &'static Registry, f: impl FnOnce() -> R) -> R {
 lazy_static! {
     /// The provided registry of tvix_castore, with all builtin BlobStore/DirectoryStore implementations
     pub static ref REG: Registry = {
-        let mut reg = Registry(Default::default());
+        let mut reg = Default::default();
         add_default_services(&mut reg);
         reg
     };