about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--tvix/castore/src/blobservice/combinator.rs2
-rw-r--r--tvix/castore/src/blobservice/grpc.rs2
-rw-r--r--tvix/castore/src/blobservice/memory.rs2
-rw-r--r--tvix/castore/src/blobservice/object_store.rs2
-rw-r--r--tvix/castore/src/composition.rs90
-rw-r--r--tvix/castore/src/directoryservice/bigtable.rs2
-rw-r--r--tvix/castore/src/directoryservice/combinators.rs2
-rw-r--r--tvix/castore/src/directoryservice/grpc.rs2
-rw-r--r--tvix/castore/src/directoryservice/memory.rs2
-rw-r--r--tvix/castore/src/directoryservice/object_store.rs2
-rw-r--r--tvix/castore/src/directoryservice/sled.rs2
11 files changed, 64 insertions, 46 deletions
diff --git a/tvix/castore/src/blobservice/combinator.rs b/tvix/castore/src/blobservice/combinator.rs
index 8ec5a859bcda..fc33d16a3473 100644
--- a/tvix/castore/src/blobservice/combinator.rs
+++ b/tvix/castore/src/blobservice/combinator.rs
@@ -119,7 +119,7 @@ impl ServiceBuilder for CombinedBlobServiceConfig {
     async fn build<'a>(
         &'a self,
         _instance_name: &str,
-        context: &CompositionContext<dyn BlobService>,
+        context: &CompositionContext,
     ) -> Result<Arc<dyn BlobService>, Box<dyn std::error::Error + Send + Sync>> {
         let (local, remote) = futures::join!(
             context.resolve(self.local.clone()),
diff --git a/tvix/castore/src/blobservice/grpc.rs b/tvix/castore/src/blobservice/grpc.rs
index f5705adbf432..0db3dfea4ad8 100644
--- a/tvix/castore/src/blobservice/grpc.rs
+++ b/tvix/castore/src/blobservice/grpc.rs
@@ -206,7 +206,7 @@ impl ServiceBuilder for GRPCBlobServiceConfig {
     async fn build<'a>(
         &'a self,
         _instance_name: &str,
-        _context: &CompositionContext<dyn BlobService>,
+        _context: &CompositionContext,
     ) -> Result<Arc<dyn BlobService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
         let client = proto::blob_service_client::BlobServiceClient::new(
             crate::tonic::channel_from_url(&self.url.parse()?).await?,
diff --git a/tvix/castore/src/blobservice/memory.rs b/tvix/castore/src/blobservice/memory.rs
index 83b37edb1c89..3d733f950470 100644
--- a/tvix/castore/src/blobservice/memory.rs
+++ b/tvix/castore/src/blobservice/memory.rs
@@ -59,7 +59,7 @@ impl ServiceBuilder for MemoryBlobServiceConfig {
     async fn build<'a>(
         &'a self,
         _instance_name: &str,
-        _context: &CompositionContext<dyn BlobService>,
+        _context: &CompositionContext,
     ) -> Result<Arc<dyn BlobService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
         Ok(Arc::new(MemoryBlobService::default()))
     }
diff --git a/tvix/castore/src/blobservice/object_store.rs b/tvix/castore/src/blobservice/object_store.rs
index d898ce19e56c..5bb05cf26123 100644
--- a/tvix/castore/src/blobservice/object_store.rs
+++ b/tvix/castore/src/blobservice/object_store.rs
@@ -295,7 +295,7 @@ impl ServiceBuilder for ObjectStoreBlobServiceConfig {
     async fn build<'a>(
         &'a self,
         _instance_name: &str,
-        _context: &CompositionContext<dyn BlobService>,
+        _context: &CompositionContext,
     ) -> Result<Arc<dyn BlobService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
         let (object_store, path) = object_store::parse_url_opts(
             &self.object_store_url.parse()?,
diff --git a/tvix/castore/src/composition.rs b/tvix/castore/src/composition.rs
index 9e7b3712fb7a..18a7672846b8 100644
--- a/tvix/castore/src/composition.rs
+++ b/tvix/castore/src/composition.rs
@@ -26,7 +26,7 @@
 //! #[tonic::async_trait]
 //! impl ServiceBuilder for MyBlobServiceConfig {
 //!     type Output = dyn BlobService;
-//!     async fn build(&self, _: &str, _: &CompositionContext<Self::Output>) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>> {
+//!     async fn build(&self, _: &str, _: &CompositionContext) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>> {
 //!         todo!()
 //!     }
 //! }
@@ -49,6 +49,7 @@
 //! ### Example 2.: Composing stores to get one store
 //!
 //! ```
+//! use std::sync::Arc;
 //! use tvix_castore::composition::*;
 //! use tvix_castore::blobservice::BlobService;
 //!
@@ -69,8 +70,9 @@
 //! });
 //!
 //! let blob_services_configs = with_registry(&REG, || serde_json::from_value(blob_services_configs_json))?;
-//! let blob_service_composition = Composition::<dyn BlobService>::from_configs(blob_services_configs);
-//! let blob_service = blob_service_composition.build("default").await?;
+//! let mut blob_service_composition = Composition::default();
+//! blob_service_composition.extend_with_configs::<dyn BlobService>(blob_services_configs);
+//! let blob_service: Arc<dyn BlobService> = blob_service_composition.build("default").await?;
 //! # Ok(())
 //! # })
 //! # }
@@ -271,12 +273,12 @@ pub fn add_default_services(reg: &mut Registry) {
     crate::directoryservice::register_directory_services(reg);
 }
 
-pub struct CompositionContext<'a, T: ?Sized> {
+pub struct CompositionContext<'a> {
     stack: Vec<String>,
-    composition: Option<&'a Composition<T>>,
+    composition: Option<&'a Composition>,
 }
 
-impl<'a, T: ?Sized + Send + Sync + 'static> CompositionContext<'a, T> {
+impl<'a> CompositionContext<'a> {
     pub fn blank() -> Self {
         Self {
             stack: Default::default(),
@@ -284,7 +286,7 @@ impl<'a, T: ?Sized + Send + Sync + 'static> CompositionContext<'a, T> {
         }
     }
 
-    pub async fn resolve(
+    pub async fn resolve<T: ?Sized + Send + Sync + 'static>(
         &self,
         entrypoint: String,
     ) -> Result<Arc<T>, Box<dyn std::error::Error + Send + Sync + 'static>> {
@@ -307,7 +309,7 @@ pub trait ServiceBuilder: Send + Sync {
     async fn build(
         &self,
         instance_name: &str,
-        context: &CompositionContext<Self::Output>,
+        context: &CompositionContext,
     ) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>>;
 }
 
@@ -325,8 +327,9 @@ enum InstantiationState<T: ?Sized> {
     Done(Result<Arc<T>, CompositionError>),
 }
 
-pub struct Composition<T: ?Sized> {
-    stores: std::sync::Mutex<HashMap<String, InstantiationState<T>>>,
+#[derive(Default)]
+pub struct Composition {
+    stores: std::sync::Mutex<HashMap<(TypeId, String), Box<dyn Any + Send + Sync>>>,
 }
 
 #[derive(thiserror::Error, Clone, Debug)]
@@ -341,44 +344,57 @@ pub enum CompositionError {
     Failed(String, Arc<dyn std::error::Error + Send + Sync>),
 }
 
-impl<T: ?Sized + Send + Sync + 'static> Composition<T> {
-    pub fn from_configs(
-        // Keep the concrete `HashMap` type here since it allows for type
-        // inference of what type is previously being deserialized.
-        configs: HashMap<String, DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>>,
-    ) -> Self {
-        Self::from_configs_iter(configs)
-    }
-
-    pub fn from_configs_iter(
-        configs: impl IntoIterator<
+impl<T: ?Sized + Send + Sync + 'static>
+    Extend<(
+        String,
+        DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>,
+    )> for Composition
+{
+    fn extend<I>(&mut self, configs: I)
+    where
+        I: IntoIterator<
             Item = (
                 String,
                 DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>,
             ),
         >,
-    ) -> Self {
-        Composition {
-            stores: std::sync::Mutex::new(
-                configs
-                    .into_iter()
-                    .map(|(k, v)| (k, InstantiationState::Config(v.0)))
-                    .collect(),
-            ),
-        }
+    {
+        self.stores
+            .lock()
+            .unwrap()
+            .extend(configs.into_iter().map(|(k, v)| {
+                (
+                    (TypeId::of::<T>(), k),
+                    Box::new(InstantiationState::Config(v.0)) as Box<dyn Any + Send + Sync>,
+                )
+            }))
+    }
+}
+
+impl Composition {
+    pub fn extend_with_configs<T: ?Sized + Send + Sync + 'static>(
+        &mut self,
+        // Keep the concrete `HashMap` type here since it allows for type
+        // inference of what type is previously being deserialized.
+        configs: HashMap<String, DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>>,
+    ) {
+        self.extend(configs);
     }
 
-    pub async fn build(&self, entrypoint: &str) -> Result<Arc<T>, CompositionError> {
+    pub async fn build<T: ?Sized + Send + Sync + 'static>(
+        &self,
+        entrypoint: &str,
+    ) -> Result<Arc<T>, CompositionError> {
         self.build_internal(vec![], entrypoint.to_string()).await
     }
 
-    fn build_internal(
+    fn build_internal<T: ?Sized + Send + Sync + 'static>(
         &self,
         stack: Vec<String>,
         entrypoint: String,
     ) -> BoxFuture<'_, Result<Arc<T>, CompositionError>> {
         let mut stores = self.stores.lock().unwrap();
-        let entry = match stores.get_mut(&entrypoint) {
+        let entry = match stores.get_mut(&(TypeId::of::<T>(), entrypoint.clone())) {
             Some(v) => v,
             None => return Box::pin(futures::future::err(CompositionError::NotFound(entrypoint))),
         };
@@ -387,9 +403,11 @@ impl<T: ?Sized + Send + Sync + 'static> Composition<T> {
         // this temporary value.
         let prev_val = std::mem::replace(
             entry,
-            InstantiationState::Done(Err(CompositionError::Poisoned(entrypoint.clone()))),
+            Box::new(InstantiationState::<T>::Done(Err(
+                CompositionError::Poisoned(entrypoint.clone()),
+            ))),
         );
-        let (new_val, ret) = match prev_val {
+        let (new_val, ret) = match *prev_val.downcast::<InstantiationState<T>>().unwrap() {
             InstantiationState::Done(service) => (
                 InstantiationState::Done(service.clone()),
                 futures::future::ready(service).boxed(),
@@ -433,7 +451,7 @@ impl<T: ?Sized + Send + Sync + 'static> Composition<T> {
                 })
             }
         };
-        *entry = new_val;
+        *entry = Box::new(new_val);
         ret
     }
 }
diff --git a/tvix/castore/src/directoryservice/bigtable.rs b/tvix/castore/src/directoryservice/bigtable.rs
index 596094930614..d10dddaf9f60 100644
--- a/tvix/castore/src/directoryservice/bigtable.rs
+++ b/tvix/castore/src/directoryservice/bigtable.rs
@@ -353,7 +353,7 @@ impl ServiceBuilder for BigtableParameters {
     async fn build<'a>(
         &'a self,
         _instance_name: &str,
-        _context: &CompositionContext<dyn DirectoryService>,
+        _context: &CompositionContext,
     ) -> Result<Arc<dyn DirectoryService>, Box<dyn std::error::Error + Send + Sync>> {
         Ok(Arc::new(
             BigtableDirectoryService::connect(self.clone()).await?,
diff --git a/tvix/castore/src/directoryservice/combinators.rs b/tvix/castore/src/directoryservice/combinators.rs
index 74d02f1ad2b9..0fdc82c16cb0 100644
--- a/tvix/castore/src/directoryservice/combinators.rs
+++ b/tvix/castore/src/directoryservice/combinators.rs
@@ -167,7 +167,7 @@ impl ServiceBuilder for CacheConfig {
     async fn build<'a>(
         &'a self,
         _instance_name: &str,
-        context: &CompositionContext<dyn DirectoryService>,
+        context: &CompositionContext,
     ) -> Result<Arc<dyn DirectoryService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
         let (near, far) = futures::join!(
             context.resolve(self.near.clone()),
diff --git a/tvix/castore/src/directoryservice/grpc.rs b/tvix/castore/src/directoryservice/grpc.rs
index 415796fa52cc..ff08bad4bd0f 100644
--- a/tvix/castore/src/directoryservice/grpc.rs
+++ b/tvix/castore/src/directoryservice/grpc.rs
@@ -243,7 +243,7 @@ impl ServiceBuilder for GRPCDirectoryServiceConfig {
     async fn build<'a>(
         &'a self,
         _instance_name: &str,
-        _context: &CompositionContext<dyn DirectoryService>,
+        _context: &CompositionContext,
     ) -> Result<Arc<dyn DirectoryService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
         let client = proto::directory_service_client::DirectoryServiceClient::new(
             crate::tonic::channel_from_url(&self.url.parse()?).await?,
diff --git a/tvix/castore/src/directoryservice/memory.rs b/tvix/castore/src/directoryservice/memory.rs
index c1fc361f0d59..ada4606a5a57 100644
--- a/tvix/castore/src/directoryservice/memory.rs
+++ b/tvix/castore/src/directoryservice/memory.rs
@@ -108,7 +108,7 @@ impl ServiceBuilder for MemoryDirectoryServiceConfig {
     async fn build<'a>(
         &'a self,
         _instance_name: &str,
-        _context: &CompositionContext<dyn DirectoryService>,
+        _context: &CompositionContext,
     ) -> Result<Arc<dyn DirectoryService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
         Ok(Arc::new(MemoryDirectoryService::default()))
     }
diff --git a/tvix/castore/src/directoryservice/object_store.rs b/tvix/castore/src/directoryservice/object_store.rs
index 0f0423a49e5b..a9a2cc8ef5c0 100644
--- a/tvix/castore/src/directoryservice/object_store.rs
+++ b/tvix/castore/src/directoryservice/object_store.rs
@@ -211,7 +211,7 @@ impl ServiceBuilder for ObjectStoreDirectoryServiceConfig {
     async fn build<'a>(
         &'a self,
         _instance_name: &str,
-        _context: &CompositionContext<dyn DirectoryService>,
+        _context: &CompositionContext,
     ) -> Result<Arc<dyn DirectoryService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
         let (object_store, path) = object_store::parse_url_opts(
             &self.object_store_url.parse()?,
diff --git a/tvix/castore/src/directoryservice/sled.rs b/tvix/castore/src/directoryservice/sled.rs
index 61058b392bb3..5766dec1a5c2 100644
--- a/tvix/castore/src/directoryservice/sled.rs
+++ b/tvix/castore/src/directoryservice/sled.rs
@@ -176,7 +176,7 @@ impl ServiceBuilder for SledDirectoryServiceConfig {
     async fn build<'a>(
         &'a self,
         _instance_name: &str,
-        _context: &CompositionContext<dyn DirectoryService>,
+        _context: &CompositionContext,
     ) -> Result<Arc<dyn DirectoryService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
         match self {
             SledDirectoryServiceConfig {