diff --git a/Cargo.lock b/Cargo.lock index 86ce4851a6..b512556e63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -180,6 +180,7 @@ dependencies = [ "opentelemetry-semantic-conventions", "opentelemetry-zipkin", "paste", + "pin-project-lite", "prometheus", "regex", "reqwest", @@ -3873,9 +3874,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e280fbe77cc62c91527259e9442153f4688736748d24660126286329742b4c6c" +checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" [[package]] name = "pin-utils" diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index a1744ce343..3c2ff71e7a 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -29,6 +29,29 @@ By [@USERNAME](https://github.com/USERNAME) in https://github.com/apollographql/ ## 🚀 Features +### Add support of global rate limit and timeout. [PR #1347](https://github.com/apollographql/router/pull/1347) + +Additions to the traffic shaping plugin: +- **Global rate limit** - If you want to rate limit requests to subgraphs or to the router itself. +- **Timeout**: - Set a timeout to subgraphs and router requests. + +```yaml +traffic_shaping: + router: # Rules applied to requests from clients to the router + global_rate_limit: # Accept a maximum of 10 requests per 5 secs. Excess requests must be rejected. + capacity: 10 + interval: 5s # Must not be greater than 18_446_744_073_709_551_615 milliseconds and not less than 0 milliseconds + timeout: 50s # If a request to the router takes more than 50secs then cancel the request (30 sec by default) + subgraphs: # Rules applied to requests from the router to individual subgraphs + products: + global_rate_limit: # Accept a maximum of 10 requests per 5 secs from the router. Excess requests must be rejected. + capacity: 10 + interval: 5s # Must not be greater than 18_446_744_073_709_551_615 milliseconds and not less than 0 milliseconds + timeout: 50s # If a request to the subgraph 'products' takes more than 50secs then cancel the request (30 sec by default) +``` + +By [@bnjjj](https://github.com/bnjjj) in https://github.com/apollographql/router/pull/1347 + ## 🐛 Fixes ## 🛠 Maintenance diff --git a/apollo-router/Cargo.toml b/apollo-router/Cargo.toml index 644b3ded15..9495c6dedb 100644 --- a/apollo-router/Cargo.toml +++ b/apollo-router/Cargo.toml @@ -133,6 +133,7 @@ tower-http = { version = "0.3.4", features = [ "decompression-br", "decompression-deflate", "decompression-gzip", + "timeout" ] } tower-service = "0.3.2" tower-test = "0.4.0" @@ -144,7 +145,7 @@ tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json"] } url = { version = "2.2.2", features = ["serde"] } urlencoding = "2.1.0" yaml-rust = "0.4.5" - +pin-project-lite = "0.2.9" [target.'cfg(macos)'.dependencies] uname = "0.1.1" diff --git a/apollo-router/src/axum_http_server_factory.rs b/apollo-router/src/axum_http_server_factory.rs index 5e22b529eb..14faf10e20 100644 --- a/apollo-router/src/axum_http_server_factory.rs +++ b/apollo-router/src/axum_http_server_factory.rs @@ -62,6 +62,8 @@ use crate::http_server_factory::HttpServerHandle; use crate::http_server_factory::Listener; use crate::http_server_factory::NetworkStream; use crate::plugin::Handler; +use crate::plugins::traffic_shaping::Elapsed; +use crate::plugins::traffic_shaping::RateLimited; use crate::router::ApolloRouterError; use crate::router_factory::RouterServiceFactory; @@ -504,6 +506,14 @@ where match service.call(Request::from_parts(head, body).into()).await { Err(e) => { + if let Some(source_err) = e.source() { + if source_err.is::() { + return RateLimited::new().into_response(); + } + if source_err.is::() { + return Elapsed::new().into_response(); + } + } tracing::error!("router service call failed: {}", e); ( StatusCode::INTERNAL_SERVER_ERROR, @@ -559,6 +569,15 @@ where } Err(e) => { tracing::error!("router service is not available to process request: {}", e); + if let Some(source_err) = e.source() { + if source_err.is::() { + return RateLimited::new().into_response(); + } + if source_err.is::() { + return Elapsed::new().into_response(); + } + } + ( StatusCode::SERVICE_UNAVAILABLE, "router service is not available to process request", diff --git a/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap b/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap index eaa127e84b..622c2fa8e1 100644 --- a/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap +++ b/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap @@ -1,5 +1,6 @@ --- source: apollo-router/src/configuration/mod.rs +assertion_line: 914 expression: "&schema" --- { @@ -1887,6 +1888,7 @@ expression: "&schema" "type": "object", "properties": { "all": { + "description": "Applied on all subgraphs", "type": "object", "properties": { "compression": { @@ -1899,16 +1901,79 @@ expression: "&schema" ], "nullable": true }, + "global_rate_limit": { + "description": "Enable global rate limiting", + "type": "object", + "required": [ + "capacity", + "interval" + ], + "properties": { + "capacity": { + "description": "Number of requests allowed", + "type": "integer", + "format": "uint64", + "minimum": 1.0 + }, + "interval": { + "description": "Per interval", + "type": "string" + } + }, + "additionalProperties": false, + "nullable": true + }, "query_deduplication": { "description": "Enable query deduplication", "type": "boolean", "nullable": true + }, + "timeout": { + "description": "Enable timeout for incoming requests", + "default": null, + "type": "string" + } + }, + "additionalProperties": false, + "nullable": true + }, + "router": { + "description": "Applied at the router level", + "type": "object", + "properties": { + "global_rate_limit": { + "description": "Enable global rate limiting", + "type": "object", + "required": [ + "capacity", + "interval" + ], + "properties": { + "capacity": { + "description": "Number of requests allowed", + "type": "integer", + "format": "uint64", + "minimum": 1.0 + }, + "interval": { + "description": "Per interval", + "type": "string" + } + }, + "additionalProperties": false, + "nullable": true + }, + "timeout": { + "description": "Enable timeout for incoming requests", + "default": null, + "type": "string" } }, "additionalProperties": false, "nullable": true }, "subgraphs": { + "description": "Applied on specific subgraphs", "type": "object", "additionalProperties": { "type": "object", @@ -1923,10 +1988,37 @@ expression: "&schema" ], "nullable": true }, + "global_rate_limit": { + "description": "Enable global rate limiting", + "type": "object", + "required": [ + "capacity", + "interval" + ], + "properties": { + "capacity": { + "description": "Number of requests allowed", + "type": "integer", + "format": "uint64", + "minimum": 1.0 + }, + "interval": { + "description": "Per interval", + "type": "string" + } + }, + "additionalProperties": false, + "nullable": true + }, "query_deduplication": { "description": "Enable query deduplication", "type": "boolean", "nullable": true + }, + "timeout": { + "description": "Enable timeout for incoming requests", + "default": null, + "type": "string" } }, "additionalProperties": false diff --git a/apollo-router/src/plugins/telemetry/mod.rs b/apollo-router/src/plugins/telemetry/mod.rs index befd75a6ba..bdf0978b9a 100644 --- a/apollo-router/src/plugins/telemetry/mod.rs +++ b/apollo-router/src/plugins/telemetry/mod.rs @@ -225,8 +225,6 @@ impl Plugin for Telemetry { // The trace provider will not be shut down if drop is not called and it will result in a hang. // Don't add anything fallible after the tracer provider has been created. let tracer_provider = Self::create_tracer_provider(&config)?; - // - // let metrics_response_queries = Self::create_metrics_queries(&config)?; let plugin = Ok(Telemetry { spaceport_shutdown: shutdown_tx, @@ -236,7 +234,6 @@ impl Plugin for Telemetry { meter_provider: builder.meter_provider(), apollo_metrics_sender: builder.apollo_metrics_provider(), config, - // metrics_response_queries, }); // We're safe now for shutdown. diff --git a/apollo-router/src/plugins/traffic_shaping/mod.rs b/apollo-router/src/plugins/traffic_shaping/mod.rs index c1b24ee7bc..973d18ef8f 100644 --- a/apollo-router/src/plugins/traffic_shaping/mod.rs +++ b/apollo-router/src/plugins/traffic_shaping/mod.rs @@ -10,8 +10,13 @@ //! mod deduplication; +mod rate; +mod timeout; use std::collections::HashMap; +use std::num::NonZeroU64; +use std::sync::Mutex; +use std::time::Duration; use http::header::ACCEPT_ENCODING; use http::header::CONTENT_ENCODING; @@ -23,17 +28,29 @@ use tower::BoxError; use tower::ServiceBuilder; use tower::ServiceExt; +use self::rate::RateLimitLayer; +pub(crate) use self::rate::RateLimited; +pub(crate) use self::timeout::Elapsed; +use self::timeout::TimeoutLayer; use crate::layers::ServiceBuilderExt; use crate::plugin::Plugin; use crate::plugin::PluginInit; use crate::plugins::traffic_shaping::deduplication::QueryDeduplicationLayer; use crate::register_plugin; use crate::services::subgraph_service::Compression; +use crate::services::RouterRequest; +use crate::services::RouterResponse; +use crate::ConfigurationError; use crate::QueryPlannerRequest; use crate::QueryPlannerResponse; use crate::SubgraphRequest; use crate::SubgraphResponse; +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); +trait Merge { + fn merge(&self, fallback: Option<&Self>) -> Self; +} + #[derive(PartialEq, Debug, Clone, Deserialize, JsonSchema)] #[serde(deny_unknown_fields)] struct Shaping { @@ -41,33 +58,86 @@ struct Shaping { query_deduplication: Option, /// Enable compression for subgraphs (available compressions are deflate, br, gzip) compression: Option, + /// Enable global rate limiting + global_rate_limit: Option, + #[serde(deserialize_with = "humantime_serde::deserialize", default)] + #[schemars(with = "String", default)] + /// Enable timeout for incoming requests + timeout: Option, } -impl Shaping { - fn merge(&self, fallback: Option<&Shaping>) -> Shaping { +impl Merge for Shaping { + fn merge(&self, fallback: Option<&Self>) -> Self { match fallback { None => self.clone(), Some(fallback) => Shaping { query_deduplication: self.query_deduplication.or(fallback.query_deduplication), compression: self.compression.or(fallback.compression), + timeout: self.timeout.or(fallback.timeout), + global_rate_limit: self + .global_rate_limit + .as_ref() + .or(fallback.global_rate_limit.as_ref()) + .cloned(), }, } } } +#[derive(PartialEq, Debug, Clone, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +struct RouterShaping { + /// Enable global rate limiting + global_rate_limit: Option, + #[serde(deserialize_with = "humantime_serde::deserialize", default)] + #[schemars(with = "String", default)] + /// Enable timeout for incoming requests + timeout: Option, +} + #[derive(PartialEq, Debug, Clone, Deserialize, JsonSchema)] #[serde(deny_unknown_fields)] struct Config { #[serde(default)] + /// Applied at the router level + router: Option, + #[serde(default)] + /// Applied on all subgraphs all: Option, #[serde(default)] + /// Applied on specific subgraphs subgraphs: HashMap, /// Enable variable deduplication optimization when sending requests to subgraphs (https://github.com/apollographql/router/issues/87) variables_deduplication: Option, } +#[derive(PartialEq, Debug, Clone, Deserialize, JsonSchema)] +#[serde(deny_unknown_fields)] +struct RateLimitConf { + /// Number of requests allowed + capacity: NonZeroU64, + #[serde(deserialize_with = "humantime_serde::deserialize")] + #[schemars(with = "String")] + /// Per interval + interval: Duration, +} + +impl Merge for RateLimitConf { + fn merge(&self, fallback: Option<&Self>) -> Self { + match fallback { + None => self.clone(), + Some(fallback) => Self { + capacity: fallback.capacity, + interval: fallback.interval, + }, + } + } +} + struct TrafficShaping { config: Config, + rate_limit_router: Option, + rate_limit_subgraphs: Mutex>, } #[async_trait::async_trait] @@ -75,11 +145,53 @@ impl Plugin for TrafficShaping { type Config = Config; async fn new(init: PluginInit) -> Result { + let rate_limit_router = init + .config + .router + .as_ref() + .and_then(|r| r.global_rate_limit.as_ref()) + .map(|router_rate_limit_conf| { + if router_rate_limit_conf.interval.as_millis() > u64::MAX as u128 { + Err(ConfigurationError::InvalidConfiguration { + message: "bad configuration for traffic_shaping plugin", + error: format!( + "cannot set an interval for the rate limit greater than {} ms", + u64::MAX + ), + }) + } else { + Ok(RateLimitLayer::new( + router_rate_limit_conf.capacity, + router_rate_limit_conf.interval, + )) + } + }) + .transpose()?; + Ok(Self { config: init.config, + rate_limit_router, + rate_limit_subgraphs: Mutex::new(HashMap::new()), }) } + fn router_service( + &self, + service: BoxService, + ) -> BoxService { + ServiceBuilder::new() + .layer(TimeoutLayer::new( + self.config + .router + .as_ref() + .and_then(|r| r.timeout) + .unwrap_or(DEFAULT_TIMEOUT), + )) + .option_layer(self.rate_limit_router.clone()) + .service(service) + .boxed() + } + fn subgraph_service( &self, name: &str, @@ -91,6 +203,16 @@ impl Plugin for TrafficShaping { let final_config = Self::merge_config(all_config, subgraph_config); if let Some(config) = final_config { + let rate_limit = config.global_rate_limit.as_ref().map(|rate_limit_conf| { + self.rate_limit_subgraphs + .lock() + .unwrap() + .entry(name.to_string()) + .or_insert_with(|| { + RateLimitLayer::new(rate_limit_conf.capacity, rate_limit_conf.interval) + }) + .clone() + }); ServiceBuilder::new() .option_layer(config.query_deduplication.unwrap_or_default().then(|| { // Buffer is required because dedup layer requires a clone service. @@ -98,6 +220,12 @@ impl Plugin for TrafficShaping { .layer(QueryDeduplicationLayer::default()) .buffered() })) + .layer(TimeoutLayer::new( + config + .timeout + .unwrap_or(DEFAULT_TIMEOUT), + )) + .option_layer(rate_limit) .service(service) .map_request(move |mut req: SubgraphRequest| { if let Some(compression) = config.compression { @@ -132,10 +260,10 @@ impl Plugin for TrafficShaping { } impl TrafficShaping { - fn merge_config( - all_config: Option<&Shaping>, - subgraph_config: Option<&Shaping>, - ) -> Option { + fn merge_config( + all_config: Option<&T>, + subgraph_config: Option<&T>, + ) -> Option { let merged_subgraph_config = subgraph_config.map(|c| c.merge(all_config)); merged_subgraph_config.or_else(|| all_config.cloned()) } @@ -148,6 +276,7 @@ mod test { use std::sync::Arc; use once_cell::sync::Lazy; + use serde_json_bytes::json; use serde_json_bytes::ByteString; use serde_json_bytes::Value; use tower::util::BoxCloneService; @@ -156,6 +285,7 @@ mod test { use super::*; use crate::graphql::Response; use crate::json_ext::Object; + use crate::plugin::test::MockRouterService; use crate::plugin::test::MockSubgraph; use crate::plugin::DynPlugin; use crate::PluggableRouterServiceBuilder; @@ -323,7 +453,7 @@ mod test { ) .unwrap(); - assert_eq!(TrafficShaping::merge_config(None, None), None); + assert_eq!(TrafficShaping::merge_config::(None, None), None); assert_eq!( TrafficShaping::merge_config(config.all.as_ref(), None), config.all @@ -339,4 +469,93 @@ mod test { config.subgraphs.get("products") ); } + + #[tokio::test] + async fn it_rate_limit_subgraph_requests() { + let config = serde_yaml::from_str::( + r#" + subgraphs: + test: + global_rate_limit: + capacity: 1 + interval: 300ms + timeout: 500ms + "#, + ) + .unwrap(); + + let plugin = get_traffic_shaping_plugin(&config).await; + + let test_service = MockSubgraph::new(HashMap::new()); + + let _response = plugin + .subgraph_service("test", test_service.clone().boxed()) + .oneshot(SubgraphRequest::fake_builder().build()) + .await + .unwrap(); + let _response = plugin + .subgraph_service("test", test_service.clone().boxed()) + .oneshot(SubgraphRequest::fake_builder().build()) + .await + .expect_err("should be in error due to a timeout and rate limit"); + let _response = plugin + .subgraph_service("another", test_service.clone().boxed()) + .oneshot(SubgraphRequest::fake_builder().build()) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(300)).await; + let _response = plugin + .subgraph_service("test", test_service.boxed()) + .oneshot(SubgraphRequest::fake_builder().build()) + .await + .unwrap(); + } + + #[tokio::test] + async fn it_rate_limit_router_requests() { + let config = serde_yaml::from_str::( + r#" + router: + global_rate_limit: + capacity: 1 + interval: 300ms + timeout: 500ms + "#, + ) + .unwrap(); + + let plugin = get_traffic_shaping_plugin(&config).await; + let mut mock_service = MockRouterService::new(); + mock_service.expect_call().times(2).returning(move |_| { + Ok(RouterResponse::fake_builder() + .data(json!({ "test": 1234_u32 })) + .build() + .unwrap()) + }); + let mock_service = mock_service.build(); + + let _response = plugin + .router_service(mock_service.clone().boxed()) + .oneshot(RouterRequest::fake_builder().build().unwrap()) + .await + .unwrap() + .next_response() + .await + .unwrap(); + + assert!(plugin + .router_service(mock_service.clone().boxed()) + .oneshot(RouterRequest::fake_builder().build().unwrap()) + .await + .is_err()); + tokio::time::sleep(Duration::from_millis(300)).await; + let _response = plugin + .router_service(mock_service.clone().boxed()) + .oneshot(RouterRequest::fake_builder().build().unwrap()) + .await + .unwrap() + .next_response() + .await + .unwrap(); + } } diff --git a/apollo-router/src/plugins/traffic_shaping/rate/error.rs b/apollo-router/src/plugins/traffic_shaping/rate/error.rs new file mode 100644 index 0000000000..6e06c5823a --- /dev/null +++ b/apollo-router/src/plugins/traffic_shaping/rate/error.rs @@ -0,0 +1,32 @@ +//! Error types + +use std::error; +use std::fmt; + +use axum::response::IntoResponse; +use http::StatusCode; + +/// The rate limit error. +#[derive(Debug, Default)] +pub(crate) struct RateLimited; + +impl RateLimited { + /// Construct a new RateLimited error + pub(crate) fn new() -> Self { + RateLimited {} + } +} + +impl fmt::Display for RateLimited { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("your request has been rate limited") + } +} + +impl IntoResponse for RateLimited { + fn into_response(self) -> axum::response::Response { + (StatusCode::TOO_MANY_REQUESTS, self.to_string()).into_response() + } +} + +impl error::Error for RateLimited {} diff --git a/apollo-router/src/plugins/traffic_shaping/rate/future.rs b/apollo-router/src/plugins/traffic_shaping/rate/future.rs new file mode 100644 index 0000000000..cb3ed8eb1e --- /dev/null +++ b/apollo-router/src/plugins/traffic_shaping/rate/future.rs @@ -0,0 +1,39 @@ +//! Future types + +use std::future::Future; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use pin_project_lite::pin_project; + +pin_project! { + #[derive(Debug)] + pub(crate) struct ResponseFuture { + #[pin] + response: T, + } +} + +impl ResponseFuture { + pub(crate) fn new(response: T) -> Self { + ResponseFuture { response } + } +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + match this.response.poll(cx) { + Poll::Ready(v) => Poll::Ready(v.map_err(Into::into)), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/apollo-router/src/plugins/traffic_shaping/rate/layer.rs b/apollo-router/src/plugins/traffic_shaping/rate/layer.rs new file mode 100644 index 0000000000..107b394a70 --- /dev/null +++ b/apollo-router/src/plugins/traffic_shaping/rate/layer.rs @@ -0,0 +1,53 @@ +use std::num::NonZeroU64; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; +use std::time::Duration; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +use tower::Layer; + +use super::Rate; +use super::RateLimit; +/// Enforces a rate limit on the number of requests the underlying +/// service can handle over a period of time. +#[derive(Debug, Clone)] +pub(crate) struct RateLimitLayer { + rate: Rate, + window_start: Arc, + previous_nb_requests: Arc, + current_nb_requests: Arc, +} + +impl RateLimitLayer { + /// Create new rate limit layer. + pub(crate) fn new(num: NonZeroU64, per: Duration) -> Self { + let rate = Rate::new(num, per); + RateLimitLayer { + rate, + window_start: Arc::new(AtomicU64::new( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after EPOCH") + .as_millis() as u64, + )), + previous_nb_requests: Arc::default(), + current_nb_requests: Arc::new(AtomicUsize::new(1)), + } + } +} + +impl Layer for RateLimitLayer { + type Service = RateLimit; + + fn layer(&self, service: S) -> Self::Service { + RateLimit { + inner: service, + rate: self.rate, + window_start: self.window_start.clone(), + previous_nb_requests: self.previous_nb_requests.clone(), + current_nb_requests: self.current_nb_requests.clone(), + } + } +} diff --git a/apollo-router/src/plugins/traffic_shaping/rate/mod.rs b/apollo-router/src/plugins/traffic_shaping/rate/mod.rs new file mode 100644 index 0000000000..6a2288a7c0 --- /dev/null +++ b/apollo-router/src/plugins/traffic_shaping/rate/mod.rs @@ -0,0 +1,13 @@ +//! Limit the rate at which requests are processed. + +mod error; +mod future; +mod layer; +#[allow(clippy::module_inception)] +mod rate; +mod service; + +pub(crate) use self::error::RateLimited; +pub(crate) use self::layer::RateLimitLayer; +pub(crate) use self::rate::Rate; +pub(crate) use self::service::RateLimit; diff --git a/apollo-router/src/plugins/traffic_shaping/rate/rate.rs b/apollo-router/src/plugins/traffic_shaping/rate/rate.rs new file mode 100644 index 0000000000..eb73f74f10 --- /dev/null +++ b/apollo-router/src/plugins/traffic_shaping/rate/rate.rs @@ -0,0 +1,33 @@ +use std::num::NonZeroU64; +use std::time::Duration; + +/// A rate of requests per time period. +#[derive(Debug, Copy, Clone)] +pub(crate) struct Rate { + num: u64, + per: Duration, +} + +impl Rate { + /// Create a new rate. + /// + /// # Panics + /// + /// This function panics if `num` or `per` is 0. + pub(crate) fn new(num: NonZeroU64, per: Duration) -> Self { + assert!(per > Duration::default()); + + Rate { + num: num.into(), + per, + } + } + + pub(crate) fn num(&self) -> u64 { + self.num + } + + pub(crate) fn per(&self) -> Duration { + self.per + } +} diff --git a/apollo-router/src/plugins/traffic_shaping/rate/service.rs b/apollo-router/src/plugins/traffic_shaping/rate/service.rs new file mode 100644 index 0000000000..cc762f1e7d --- /dev/null +++ b/apollo-router/src/plugins/traffic_shaping/rate/service.rs @@ -0,0 +1,83 @@ +use std::sync::atomic::AtomicU64; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::task::Context; +use std::task::Poll; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +use futures::ready; +use tower::Service; + +use super::future::ResponseFuture; +use super::Rate; +use crate::plugins::traffic_shaping::rate::error::RateLimited; + +#[derive(Debug)] +pub(crate) struct RateLimit { + pub(crate) inner: T, + pub(crate) rate: Rate, + /// We're using an atomic u64 because it's basically a timestamp in milliseconds for the start of the window + /// Instead of using an Instant which is not thread safe we're using an atomic u64 + /// It's ok to have an u64 because we just care about milliseconds for this use case + pub(crate) window_start: Arc, + pub(crate) previous_nb_requests: Arc, + pub(crate) current_nb_requests: Arc, +} + +impl Service for RateLimit +where + S: Service, + S::Error: Into, +{ + type Response = S::Response; + type Error = tower::BoxError; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + let time_unit = self.rate.per().as_millis() as u64; + + let updated = + self.window_start + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |window_start| { + let duration_now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after EPOCH") + .as_millis() as u64; + if duration_now - window_start > self.rate.per().as_millis() as u64 { + Some(duration_now) + } else { + None + } + }); + // If it has been updated + if let Ok(_updated_window_start) = updated { + self.previous_nb_requests.swap( + self.current_nb_requests.load(Ordering::SeqCst), + Ordering::SeqCst, + ); + self.current_nb_requests.swap(1, Ordering::SeqCst); + } + + let estimated_cap = (self.previous_nb_requests.load(Ordering::SeqCst) + * (time_unit + .checked_sub(self.window_start.load(Ordering::SeqCst)) + .unwrap_or_default() + / time_unit) as usize) + + self.current_nb_requests.load(Ordering::SeqCst); + + if estimated_cap as u64 > self.rate.num() { + tracing::trace!("rate limit exceeded; sleeping."); + return Poll::Ready(Err(RateLimited::new().into())); + } + + self.current_nb_requests.fetch_add(1, Ordering::SeqCst); + + Poll::Ready(ready!(self.inner.poll_ready(cx)).map_err(Into::into)) + } + + fn call(&mut self, request: Request) -> Self::Future { + ResponseFuture::new(self.inner.call(request)) + } +} diff --git a/apollo-router/src/plugins/traffic_shaping/timeout/error.rs b/apollo-router/src/plugins/traffic_shaping/timeout/error.rs new file mode 100644 index 0000000000..2d7f51cf18 --- /dev/null +++ b/apollo-router/src/plugins/traffic_shaping/timeout/error.rs @@ -0,0 +1,32 @@ +//! Error types + +use std::error; +use std::fmt; + +use axum::response::IntoResponse; +use http::StatusCode; + +/// The timeout elapsed. +#[derive(Debug, Default)] +pub(crate) struct Elapsed; + +impl Elapsed { + /// Construct a new elapsed error + pub(crate) fn new() -> Self { + Elapsed {} + } +} + +impl fmt::Display for Elapsed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("request timed out") + } +} + +impl IntoResponse for Elapsed { + fn into_response(self) -> axum::response::Response { + (StatusCode::REQUEST_TIMEOUT, self.to_string()).into_response() + } +} + +impl error::Error for Elapsed {} diff --git a/apollo-router/src/plugins/traffic_shaping/timeout/future.rs b/apollo-router/src/plugins/traffic_shaping/timeout/future.rs new file mode 100644 index 0000000000..1c03105a0d --- /dev/null +++ b/apollo-router/src/plugins/traffic_shaping/timeout/future.rs @@ -0,0 +1,54 @@ +//! Future types + +use std::future::Future; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use pin_project_lite::pin_project; +use tokio::time::Sleep; + +use super::error::Elapsed; + +pin_project! { + /// [`Timeout`] response future + /// + /// [`Timeout`]: crate::timeout::Timeout + #[derive(Debug)] + pub(crate) struct ResponseFuture { + #[pin] + response: T, + #[pin] + sleep: Pin>, + } +} + +impl ResponseFuture { + pub(crate) fn new(response: T, sleep: Pin>) -> Self { + ResponseFuture { response, sleep } + } +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + // First, try polling the future + match this.response.poll(cx) { + Poll::Ready(v) => return Poll::Ready(v.map_err(Into::into)), + Poll::Pending => {} + } + + // Now check the sleep + match Pin::new(&mut this.sleep).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(_) => Poll::Ready(Err(Elapsed::new().into())), + } + } +} diff --git a/apollo-router/src/plugins/traffic_shaping/timeout/layer.rs b/apollo-router/src/plugins/traffic_shaping/timeout/layer.rs new file mode 100644 index 0000000000..8dab116c11 --- /dev/null +++ b/apollo-router/src/plugins/traffic_shaping/timeout/layer.rs @@ -0,0 +1,26 @@ +use std::time::Duration; + +use tower::Layer; + +use super::Timeout; + +/// Applies a timeout to requests via the supplied inner service. +#[derive(Debug, Clone)] +pub(crate) struct TimeoutLayer { + timeout: Duration, +} + +impl TimeoutLayer { + /// Create a timeout from a duration + pub(crate) fn new(timeout: Duration) -> Self { + TimeoutLayer { timeout } + } +} + +impl Layer for TimeoutLayer { + type Service = Timeout; + + fn layer(&self, service: S) -> Self::Service { + Timeout::new(service, self.timeout) + } +} diff --git a/apollo-router/src/plugins/traffic_shaping/timeout/mod.rs b/apollo-router/src/plugins/traffic_shaping/timeout/mod.rs new file mode 100644 index 0000000000..b766437ea5 --- /dev/null +++ b/apollo-router/src/plugins/traffic_shaping/timeout/mod.rs @@ -0,0 +1,93 @@ +//! This is a modified Timeout service copy/pasted from the tower codebase. +//! This Timeout is also checking if we do not timeout on the `poll_ready` and not only on the `call` part +//! Middleware that applies a timeout to requests. +//! +//! If the response does not complete within the specified timeout, the response +//! will be aborted. + +pub(crate) mod error; +pub(crate) mod future; +mod layer; + +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; +use std::time::Duration; + +use futures::Future; +use tokio::time::Sleep; +use tower::Service; + +use self::future::ResponseFuture; +pub(crate) use self::layer::TimeoutLayer; +pub(crate) use crate::plugins::traffic_shaping::timeout::error::Elapsed; + +/// Applies a timeout to requests. +#[derive(Debug)] +pub(crate) struct Timeout { + inner: T, + timeout: Duration, + sleep: Option>>, +} + +// ===== impl Timeout ===== + +impl Timeout { + /// Creates a new [`Timeout`] + pub(crate) fn new(inner: T, timeout: Duration) -> Self { + Timeout { + inner, + timeout, + sleep: None, + } + } +} + +impl Service for Timeout +where + S: Service, + S::Error: Into, +{ + type Response = S::Response; + type Error = tower::BoxError; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.sleep.is_none() { + self.sleep = Some(Box::pin(tokio::time::sleep(self.timeout))); + } + match self.inner.poll_ready(cx) { + Poll::Pending => {} + Poll::Ready(r) => return Poll::Ready(r.map_err(Into::into)), + }; + + // Checking if we don't timeout on `poll_ready` + if Pin::new( + &mut self + .sleep + .as_mut() + .expect("we can unwrap because we set it just before"), + ) + .poll(cx) + .is_ready() + { + tracing::trace!("timeout exceeded."); + self.sleep = None; + + return Poll::Ready(Err(Elapsed::new().into())); + } + + Poll::Pending + } + + fn call(&mut self, request: Request) -> Self::Future { + let response = self.inner.call(request); + + ResponseFuture::new( + response, + self.sleep + .take() + .expect("poll_ready must been called before"), + ) + } +} diff --git a/docs/source/configuration/traffic-shaping.mdx b/docs/source/configuration/traffic-shaping.mdx index 00c85dacdc..2a2216d2b8 100644 --- a/docs/source/configuration/traffic-shaping.mdx +++ b/docs/source/configuration/traffic-shaping.mdx @@ -5,11 +5,13 @@ title: Traffic shaping in the Apollo Router The Apollo Router supports the following types of traffic shaping between itself and your subgraphs: - **Sub-query deduplication** - Whenever the router is sending multiple identical in-flight query operations to a subgraph, it can consolidate them into a single request. - - Mutation operations are never deduplicated. - - Only in-flight requests are deduplicated. + - Mutation operations are never deduplicated. + - Only in-flight requests are deduplicated. - **Variable deduplication** - If a request to a subgraph includes multiple GraphQL variables with the same value, the router can replace those with a single variable. - **Compression** - The router can compress request bodies to subgraphs (along with response bodies to clients) with a supported algorithm - - The router currently supports `gzip`, `br`, and `deflate`. + - The router currently supports `gzip`, `br`, and `deflate`. +- **Global rate limiting** - If you want to rate limit requests to subgraphs or to the router itself. +- **Timeout**: - Set a timeout to subgraphs and router requests. Each of these optimizations can reduce network bandwidth and CPU usage for your subgraphs. @@ -20,13 +22,22 @@ To enable traffic shaping, add the `traffic_shaping` plugin to your [YAML config ```yaml title="router.yaml" traffic_shaping: variables_deduplication: true # Enable the variable deduplication optimization. + router: # Rules applied to requests from clients to the router + global_rate_limit: # Accept a maximum of 10 requests per 5 secs. Excess requests must be rejected. + capacity: 10 + interval: 5s # Must not be greater than 18_446_744_073_709_551_615 milliseconds and not less than 0 milliseconds + timeout: 50s # If a request to the router takes more than 50secs then cancel the request (30 sec by default) all: query_deduplication: true # Enable query deduplication for all subgraphs. compression: br # Enable brotli compression for all subgraphs. - subgraphs: + subgraphs: # Rules applied to requests from the router to individual subgraphs products: query_deduplication: false # Disable query for the products subgraph. compression: gzip # Enable gzip compression only for the products subgraph. + global_rate_limit: # Accept a maximum of 10 requests per 5 secs from the router. Excess requests must be rejected. + capacity: 10 + interval: 5s # Must not be greater than 18_446_744_073_709_551_615 milliseconds and not less than 0 milliseconds + timeout: 50s # If a request to the subgraph 'products' takes more than 50secs then cancel the request (30 sec by default) ``` Any configuration under the `subgraphs` key takes precedence over configuration under the `all` key. In the example above, query deduplication is enabled for all subgraphs _except_ the `products` subgraph.