diff --git a/src/core/async_graphql_hyper.rs b/src/core/async_graphql_hyper.rs index 5518ea5da7..dd7be32c34 100644 --- a/src/core/async_graphql_hyper.rs +++ b/src/core/async_graphql_hyper.rs @@ -408,6 +408,17 @@ impl GraphQLArcResponse { pub fn into_response(self) -> Result> { self.build_response(StatusCode::OK, self.default_body()?) } + + /// Transforms a plain `GraphQLResponse` into a `Response`. + /// Differs as `to_response` by flattening the response's data + /// `{"data": {"user": {"name": "John"}}}` becomes `{"name": "John"}`. + pub fn into_rest_response(self) -> Result> { + if !self.response.is_ok() { + return self.build_response(StatusCode::INTERNAL_SERVER_ERROR, self.default_body()?); + } + + self.into_response() + } } #[cfg(test)] diff --git a/src/core/http/request_handler.rs b/src/core/http/request_handler.rs index e7ab5efae2..17a25233b4 100644 --- a/src/core/http/request_handler.rs +++ b/src/core/http/request_handler.rs @@ -241,20 +241,24 @@ async fn handle_rest_apis( *request.uri_mut() = request.uri().path().replace(API_URL_PREFIX, "").parse()?; let req_ctx = Arc::new(create_request_context(&request, app_ctx.as_ref())); if let Some(p_request) = app_ctx.endpoints.matches(&request) { + let (req, body) = request.into_parts(); let http_route = format!("{API_URL_PREFIX}{}", p_request.path.as_str()); req_counter.set_http_route(&http_route); let span = tracing::info_span!( "REST", - otel.name = format!("REST {} {}", request.method(), p_request.path.as_str()), + otel.name = format!("REST {} {}", req.method, p_request.path.as_str()), otel.kind = ?SpanKind::Server, - { HTTP_REQUEST_METHOD } = %request.method(), + { HTTP_REQUEST_METHOD } = %req.method, { HTTP_ROUTE } = http_route ); return async { - let graphql_request = p_request.into_request(request).await?; + let mut graphql_request = p_request.into_request(body).await?; + let operation_id = graphql_request.operation_id(&req.headers); + let exec = JITExecutor::new(app_ctx.clone(), req_ctx.clone(), operation_id) + .flatten_response(true); let mut response = graphql_request .data(req_ctx.clone()) - .execute(&app_ctx.schema) + .execute_with_jit(exec) .await .set_cache_control( app_ctx.blueprint.server.enable_cache_control_header, diff --git a/src/core/jit/error.rs b/src/core/jit/error.rs index 086b604502..25a6c388ee 100644 --- a/src/core/jit/error.rs +++ b/src/core/jit/error.rs @@ -56,6 +56,12 @@ pub enum Error { Unknown, } +impl From for Error { + fn from(value: async_graphql::ServerError) -> Self { + Self::ServerError(value) + } +} + impl ErrorExtensions for Error { fn extend(&self) -> super::graphql_error::Error { match self { diff --git a/src/core/jit/exec_const.rs b/src/core/jit/exec_const.rs index 1606f05631..b16b7fde5f 100644 --- a/src/core/jit/exec_const.rs +++ b/src/core/jit/exec_const.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use async_graphql_value::{ConstValue, Value}; +use derive_setters::Setters; use futures_util::future::join_all; use tailcall_valid::Validator; @@ -8,7 +9,6 @@ use super::context::Context; use super::exec::{Executor, IRExecutor}; use super::graphql_error::GraphQLError; use super::{transform, AnyResponse, BuildError, Error, OperationPlan, Request, Response, Result}; -use crate::core::app_context::AppContext; use crate::core::http::RequestContext; use crate::core::ir::model::IR; use crate::core::ir::{self, EmptyResolverContext, EvalContext}; @@ -16,15 +16,19 @@ use crate::core::jit::synth::Synth; use crate::core::jit::transform::InputResolver; use crate::core::json::{JsonLike, JsonLikeList}; use crate::core::Transform; +use crate::core::{app_context::AppContext, json::JsonObjectLike}; /// A specialized executor that executes with async_graphql::Value +#[derive(Setters)] pub struct ConstValueExecutor { pub plan: OperationPlan, + + flatten_response: bool, } impl From> for ConstValueExecutor { fn from(plan: OperationPlan) -> Self { - Self { plan } + Self { plan, flatten_response: false } } } @@ -56,6 +60,7 @@ impl ConstValueExecutor { let is_introspection_query = req_ctx.server.get_enable_introspection() && self.plan.is_introspection_query; + let flatten_response = self.flatten_response; let variables = &request.variables; // Attempt to skip unnecessary fields @@ -102,13 +107,45 @@ impl ConstValueExecutor { let async_req = async_graphql::Request::from(request).only_introspection(); let async_resp = app_ctx.execute(async_req).await; - resp.merge_with(&async_resp).into() + to_any_response(resp.merge_with(&async_resp), flatten_response) } else { - resp.into() + to_any_response(resp, flatten_response) } } } +fn to_any_response( + resp: Response, + flatten: bool, +) -> AnyResponse> { + if flatten { + if resp.errors.is_empty() { + AnyResponse { + body: Arc::new( + serde_json::to_vec(flatten_response(&resp.data)).unwrap_or_default(), + ), + is_ok: true, + cache_control: resp.cache_control, + } + } else { + AnyResponse { + body: Arc::new(serde_json::to_vec(&resp).unwrap_or_default()), + is_ok: false, + cache_control: resp.cache_control, + } + } + } else { + resp.into() + } +} + +fn flatten_response<'a, T: JsonLike<'a>>(data: &'a T) -> &'a T { + match data.as_object() { + Some(obj) if obj.len() == 1 => flatten_response(obj.iter().next().unwrap().1), + _ => data, + } +} + struct ConstValueExec<'a> { plan: &'a OperationPlan, req_context: &'a RequestContext, diff --git a/src/core/jit/graphql_error.rs b/src/core/jit/graphql_error.rs index 6d2e0a7132..ade6b4af61 100644 --- a/src/core/jit/graphql_error.rs +++ b/src/core/jit/graphql_error.rs @@ -53,6 +53,10 @@ impl From> for GraphQLError { return e.into(); } + if let super::Error::ServerError(e) = inner_value { + return e.into(); + } + let ext = inner_value.extend().extensions; let mut server_error = GraphQLError::new(inner_value.to_string(), Some(position)); server_error.extensions = ext; diff --git a/src/core/jit/graphql_executor.rs b/src/core/jit/graphql_executor.rs index e1c131c374..bf2c7546be 100644 --- a/src/core/jit/graphql_executor.rs +++ b/src/core/jit/graphql_executor.rs @@ -5,21 +5,23 @@ use std::sync::Arc; use async_graphql::{BatchRequest, Value}; use async_graphql_value::{ConstValue, Extensions}; +use derive_setters::Setters; use futures_util::stream::FuturesOrdered; use futures_util::StreamExt; use tailcall_hasher::TailcallHasher; use super::{AnyResponse, BatchResponse, Response}; -use crate::core::app_context::AppContext; +use crate::core::{app_context::AppContext, async_graphql_hyper::GraphQLRequest}; use crate::core::async_graphql_hyper::OperationId; use crate::core::http::RequestContext; use crate::core::jit::{self, ConstValueExecutor, OPHash, Pos, Positioned}; -#[derive(Clone)] +#[derive(Clone, Setters)] pub struct JITExecutor { app_ctx: Arc, req_ctx: Arc, operation_id: OperationId, + flatten_response: bool, } impl JITExecutor { @@ -28,7 +30,7 @@ impl JITExecutor { req_ctx: Arc, operation_id: OperationId, ) -> Self { - Self { app_ctx, req_ctx, operation_id } + Self { app_ctx, req_ctx, operation_id, flatten_response: false } } #[inline(always)] @@ -62,21 +64,20 @@ impl JITExecutor { } #[inline(always)] - fn req_hash(request: &async_graphql::Request) -> OPHash { + fn req_hash(request: &impl Hash) -> OPHash { let mut hasher = TailcallHasher::default(); - request.query.hash(&mut hasher); + request.hash(&mut hasher); OPHash::new(hasher.finish()) } } impl JITExecutor { - pub fn execute( - &self, - request: async_graphql::Request, - ) -> impl Future>> + Send + '_ { - // TODO: hash considering only the query itself ignoring specified operation and - // variables that could differ for the same query + pub fn execute(&self, request: T) -> impl Future>> + Send + '_ + where + jit::Request: TryFrom, + T: Hash + Send + 'static, + { let hash = Self::req_hash(&request); async move { @@ -84,7 +85,14 @@ impl JITExecutor { return response.clone(); } - let jit_request = jit::Request::from(request); + let jit_request = match jit::Request::try_from(request) { + Ok(request) => request, + Err(error) => { + return Response::::default() + .with_errors(vec![Positioned::new(error, Pos::default())]) + .into() + } + }; let exec = if let Some(op) = self.app_ctx.operation_plans.get(&hash) { ConstValueExecutor::from(op.value().clone()) } else { @@ -102,6 +110,7 @@ impl JITExecutor { exec }; + let exec = exec.flatten_response(self.flatten_response); let is_const = exec.plan.is_const; let is_protected = exec.plan.is_protected; @@ -125,10 +134,10 @@ impl JITExecutor { /// Execute a GraphQL batch query. pub async fn execute_batch(&self, batch_request: BatchRequest) -> BatchResponse> { match batch_request { - BatchRequest::Single(request) => BatchResponse::Single(self.execute(request).await), + BatchRequest::Single(request) => BatchResponse::Single(self.execute(GraphQLRequest(request)).await), BatchRequest::Batch(requests) => { let futs = FuturesOrdered::from_iter( - requests.into_iter().map(|request| self.execute(request)), + requests.into_iter().map(|request| self.execute(GraphQLRequest(request))), ); let responses = futs.collect::>().await; BatchResponse::Batch(responses) diff --git a/src/core/jit/model.rs b/src/core/jit/model.rs index 9b2950a22a..d69e9d571f 100644 --- a/src/core/jit/model.rs +++ b/src/core/jit/model.rs @@ -590,7 +590,7 @@ mod test { let bp = Blueprint::try_from(&module).unwrap(); let request = Request::new(query); - let jit_request = jit::Request::from(request); + let jit_request = jit::Request::try_from(request).unwrap(); jit_request.create_plan(&bp).unwrap() } diff --git a/src/core/jit/request.rs b/src/core/jit/request.rs index f37c4a721e..4f0a81da42 100644 --- a/src/core/jit/request.rs +++ b/src/core/jit/request.rs @@ -1,38 +1,46 @@ use std::collections::HashMap; use std::ops::DerefMut; +use async_graphql::parser::types::ExecutableDocument; use async_graphql_value::ConstValue; -use serde::Deserialize; use tailcall_valid::Validator; use super::{transform, Builder, OperationPlan, Result, Variables}; -use crate::core::blueprint::Blueprint; use crate::core::transform::TransformerOps; use crate::core::Transform; +use crate::core::{async_graphql_hyper::GraphQLRequest, blueprint::Blueprint}; -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Clone)] pub struct Request { - #[serde(default)] pub query: String, - #[serde(default, rename = "operationName")] pub operation_name: Option, - #[serde(default)] pub variables: Variables, - #[serde(default)] pub extensions: HashMap, + pub parsed_query: ExecutableDocument, } // NOTE: This is hot code and should allocate minimal memory -impl From for Request { - fn from(mut value: async_graphql::Request) -> Self { +impl TryFrom for Request { + type Error = super::Error; + + fn try_from(mut value: async_graphql::Request) -> Result { let variables = std::mem::take(value.variables.deref_mut()); - Self { + Ok(Self { + parsed_query: value.parsed_query()?.clone(), query: value.query, operation_name: value.operation_name, variables: Variables::from_iter(variables.into_iter().map(|(k, v)| (k.to_string(), v))), extensions: value.extensions.0, - } + }) + } +} + +impl TryFrom for Request { + type Error = super::Error; + + fn try_from(value: GraphQLRequest) -> Result { + Self::try_from(value.0) } } @@ -41,8 +49,7 @@ impl Request { &self, blueprint: &Blueprint, ) -> Result> { - let doc = async_graphql::parser::parse_query(&self.query)?; - let builder = Builder::new(blueprint, &doc); + let builder = Builder::new(blueprint, &self.parsed_query); let plan = builder.build(self.operation_name.as_deref())?; transform::CheckConst::new() @@ -67,6 +74,7 @@ impl Request { operation_name: None, variables: Variables::new(), extensions: HashMap::new(), + parsed_query: async_graphql::parser::parse_query(query).unwrap(), } } diff --git a/src/core/rest/partial_request.rs b/src/core/rest/partial_request.rs index cfd8053483..aaa2a31658 100644 --- a/src/core/rest/partial_request.rs +++ b/src/core/rest/partial_request.rs @@ -1,9 +1,10 @@ use async_graphql::parser::types::ExecutableDocument; use async_graphql::{Name, Variables}; use async_graphql_value::ConstValue; +use hyper::Body; use super::path::Path; -use super::{Request, Result}; +use super::Result; use crate::core::async_graphql_hyper::GraphQLRequest; /// A partial GraphQLRequest that contains a parsed executable GraphQL document. @@ -16,10 +17,10 @@ pub struct PartialRequest<'a> { } impl PartialRequest<'_> { - pub async fn into_request(self, request: Request) -> Result { + pub async fn into_request(self, body: Body) -> Result { let mut variables = self.variables; if let Some(key) = self.body { - let bytes = hyper::body::to_bytes(request.into_body()).await?; + let bytes = hyper::body::to_bytes(body).await?; let body: ConstValue = serde_json::from_slice(&bytes)?; variables.insert(Name::new(key), body); }