about summary refs log tree commit diff
path: root/tvix/castore/src/composition.rs
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/castore/src/composition.rs')
-rw-r--r--tvix/castore/src/composition.rs90
1 files changed, 54 insertions, 36 deletions
diff --git a/tvix/castore/src/composition.rs b/tvix/castore/src/composition.rs
index 9e7b3712fb7a..18a7672846b8 100644
--- a/tvix/castore/src/composition.rs
+++ b/tvix/castore/src/composition.rs
@@ -26,7 +26,7 @@
 //! #[tonic::async_trait]
 //! impl ServiceBuilder for MyBlobServiceConfig {
 //!     type Output = dyn BlobService;
-//!     async fn build(&self, _: &str, _: &CompositionContext<Self::Output>) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>> {
+//!     async fn build(&self, _: &str, _: &CompositionContext) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>> {
 //!         todo!()
 //!     }
 //! }
@@ -49,6 +49,7 @@
 //! ### Example 2.: Composing stores to get one store
 //!
 //! ```
+//! use std::sync::Arc;
 //! use tvix_castore::composition::*;
 //! use tvix_castore::blobservice::BlobService;
 //!
@@ -69,8 +70,9 @@
 //! });
 //!
 //! let blob_services_configs = with_registry(&REG, || serde_json::from_value(blob_services_configs_json))?;
-//! let blob_service_composition = Composition::<dyn BlobService>::from_configs(blob_services_configs);
-//! let blob_service = blob_service_composition.build("default").await?;
+//! let mut blob_service_composition = Composition::default();
+//! blob_service_composition.extend_with_configs::<dyn BlobService>(blob_services_configs);
+//! let blob_service: Arc<dyn BlobService> = blob_service_composition.build("default").await?;
 //! # Ok(())
 //! # })
 //! # }
@@ -271,12 +273,12 @@ pub fn add_default_services(reg: &mut Registry) {
     crate::directoryservice::register_directory_services(reg);
 }
 
-pub struct CompositionContext<'a, T: ?Sized> {
+pub struct CompositionContext<'a> {
     stack: Vec<String>,
-    composition: Option<&'a Composition<T>>,
+    composition: Option<&'a Composition>,
 }
 
-impl<'a, T: ?Sized + Send + Sync + 'static> CompositionContext<'a, T> {
+impl<'a> CompositionContext<'a> {
     pub fn blank() -> Self {
         Self {
             stack: Default::default(),
@@ -284,7 +286,7 @@ impl<'a, T: ?Sized + Send + Sync + 'static> CompositionContext<'a, T> {
         }
     }
 
-    pub async fn resolve(
+    pub async fn resolve<T: ?Sized + Send + Sync + 'static>(
         &self,
         entrypoint: String,
     ) -> Result<Arc<T>, Box<dyn std::error::Error + Send + Sync + 'static>> {
@@ -307,7 +309,7 @@ pub trait ServiceBuilder: Send + Sync {
     async fn build(
         &self,
         instance_name: &str,
-        context: &CompositionContext<Self::Output>,
+        context: &CompositionContext,
     ) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>>;
 }
 
@@ -325,8 +327,9 @@ enum InstantiationState<T: ?Sized> {
     Done(Result<Arc<T>, CompositionError>),
 }
 
-pub struct Composition<T: ?Sized> {
-    stores: std::sync::Mutex<HashMap<String, InstantiationState<T>>>,
+#[derive(Default)]
+pub struct Composition {
+    stores: std::sync::Mutex<HashMap<(TypeId, String), Box<dyn Any + Send + Sync>>>,
 }
 
 #[derive(thiserror::Error, Clone, Debug)]
@@ -341,44 +344,57 @@ pub enum CompositionError {
     Failed(String, Arc<dyn std::error::Error + Send + Sync>),
 }
 
-impl<T: ?Sized + Send + Sync + 'static> Composition<T> {
-    pub fn from_configs(
-        // Keep the concrete `HashMap` type here since it allows for type
-        // inference of what type is previously being deserialized.
-        configs: HashMap<String, DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>>,
-    ) -> Self {
-        Self::from_configs_iter(configs)
-    }
-
-    pub fn from_configs_iter(
-        configs: impl IntoIterator<
+impl<T: ?Sized + Send + Sync + 'static>
+    Extend<(
+        String,
+        DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>,
+    )> for Composition
+{
+    fn extend<I>(&mut self, configs: I)
+    where
+        I: IntoIterator<
             Item = (
                 String,
                 DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>,
             ),
         >,
-    ) -> Self {
-        Composition {
-            stores: std::sync::Mutex::new(
-                configs
-                    .into_iter()
-                    .map(|(k, v)| (k, InstantiationState::Config(v.0)))
-                    .collect(),
-            ),
-        }
+    {
+        self.stores
+            .lock()
+            .unwrap()
+            .extend(configs.into_iter().map(|(k, v)| {
+                (
+                    (TypeId::of::<T>(), k),
+                    Box::new(InstantiationState::Config(v.0)) as Box<dyn Any + Send + Sync>,
+                )
+            }))
+    }
+}
+
+impl Composition {
+    pub fn extend_with_configs<T: ?Sized + Send + Sync + 'static>(
+        &mut self,
+        // Keep the concrete `HashMap` type here since it allows for type
+        // inference of what type is previously being deserialized.
+        configs: HashMap<String, DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>>,
+    ) {
+        self.extend(configs);
     }
 
-    pub async fn build(&self, entrypoint: &str) -> Result<Arc<T>, CompositionError> {
+    pub async fn build<T: ?Sized + Send + Sync + 'static>(
+        &self,
+        entrypoint: &str,
+    ) -> Result<Arc<T>, CompositionError> {
         self.build_internal(vec![], entrypoint.to_string()).await
     }
 
-    fn build_internal(
+    fn build_internal<T: ?Sized + Send + Sync + 'static>(
         &self,
         stack: Vec<String>,
         entrypoint: String,
     ) -> BoxFuture<'_, Result<Arc<T>, CompositionError>> {
         let mut stores = self.stores.lock().unwrap();
-        let entry = match stores.get_mut(&entrypoint) {
+        let entry = match stores.get_mut(&(TypeId::of::<T>(), entrypoint.clone())) {
             Some(v) => v,
             None => return Box::pin(futures::future::err(CompositionError::NotFound(entrypoint))),
         };
@@ -387,9 +403,11 @@ impl<T: ?Sized + Send + Sync + 'static> Composition<T> {
         // this temporary value.
         let prev_val = std::mem::replace(
             entry,
-            InstantiationState::Done(Err(CompositionError::Poisoned(entrypoint.clone()))),
+            Box::new(InstantiationState::<T>::Done(Err(
+                CompositionError::Poisoned(entrypoint.clone()),
+            ))),
         );
-        let (new_val, ret) = match prev_val {
+        let (new_val, ret) = match *prev_val.downcast::<InstantiationState<T>>().unwrap() {
             InstantiationState::Done(service) => (
                 InstantiationState::Done(service.clone()),
                 futures::future::ready(service).boxed(),
@@ -433,7 +451,7 @@ impl<T: ?Sized + Send + Sync + 'static> Composition<T> {
                 })
             }
         };
-        *entry = new_val;
+        *entry = Box::new(new_val);
         ret
     }
 }