diff --git a/src/core/data_loader/dedupe.rs b/src/core/data_loader/dedupe.rs index 4ab458aa4e..fc96bdae9d 100644 --- a/src/core/data_loader/dedupe.rs +++ b/src/core/data_loader/dedupe.rs @@ -8,8 +8,8 @@ use tokio::sync::broadcast; pub trait Key: Send + Sync + Eq + Hash + Clone {} impl Key for A {} -pub trait Value: Send + Sync + Clone {} -impl Value for A {} +pub trait Value: Send + Sync {} +impl Value for A {} /// /// Allows deduplication of async operations based on a key. @@ -25,25 +25,25 @@ pub struct Dedupe { /// Represents the current state of the operation. enum State { /// Means that the operation has been executed and the result is stored. - Ready(Value), + Ready(Arc), /// Means that the operation is in progress and the result can be sent via /// the stored sender whenever it's available in the future. - Pending(Weak>), + Pending(Weak>>), } /// Represents the next steps enum Step { /// The operation has been executed and the result must be returned. - Return(Value), + Return(Arc), /// The operation is in progress and the result must be awaited on the /// receiver. - Await(broadcast::Receiver), + Await(broadcast::Receiver>), /// The operation needs to be executed and the result needs to be sent to /// the provided sender. - Init(Arc>), + Init(Arc>>), } impl Dedupe { @@ -51,10 +51,10 @@ impl Dedupe { Self { cache: Arc::new(Mutex::new(HashMap::new())), size, persist } } - pub async fn dedupe<'a, Fn, Fut>(&'a self, key: &'a K, or_else: Fn) -> V + pub async fn dedupe<'a, Fn, Fut>(&'a self, key: &'a K, or_else: Fn) -> Arc where Fn: FnOnce() -> Fut, - Fut: Future, + Fut: Future>, { loop { let value = match self.step(key) { @@ -123,10 +123,10 @@ impl DedupeResult { } impl DedupeResult { - pub async fn dedupe<'a, Fn, Fut>(&'a self, key: &'a K, or_else: Fn) -> Result + pub async fn dedupe<'a, Fn, Fut>(&'a self, key: &'a K, or_else: Fn) -> Arc> where Fn: FnOnce() -> Fut, - Fut: Future>, + Fut: Future>>, { self.0.dedupe(key, or_else).await } @@ -147,17 +147,17 @@ mod tests { #[tokio::test] async fn test_no_key() { let cache = Arc::new(Dedupe::::new(1000, true)); - let actual = cache.dedupe(&1, || Box::pin(async { 1 })).await; - assert_eq!(actual, 1); + let actual = cache.dedupe(&1, || Box::pin(async { Arc::new(1) })).await; + assert_eq!(*actual, 1); } #[tokio::test] async fn test_with_key() { let cache = Arc::new(Dedupe::::new(1000, true)); - cache.dedupe(&1, || Box::pin(async { 1 })).await; + cache.dedupe(&1, || Box::pin(async { Arc::new(1) })).await; - let actual = cache.dedupe(&1, || Box::pin(async { 2 })).await; - assert_eq!(actual, 1); + let actual = cache.dedupe(&1, || Box::pin(async { Arc::new(2) })).await; + assert_eq!(*actual, 1); } #[tokio::test] @@ -165,11 +165,13 @@ mod tests { let cache = Arc::new(Dedupe::::new(1000, true)); for i in 0..100 { - cache.dedupe(&1, || Box::pin(async move { i })).await; + cache + .dedupe(&1, || Box::pin(async move { Arc::new(i) })) + .await; } - let actual = cache.dedupe(&1, || Box::pin(async { 2 })).await; - assert_eq!(actual, 0); + let actual = cache.dedupe(&1, || Box::pin(async { Arc::new(2) })).await; + assert_eq!(*actual, 0); } #[tokio::test] @@ -179,13 +181,13 @@ mod tests { let a = cache.dedupe(&1, || { Box::pin(async move { sleep(Duration::from_millis(1)).await; - 1 + Arc::new(1) }) }); let b = cache.dedupe(&1, || { Box::pin(async move { sleep(Duration::from_millis(1)).await; - 2 + Arc::new(2) }) }); let (a, b) = join!(a, b); @@ -193,10 +195,10 @@ mod tests { assert_eq!(a, b); } - async fn compute_value(counter: Arc) -> String { + async fn compute_value(counter: Arc) -> Arc { counter.fetch_add(1, Ordering::SeqCst); sleep(Duration::from_millis(1)).await; - format!("value_{}", counter.load(Ordering::SeqCst)) + Arc::new(format!("value_{}", counter.load(Ordering::SeqCst))) } #[tokio::test(worker_threads = 16, flavor = "multi_thread")] @@ -237,6 +239,7 @@ mod tests { let task = cache.dedupe(&1, move || async move { sleep(Duration::from_millis(100)).await; + Arc::new(()) }); // drops the task since the underlying sleep timeout is higher than the @@ -249,6 +252,7 @@ mod tests { cache .dedupe(&1, move || async move { sleep(Duration::from_millis(100)).await; + Arc::new(()) }) .await; } @@ -263,7 +267,7 @@ mod tests { cache_1 .dedupe(&1, move || async move { sleep(Duration::from_millis(100)).await; - 100 + Arc::new(100) }) .await }); @@ -272,7 +276,7 @@ mod tests { cache_2 .dedupe(&1, move || async move { sleep(Duration::from_millis(100)).await; - 200 + Arc::new(200) }) .await }); @@ -283,7 +287,7 @@ mod tests { task_1.abort(); let actual = task_2.await.unwrap(); - assert_eq!(actual, 200) + assert_eq!(*actual, 200) } // TODO: This is a failing test @@ -313,6 +317,7 @@ mod tests { .dedupe(&1, move || async move { sleep(Duration::from_millis(100)).await; status_1.lock().unwrap().call_1 = true; + Arc::new(()) }) .await }); @@ -326,6 +331,7 @@ mod tests { .dedupe(&1, move || async move { sleep(Duration::from_millis(120)).await; status_2.lock().unwrap().call_2 = true; + Arc::new(()) }) .await }); @@ -368,6 +374,7 @@ mod tests { .dedupe(&1, move || async move { sleep(Duration::from_millis(100)).await; status_1.lock().unwrap().call_1 = true; + Arc::new(()) }) .await }); @@ -378,6 +385,7 @@ mod tests { .dedupe(&1, move || async move { sleep(Duration::from_millis(150)).await; status_2.lock().unwrap().call_2 = true; + Arc::new(()) }) .await }); diff --git a/src/core/helpers/value.rs b/src/core/helpers/value.rs index 26a972932f..9cb5801e47 100644 --- a/src/core/helpers/value.rs +++ b/src/core/helpers/value.rs @@ -1,4 +1,7 @@ -use std::hash::{Hash, Hasher}; +use std::{ + hash::{Hash, Hasher}, + sync::Arc, +}; use async_graphql_value::ConstValue; @@ -36,3 +39,10 @@ pub fn hash(const_value: &ConstValue, state: &mut H) { } } } + +pub fn arc_result_to_result(arc_result: Arc>) -> Result { + match &*arc_result { + Ok(t) => Ok(t.clone()), + Err(e) => Err(e.clone()), + } +} diff --git a/src/core/ir/eval.rs b/src/core/ir/eval.rs index e7b0a8c179..da9d566672 100644 --- a/src/core/ir/eval.rs +++ b/src/core/ir/eval.rs @@ -10,6 +10,7 @@ use super::eval_io::eval_io; use super::model::{Cache, CacheKey, Map, IR}; use super::{Error, EvalContext, ResolverContextLike, TypedValue}; use crate::core::auth::verify::{AuthVerifier, Verify}; +use crate::core::helpers::value::arc_result_to_result; use crate::core::json::{JsonLike, JsonObjectLike}; use crate::core::merge_right::MergeRight; use crate::core::serde_value_ext::ValueExt; @@ -43,7 +44,7 @@ impl IR { expr.eval(ctx).await } - IR::IO(io) => eval_io(io, ctx).await, + IR::IO(io) => arc_result_to_result(eval_io(io, ctx).await), IR::Cache(Cache { max_age, io }) => { let io = io.deref(); let key = io.cache_key(ctx); @@ -51,7 +52,8 @@ impl IR { if let Some(val) = ctx.request_ctx.runtime.cache.get(&key).await? { Ok(val) } else { - let val = eval_io(io, ctx).await?; + let result = arc_result_to_result(eval_io(io, ctx).await); + let val = result?; ctx.request_ctx .runtime .cache @@ -60,7 +62,7 @@ impl IR { Ok(val) } } else { - eval_io(io, ctx).await + arc_result_to_result(eval_io(io, ctx).await) } } IR::Map(Map { input, map }) => { diff --git a/src/core/ir/eval_io.rs b/src/core/ir/eval_io.rs index f9ef59b3da..11a11d192d 100644 --- a/src/core/ir/eval_io.rs +++ b/src/core/ir/eval_io.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_graphql_value::ConstValue; use super::eval_http::{ @@ -14,7 +16,7 @@ use crate::core::grpc::data_loader::GrpcDataLoader; use crate::core::http::DataLoaderRequest; use crate::core::ir::Error; -pub async fn eval_io(io: &IO, ctx: &mut EvalContext<'_, Ctx>) -> Result +pub async fn eval_io(io: &IO, ctx: &mut EvalContext<'_, Ctx>) -> Arc> where Ctx: ResolverContextLike + Sync, { @@ -23,7 +25,7 @@ where let dedupe = io.dedupe(); if !dedupe || !ctx.is_query() { - return eval_io_inner(io, ctx).await; + return eval_io_inner_arc(io, ctx).await; } if let Some(key) = io.cache_key(ctx) { ctx.request_ctx @@ -31,12 +33,12 @@ where .dedupe(&key, || async { ctx.request_ctx .dedupe_handler - .dedupe(&key, || eval_io_inner(io, ctx)) + .dedupe(&key, || eval_io_inner_arc(io, ctx)) .await }) .await } else { - eval_io_inner(io, ctx).await + eval_io_inner_arc(io, ctx).await } } @@ -116,3 +118,13 @@ where } } } + +async fn eval_io_inner_arc( + io: &IO, + ctx: &mut EvalContext<'_, Ctx>, +) -> Arc> +where + Ctx: ResolverContextLike + Sync, +{ + return Arc::new(eval_io_inner(io, ctx).await); +} diff --git a/src/core/jit/graphql_executor.rs b/src/core/jit/graphql_executor.rs index e1c131c374..5533ee0b1a 100644 --- a/src/core/jit/graphql_executor.rs +++ b/src/core/jit/graphql_executor.rs @@ -12,6 +12,7 @@ use tailcall_hasher::TailcallHasher; use super::{AnyResponse, BatchResponse, Response}; use crate::core::app_context::AppContext; use crate::core::async_graphql_hyper::OperationId; +use crate::core::helpers::value::arc_result_to_result; use crate::core::http::RequestContext; use crate::core::jit::{self, ConstValueExecutor, OPHash, Pos, Positioned}; @@ -53,12 +54,12 @@ impl JITExecutor { .dedupe(&self.operation_id, || { Box::pin(async move { let resp = self.exec(exec, jit_request).await; - Ok(resp) + Arc::new(Ok(resp)) }) }) .await; - out.unwrap_or_default() + arc_result_to_result(out).unwrap_or_default() } #[inline(always)]