diff options
Diffstat (limited to 'tvix/castore/src/composition.rs')
-rw-r--r-- | tvix/castore/src/composition.rs | 581 |
1 files changed, 581 insertions, 0 deletions
diff --git a/tvix/castore/src/composition.rs b/tvix/castore/src/composition.rs new file mode 100644 index 000000000000..b251187e1c34 --- /dev/null +++ b/tvix/castore/src/composition.rs @@ -0,0 +1,581 @@ +//! The composition module allows composing different kinds of services based on a set of service +//! configurations _at runtime_. +//! +//! Store configs are deserialized with serde. The registry provides a stateful mapping from the +//! `type` tag of an internally tagged enum on the serde side to a Config struct which is +//! deserialized and then returned as a `Box<dyn ServiceBuilder<Output = dyn BlobService>>` +//! (the same for DirectoryService instead of BlobService etc). +//! +//! ### Example 1.: Implementing a new BlobService +//! +//! You need a Config struct which implements `DeserializeOwned` and +//! `ServiceBuilder<Output = dyn BlobService>`. +//! Provide the user with a function to call with +//! their registry. You register your new type as: +//! +//! ``` +//! use std::sync::Arc; +//! +//! use tvix_castore::composition::*; +//! use tvix_castore::blobservice::BlobService; +//! +//! #[derive(serde::Deserialize)] +//! struct MyBlobServiceConfig { +//! } +//! +//! #[tonic::async_trait] +//! impl ServiceBuilder for MyBlobServiceConfig { +//! type Output = dyn BlobService; +//! async fn build(&self, _: &str, _: &CompositionContext) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>> { +//! todo!() +//! } +//! } +//! +//! impl TryFrom<url::Url> for MyBlobServiceConfig { +//! type Error = Box<dyn std::error::Error + Send + Sync>; +//! fn try_from(url: url::Url) -> Result<Self, Self::Error> { +//! todo!() +//! } +//! } +//! +//! pub fn add_my_service(reg: &mut Registry) { +//! reg.register::<Box<dyn ServiceBuilder<Output = dyn BlobService>>, MyBlobServiceConfig>("myblobservicetype"); +//! } +//! ``` +//! +//! Now, when a user deserializes a store config with the type tag "myblobservicetype" into a +//! `Box<dyn ServiceBuilder<Output = Arc<dyn BlobService>>>`, it will be done via `MyBlobServiceConfig`. +//! +//! ### Example 2.: Composing stores to get one store +//! +//! ``` +//! use std::sync::Arc; +//! use tvix_castore::composition::*; +//! use tvix_castore::blobservice::BlobService; +//! +//! # fn main() -> Result<(), Box<dyn std::error::Error>> { +//! # tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap().block_on(async move { +//! let blob_services_configs_json = serde_json::json!({ +//! "blobstore1": { +//! "type": "memory" +//! }, +//! "blobstore2": { +//! "type": "memory" +//! }, +//! "default": { +//! "type": "combined", +//! "local": "blobstore1", +//! "remote": "blobstore2" +//! } +//! }); +//! +//! let blob_services_configs = with_registry(®, || serde_json::from_value(blob_services_configs_json))?; +//! let mut blob_service_composition = Composition::new(®); +//! blob_service_composition.extend_with_configs::<dyn BlobService>(blob_services_configs); +//! let blob_service: Arc<dyn BlobService> = blob_service_composition.build("default").await?; +//! # Ok(()) +//! # }) +//! # } +//! ``` +//! +//! ### Example 3.: Creating another registry extending the default registry with third-party types +//! +//! ``` +//! # pub fn add_my_service(reg: &mut tvix_castore::composition::Registry) {} +//! let mut my_registry = tvix_castore::composition::Registry::default(); +//! tvix_castore::composition::add_default_services(&mut my_registry); +//! add_my_service(&mut my_registry); +//! ``` +//! +//! Continue with Example 2, with my_registry instead of REG +//! +//! EXPERIMENTAL: If the xp-store-composition feature is enabled, +//! entrypoints can also be URL strings, which are created as +//! anonymous stores. Instantiations of the same URL will +//! result in a new, distinct anonymous store each time, so creating +//! two `memory://` stores with this method will not share the same view. +//! This behavior might change in the future. + +use erased_serde::deserialize; +use futures::future::BoxFuture; +use futures::FutureExt; +use serde::de::DeserializeOwned; +use serde_tagged::de::{BoxFnSeed, SeedFactory}; +use serde_tagged::util::TagString; +use std::any::{Any, TypeId}; +use std::cell::Cell; +use std::collections::BTreeMap; +use std::collections::HashMap; +use std::marker::PhantomData; +use std::sync::{Arc, LazyLock}; +use tonic::async_trait; + +/// Resolves tag names to the corresponding Config type. +// Registry implementation details: +// This is really ugly. Really we would want to store this as a generic static field: +// +// ``` +// struct Registry<T>(BTreeMap<(&'static str), RegistryEntry<T>); +// static REG<T>: Registry<T>; +// ``` +// +// so that one version of the static is generated for each Type that the registry is accessed for. +// However, this is not possible, because generics are only a thing in functions, and even there +// they will not interact with static items: +// https://doc.rust-lang.org/reference/items/static-items.html#statics--generics +// +// So instead, we make this lookup at runtime by putting the TypeId into the key. +// But now we can no longer store the `BoxFnSeed<T>` because we are lacking the generic parameter +// T, so instead store it as `Box<dyn Any>` and downcast to `&BoxFnSeed<T>` when performing the +// lookup. +// I said it was ugly... +#[derive(Default)] +pub struct Registry(BTreeMap<(TypeId, &'static str), Box<dyn Any + Sync>>); +pub type FromUrlSeed<T> = + Box<dyn Fn(url::Url) -> Result<T, Box<dyn std::error::Error + Send + Sync>> + Sync>; +pub struct RegistryEntry<T> { + serde_deserialize_seed: BoxFnSeed<DeserializeWithRegistry<T>>, + from_url_seed: FromUrlSeed<DeserializeWithRegistry<T>>, +} + +struct RegistryWithFakeType<'r, T>(&'r Registry, PhantomData<T>); + +impl<'r, 'de: 'r, T: 'static> SeedFactory<'de, TagString<'de>> for RegistryWithFakeType<'r, T> { + type Value = DeserializeWithRegistry<T>; + type Seed = &'r BoxFnSeed<Self::Value>; + + // Required method + fn seed<E>(self, tag: TagString<'de>) -> Result<Self::Seed, E> + where + E: serde::de::Error, + { + // using find() and not get() because of https://github.com/rust-lang/rust/issues/80389 + let seed: &Box<dyn Any + Sync> = self + .0 + .0 + .iter() + .find(|(k, _)| *k == &(TypeId::of::<T>(), tag.as_ref())) + .ok_or_else(|| serde::de::Error::custom(format!("Unknown type: {}", tag)))? + .1; + + let entry: &RegistryEntry<T> = <dyn Any>::downcast_ref(&**seed).unwrap(); + + Ok(&entry.serde_deserialize_seed) + } +} + +/// Wrapper type which implements Deserialize using the registry +/// +/// Wrap your type in this in order to deserialize it using a registry, e.g. +/// `RegistryWithFakeType<Box<dyn MyTrait>>`, then the types registered for `Box<dyn MyTrait>` +/// will be used. +pub struct DeserializeWithRegistry<T>(pub T); + +impl Registry { + /// Registers a mapping from type tag to a concrete type into the registry. + /// + /// The type parameters are very important: + /// After calling `register::<Box<dyn FooTrait>, FooStruct>("footype")`, when a user + /// deserializes into an input with the type tag "myblobservicetype" into a + /// `Box<dyn FooTrait>`, it will first call the Deserialize imple of `FooStruct` and + /// then convert it into a `Box<dyn FooTrait>` using From::from. + pub fn register< + T: 'static, + C: DeserializeOwned + + TryFrom<url::Url, Error = Box<dyn std::error::Error + Send + Sync>> + + Into<T>, + >( + &mut self, + type_name: &'static str, + ) { + self.0.insert( + (TypeId::of::<T>(), type_name), + Box::new(RegistryEntry { + serde_deserialize_seed: BoxFnSeed::new(|x| { + deserialize::<C>(x) + .map(Into::into) + .map(DeserializeWithRegistry) + }), + from_url_seed: Box::new(|url| { + C::try_from(url) + .map(Into::into) + .map(DeserializeWithRegistry) + }), + }), + ); + } +} + +impl<'de, T: 'static> serde::Deserialize<'de> for DeserializeWithRegistry<T> { + fn deserialize<D>(de: D) -> std::result::Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + serde_tagged::de::internal::deserialize( + de, + "type", + RegistryWithFakeType(ACTIVE_REG.get().unwrap(), PhantomData::<T>), + ) + } +} + +#[derive(Debug, thiserror::Error)] +enum TryFromUrlError { + #[error("Unknown type: {0}")] + UnknownTag(String), +} + +impl<T: 'static> TryFrom<url::Url> for DeserializeWithRegistry<T> { + type Error = Box<dyn std::error::Error + Send + Sync>; + fn try_from(url: url::Url) -> Result<Self, Self::Error> { + let tag = url.scheme().split('+').next().unwrap(); + // same as in the SeedFactory impl: using find() and not get() because of https://github.com/rust-lang/rust/issues/80389 + let seed = ACTIVE_REG + .get() + .unwrap() + .0 + .iter() + .find(|(k, _)| *k == &(TypeId::of::<T>(), tag)) + .ok_or_else(|| Box::new(TryFromUrlError::UnknownTag(tag.into())))? + .1; + let entry: &RegistryEntry<T> = <dyn Any>::downcast_ref(&**seed).unwrap(); + (entry.from_url_seed)(url) + } +} + +thread_local! { + /// The active Registry is global state, because there is no convenient and universal way to pass state + /// into the functions usually used for deserialization, e.g. `serde_json::from_str`, `toml::from_str`, + /// `serde_qs::from_str`. + static ACTIVE_REG: Cell<Option<&'static Registry>> = panic!("reg was accessed before initialization"); +} + +/// Run the provided closure with a registry context. +/// Any serde deserialize calls within the closure will use the registry to resolve tag names to +/// the corresponding Config type. +pub fn with_registry<R>(reg: &'static Registry, f: impl FnOnce() -> R) -> R { + ACTIVE_REG.set(Some(reg)); + let result = f(); + ACTIVE_REG.set(None); + result +} + +/// The provided registry of tvix_castore, with all builtin BlobStore/DirectoryStore implementations +pub static REG: LazyLock<&'static Registry> = LazyLock::new(|| { + let mut reg = Default::default(); + add_default_services(&mut reg); + // explicitly leak to get an &'static, so that we gain `&Registry: Send` from `Registry: Sync` + Box::leak(Box::new(reg)) +}); + +// ---------- End of generic registry code --------- // + +/// Register the builtin services of tvix_castore with the given registry. +/// This is useful for creating your own registry with the builtin types _and_ +/// extra third party types. +pub fn add_default_services(reg: &mut Registry) { + crate::blobservice::register_blob_services(reg); + crate::directoryservice::register_directory_services(reg); +} + +pub struct CompositionContext<'a> { + // The stack used to detect recursive instantiations and prevent deadlocks + // The TypeId of the trait object is included to distinguish e.g. the + // BlobService "default" and the DirectoryService "default". + stack: Vec<(TypeId, String)>, + registry: &'static Registry, + composition: Option<&'a Composition>, +} + +impl<'a> CompositionContext<'a> { + /// Get a composition context for one-off store creation. + pub fn blank(registry: &'static Registry) -> Self { + Self { + registry, + stack: Default::default(), + composition: None, + } + } + + pub async fn resolve<T: ?Sized + Send + Sync + 'static>( + &self, + entrypoint: String, + ) -> Result<Arc<T>, Box<dyn std::error::Error + Send + Sync + 'static>> { + // disallow recursion + if self + .stack + .contains(&(TypeId::of::<T>(), entrypoint.clone())) + { + return Err(CompositionError::Recursion( + self.stack.iter().map(|(_, n)| n.clone()).collect(), + ) + .into()); + } + + Ok(self.build_internal(entrypoint).await?) + } + + #[cfg(feature = "xp-store-composition")] + async fn build_anonymous<T: ?Sized + Send + Sync + 'static>( + &self, + entrypoint: String, + ) -> Result<Arc<T>, Box<dyn std::error::Error + Send + Sync>> { + let url = url::Url::parse(&entrypoint)?; + let config: DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>> = + with_registry(self.registry, || url.try_into())?; + config.0.build("anonymous", self).await + } + + fn build_internal<T: ?Sized + Send + Sync + 'static>( + &self, + entrypoint: String, + ) -> BoxFuture<'_, Result<Arc<T>, CompositionError>> { + #[cfg(feature = "xp-store-composition")] + if entrypoint.contains("://") { + // There is a chance this is a url. we are building an anonymous store + return Box::pin(async move { + self.build_anonymous(entrypoint.clone()) + .await + .map_err(|e| CompositionError::Failed(entrypoint, Arc::from(e))) + }); + } + + let mut stores = match self.composition { + Some(comp) => comp.stores.lock().unwrap(), + None => return Box::pin(futures::future::err(CompositionError::NotFound(entrypoint))), + }; + let entry = match stores.get_mut(&(TypeId::of::<T>(), entrypoint.clone())) { + Some(v) => v, + None => return Box::pin(futures::future::err(CompositionError::NotFound(entrypoint))), + }; + // for lifetime reasons, we put a placeholder value in the hashmap while we figure out what + // the new value should be. the Mutex stays locked the entire time, so nobody will ever see + // this temporary value. + let prev_val = std::mem::replace( + entry, + Box::new(InstantiationState::<T>::Done(Err( + CompositionError::Poisoned(entrypoint.clone()), + ))), + ); + let (new_val, ret) = match *prev_val.downcast::<InstantiationState<T>>().unwrap() { + InstantiationState::Done(service) => ( + InstantiationState::Done(service.clone()), + futures::future::ready(service).boxed(), + ), + // the construction of the store has not started yet. + InstantiationState::Config(config) => { + let (tx, rx) = tokio::sync::watch::channel(None); + ( + InstantiationState::InProgress(rx), + (async move { + let mut new_context = CompositionContext { + composition: self.composition, + registry: self.registry, + stack: self.stack.clone(), + }; + new_context + .stack + .push((TypeId::of::<T>(), entrypoint.clone())); + let res = + config.build(&entrypoint, &new_context).await.map_err(|e| { + match e.downcast() { + Ok(e) => *e, + Err(e) => CompositionError::Failed(entrypoint, e.into()), + } + }); + tx.send(Some(res.clone())).unwrap(); + res + }) + .boxed(), + ) + } + // there is already a task driving forward the construction of this store, wait for it + // to notify us via the provided channel + InstantiationState::InProgress(mut recv) => { + (InstantiationState::InProgress(recv.clone()), { + (async move { + loop { + if let Some(v) = + recv.borrow_and_update().as_ref().map(|res| res.clone()) + { + break v; + } + recv.changed().await.unwrap(); + } + }) + .boxed() + }) + } + }; + *entry = Box::new(new_val); + ret + } +} + +#[async_trait] +/// This is the trait usually implemented on a per-store-type Config struct and +/// used to instantiate it. +pub trait ServiceBuilder: Send + Sync { + type Output: ?Sized; + async fn build( + &self, + instance_name: &str, + context: &CompositionContext, + ) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>>; +} + +impl<T: ?Sized, S: ServiceBuilder<Output = T> + 'static> From<S> + for Box<dyn ServiceBuilder<Output = T>> +{ + fn from(t: S) -> Self { + Box::new(t) + } +} + +enum InstantiationState<T: ?Sized> { + Config(Box<dyn ServiceBuilder<Output = T>>), + InProgress(tokio::sync::watch::Receiver<Option<Result<Arc<T>, CompositionError>>>), + Done(Result<Arc<T>, CompositionError>), +} + +pub struct Composition { + registry: &'static Registry, + stores: std::sync::Mutex<HashMap<(TypeId, String), Box<dyn Any + Send + Sync>>>, +} + +#[derive(thiserror::Error, Clone, Debug)] +pub enum CompositionError { + #[error("store not found: {0}")] + NotFound(String), + #[error("recursion not allowed {0:?}")] + Recursion(Vec<String>), + #[error("store construction panicked {0}")] + Poisoned(String), + #[error("instantiation of service {0} failed: {1}")] + Failed(String, Arc<dyn std::error::Error + Send + Sync>), +} + +impl<T: ?Sized + Send + Sync + 'static> + Extend<( + String, + DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>, + )> for Composition +{ + fn extend<I>(&mut self, configs: I) + where + I: IntoIterator< + Item = ( + String, + DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>, + ), + >, + { + self.stores + .lock() + .unwrap() + .extend(configs.into_iter().map(|(k, v)| { + ( + (TypeId::of::<T>(), k), + Box::new(InstantiationState::Config(v.0)) as Box<dyn Any + Send + Sync>, + ) + })) + } +} + +impl Composition { + /// The given registry will be used for creation of anonymous stores during composition + pub fn new(registry: &'static Registry) -> Self { + Self { + registry, + stores: Default::default(), + } + } + + pub fn extend_with_configs<T: ?Sized + Send + Sync + 'static>( + &mut self, + // Keep the concrete `HashMap` type here since it allows for type + // inference of what type is previously being deserialized. + configs: HashMap<String, DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>>, + ) { + self.extend(configs); + } + + /// Looks up the entrypoint name in the composition and returns an instantiated service. + pub async fn build<T: ?Sized + Send + Sync + 'static>( + &self, + entrypoint: &str, + ) -> Result<Arc<T>, CompositionError> { + self.context().build_internal(entrypoint.to_string()).await + } + + pub fn context(&self) -> CompositionContext { + CompositionContext { + registry: self.registry, + stack: vec![], + composition: Some(self), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::blobservice::BlobService; + use std::sync::Arc; + + /// Test that we return a reference to the same instance of MemoryBlobService (via ptr_eq) + /// when instantiating the same entrypoint twice. By instantiating concurrently, we also + /// test the channels notifying the second consumer when the store has been instantiated. + #[tokio::test] + async fn concurrent() { + let blob_services_configs_json = serde_json::json!({ + "default": { + "type": "memory", + } + }); + + let blob_services_configs = + with_registry(®, || serde_json::from_value(blob_services_configs_json)).unwrap(); + let mut blob_service_composition = Composition::new(®); + blob_service_composition.extend_with_configs::<dyn BlobService>(blob_services_configs); + let (blob_service1, blob_service2) = tokio::join!( + blob_service_composition.build::<dyn BlobService>("default"), + blob_service_composition.build::<dyn BlobService>("default") + ); + assert!(Arc::ptr_eq( + &blob_service1.unwrap(), + &blob_service2.unwrap() + )); + } + + /// Test that we throw the correct error when an instantiation would recurse (deadlock) + #[tokio::test] + async fn reject_recursion() { + let blob_services_configs_json = serde_json::json!({ + "default": { + "type": "combined", + "local": "other", + "remote": "other" + }, + "other": { + "type": "combined", + "local": "default", + "remote": "default" + } + }); + + let blob_services_configs = + with_registry(®, || serde_json::from_value(blob_services_configs_json)).unwrap(); + let mut blob_service_composition = Composition::new(®); + blob_service_composition.extend_with_configs::<dyn BlobService>(blob_services_configs); + match blob_service_composition + .build::<dyn BlobService>("default") + .await + { + Err(CompositionError::Recursion(stack)) => { + assert_eq!(stack, vec!["default".to_string(), "other".to_string()]) + } + other => panic!("should have returned an error, returned: {:?}", other.err()), + } + } +} |