diff --git a/Cargo.lock b/Cargo.lock index 73943d62b5..060d749c49 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1599,7 +1599,6 @@ dependencies = [ "async-graphql-value", "async-trait", "dotenvy", - "futures", "hyper 0.14.30", "reqwest", "serde_json", diff --git a/examples/extension-i18n/Cargo.toml b/examples/extension-i18n/Cargo.toml index 81433494ec..31b0bda98d 100644 --- a/examples/extension-i18n/Cargo.toml +++ b/examples/extension-i18n/Cargo.toml @@ -11,10 +11,9 @@ anyhow = { workspace = true } tokio = { workspace = true } async-graphql = { workspace = true } async-graphql-value = "7.0.3" -futures = "0.3.30" +async-trait = "0.1.80" [dev-dependencies] reqwest = { workspace = true } serde_json = { workspace = true } -async-trait = "0.1.80" hyper = { version = "0.14.28", default-features = false } \ No newline at end of file diff --git a/examples/extension-i18n/src/main.rs b/examples/extension-i18n/src/main.rs index 5944e54ddb..67993ee0b0 100644 --- a/examples/extension-i18n/src/main.rs +++ b/examples/extension-i18n/src/main.rs @@ -1,61 +1,100 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use async_graphql_value::ConstValue; use dotenvy::dotenv; -use futures::executor::block_on; use tailcall::cli::runtime; use tailcall::cli::server::Server; -use tailcall::core::blueprint::{Blueprint, ExtensionLoader}; +use tailcall::core::blueprint::{Blueprint, ExtensionTrait, PrepareContext, ProcessContext}; use tailcall::core::config::reader::ConfigReader; use tailcall::core::config::KeyValue; use tailcall::core::helpers::headers::to_mustache_headers; use tailcall::core::valid::Validator; #[derive(Clone, Debug)] -pub struct TranslateExtension; +pub struct TranslateExtension { + pub load_counter: Arc>, + pub prepare_counter: Arc>, + pub process_counter: Arc>, +} + +impl Default for TranslateExtension { + fn default() -> Self { + Self { + load_counter: Arc::new(Mutex::new(0)), + prepare_counter: Arc::new(Mutex::new(0)), + process_counter: Arc::new(Mutex::new(0)), + } + } +} -impl ExtensionLoader for TranslateExtension { - fn load(&self) {} +#[async_trait::async_trait] +impl ExtensionTrait for TranslateExtension { + fn load(&self) { + *(self.load_counter.lock().unwrap()) += 1; + } - fn prepare( + async fn prepare( &self, - ir: Box, - _params: ConstValue, + context: PrepareContext, ) -> Box { - ir + *(self.prepare_counter.lock().unwrap()) += 1; + context.ir } - fn process( + async fn process( &self, - _params: ConstValue, - value: ConstValue, + context: ProcessContext, ) -> Result { - if let ConstValue::String(value) = value { - let new_value = block_on(translate(&value)); + *(self.process_counter.lock().unwrap()) += 1; + if let ConstValue::String(value) = context.value { + let new_value = match value.as_str() { + "Multi-layered client-server neural-net" => { + "Red neuronal cliente-servidor multicapa".to_string() + } + "Leanne Graham" => "Leona Grahm".to_string(), + _ => value.to_string(), + }; Ok(ConstValue::String(new_value)) } else { - Ok(value) + Ok(context.value) } } } #[derive(Clone, Debug)] -pub struct ModifyIrExtension; +pub struct ModifyIrExtension { + pub load_counter: Arc>, + pub prepare_counter: Arc>, + pub process_counter: Arc>, +} + +impl Default for ModifyIrExtension { + fn default() -> Self { + Self { + load_counter: Arc::new(Mutex::new(0)), + prepare_counter: Arc::new(Mutex::new(0)), + process_counter: Arc::new(Mutex::new(0)), + } + } +} -impl ExtensionLoader for ModifyIrExtension { - fn load(&self) {} +#[async_trait::async_trait] +impl ExtensionTrait for ModifyIrExtension { + fn load(&self) { + *(self.load_counter.lock().unwrap()) += 1; + } - fn prepare( + async fn prepare( &self, - ir: Box, - _params: ConstValue, + context: PrepareContext, ) -> Box { + *(self.prepare_counter.lock().unwrap()) += 1; if let tailcall::core::ir::model::IR::IO(tailcall::core::ir::model::IO::Http { req_template, group_by, dl_id, http_filter, - }) = *ir + }) = *context.ir { let mut req_template = req_template; let headers = to_mustache_headers(&[KeyValue { @@ -78,33 +117,23 @@ impl ExtensionLoader for ModifyIrExtension { }); Box::new(ir) } else { - ir + context.ir } } - fn process( + async fn process( &self, - _params: ConstValue, - value: ConstValue, + context: ProcessContext, ) -> Result { - Ok(value) - } -} - -async fn translate(value: &str) -> String { - match value { - "Multi-layered client-server neural-net" => { - "Red neuronal cliente-servidor multicapa".to_string() - } - "Leanne Graham" => "Leona Grahm".to_string(), - _ => value.to_string(), + *(self.process_counter.lock().unwrap()) += 1; + Ok(context.value) } } #[tokio::main] async fn main() -> anyhow::Result<()> { - let translate_ext = Arc::new(TranslateExtension {}); - let modify_ir_ext = Arc::new(ModifyIrExtension {}); + let translate_ext = Arc::new(TranslateExtension::default()); + let modify_ir_ext = Arc::new(ModifyIrExtension::default()); if let Ok(path) = dotenv() { tracing::info!("Env file: {:?} loaded", path); } @@ -128,7 +157,6 @@ async fn main() -> anyhow::Result<()> { #[cfg(test)] mod tests { use hyper::{Body, Request}; - use serde_json::json; use tailcall::core::app_context::AppContext; use tailcall::core::async_graphql_hyper::GraphQLRequest; @@ -180,8 +208,8 @@ mod tests { #[tokio::test] async fn test_tailcall_extensions() { - let translate_ext = Arc::new(TranslateExtension {}); - let modify_ir_ext = Arc::new(ModifyIrExtension {}); + let translate_ext = Arc::new(TranslateExtension::default()); + let modify_ir_ext = Arc::new(ModifyIrExtension::default()); if let Ok(path) = dotenv() { tracing::info!("Env file: {:?} loaded", path); } @@ -194,10 +222,10 @@ mod tests { let mut extensions = config_module.extensions().clone(); extensions .plugin_extensions - .insert("translate".to_string(), translate_ext); + .insert("translate".to_string(), translate_ext.clone()); extensions .plugin_extensions - .insert("modify_ir".to_string(), modify_ir_ext); + .insert("modify_ir".to_string(), modify_ir_ext.clone()); let config_module = config_module.merge_extensions(extensions); let blueprint = Blueprint::try_from(&config_module).unwrap(); let app_context = AppContext::new(blueprint, runtime, EndpointSet::default()); @@ -236,5 +264,13 @@ mod tests { hyper::body::Bytes::from(expected_response.to_string()), "Unexpected response from server" ); + + assert_eq!(translate_ext.load_counter.lock().unwrap().to_owned(), 2); + assert_eq!(translate_ext.process_counter.lock().unwrap().to_owned(), 2); + assert_eq!(translate_ext.prepare_counter.lock().unwrap().to_owned(), 2); + + assert_eq!(modify_ir_ext.load_counter.lock().unwrap().to_owned(), 1); + assert_eq!(modify_ir_ext.process_counter.lock().unwrap().to_owned(), 1); + assert_eq!(modify_ir_ext.prepare_counter.lock().unwrap().to_owned(), 1); } } diff --git a/src/core/blueprint/operators/extension.rs b/src/core/blueprint/operators/extension.rs index b6b844e74e..96d1e9bca5 100644 --- a/src/core/blueprint/operators/extension.rs +++ b/src/core/blueprint/operators/extension.rs @@ -5,6 +5,7 @@ use crate::core::config; use crate::core::config::Field; use crate::core::ir::model::IR; use crate::core::ir::Error; +use crate::core::json::JsonLikeOwned; use crate::core::try_fold::TryFold; use crate::core::valid::Valid; @@ -54,14 +55,35 @@ pub fn update_extension<'a>( ) } -pub trait ExtensionLoader: std::fmt::Debug + Send + Sync { +pub type ExtensionLoader = dyn ExtensionTrait; + +#[async_trait::async_trait] +pub trait ExtensionTrait: std::fmt::Debug + Send + Sync { fn load(&self) {} - fn modify_inner(&self, ir: Box) -> Box { - ir + async fn prepare(&self, context: PrepareContext) -> Box; + + async fn process(&self, context: ProcessContext) -> Result; +} + +pub struct PrepareContext { + pub params: Json, + pub ir: Box, +} + +impl PrepareContext { + pub fn new(ir: Box, params: Json) -> Self { + Self { ir, params } } +} - fn prepare(&self, ir: Box, params: ConstValue) -> Box; +pub struct ProcessContext { + pub params: Json, + pub value: Json, +} - fn process(&self, params: ConstValue, value: ConstValue) -> Result; +impl ProcessContext { + pub fn new(params: Json, value: Json) -> Self { + Self { params, value } + } } diff --git a/src/core/config/config_module.rs b/src/core/config/config_module.rs index 2ecf9e62a5..4580ff8528 100644 --- a/src/core/config/config_module.rs +++ b/src/core/config/config_module.rs @@ -133,7 +133,7 @@ pub struct Extensions { pub jwks: Vec>, - pub plugin_extensions: HashMap>, + pub plugin_extensions: HashMap>, } impl Extensions { diff --git a/src/core/ir/eval.rs b/src/core/ir/eval.rs index 1e53f6ba8f..e9d86f80a3 100644 --- a/src/core/ir/eval.rs +++ b/src/core/ir/eval.rs @@ -6,6 +6,7 @@ use async_graphql_value::ConstValue; use super::eval_io::eval_io; use super::model::{Cache, CacheKey, Map, IR}; use super::{Error, EvalContext, ResolverContextLike}; +use crate::core::blueprint::{PrepareContext, ProcessContext}; use crate::core::json::JsonLike; use crate::core::serde_value_ext::ValueExt; @@ -96,14 +97,17 @@ impl IR { Ok(value) }), IR::Extension { plugin, params, ir } => { - let ir = plugin.prepare(ir.clone(), params.render_value(ctx)); + let context = PrepareContext::new(ir.clone(), params.render_value(ctx)); + let ir = plugin.prepare(context).await; let value = { let mut ctx = ctx.clone(); ir.eval(&mut ctx).await? }; - plugin.process(params.render_value(ctx), value) + let context = ProcessContext::new(params.render_value(ctx), value); + + plugin.process(context).await } } }) diff --git a/src/core/ir/model.rs b/src/core/ir/model.rs index e455976994..c6518f2ac6 100644 --- a/src/core/ir/model.rs +++ b/src/core/ir/model.rs @@ -29,7 +29,7 @@ pub enum IR { Discriminate(Discriminator, Box), Extension { // path: Vec, - plugin: Arc, + plugin: Arc, params: DynamicValue, ir: Box, }, @@ -156,10 +156,7 @@ impl IR { IR::Discriminate(discriminator, expr) => { IR::Discriminate(discriminator, expr.modify_box(modifier)) } - IR::Extension { plugin, params, ir } => { - let ir = plugin.modify_inner(ir); - IR::Extension { plugin, params, ir } - } + IR::Extension { ir: _, params: _, plugin: _ } => expr, } } }