diff --git a/Cargo.lock b/Cargo.lock index 894d2c607690..d5bfbf953070 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1859,6 +1859,7 @@ dependencies = [ "frontend", "futures", "human-panic", + "humantime", "lazy_static", "meta-client", "meta-srv", @@ -10976,11 +10977,13 @@ dependencies = [ "datatypes", "derive_builder 0.12.0", "futures", + "futures-util", "hashbrown 0.14.5", "headers 0.3.9", "hostname", "http 0.2.12", "http-body 0.4.6", + "humantime", "humantime-serde", "hyper 0.14.30", "influxdb_line_protocol", diff --git a/src/cmd/Cargo.toml b/src/cmd/Cargo.toml index 95937567c371..c1f20cc9c526 100644 --- a/src/cmd/Cargo.toml +++ b/src/cmd/Cargo.toml @@ -53,6 +53,7 @@ flow.workspace = true frontend = { workspace = true, default-features = false } futures.workspace = true human-panic = "2.0" +humantime.workspace = true lazy_static.workspace = true meta-client.workspace = true meta-srv.workspace = true diff --git a/src/cmd/src/cli/database.rs b/src/cmd/src/cli/database.rs index eb5647699ef0..d313e93acf7c 100644 --- a/src/cmd/src/cli/database.rs +++ b/src/cmd/src/cli/database.rs @@ -12,11 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::time::Duration; + use base64::engine::general_purpose; use base64::Engine; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use humantime::format_duration; use serde_json::Value; use servers::http::greptime_result_v1::GreptimedbV1Response; +use servers::http::header::constants::GREPTIME_DB_HEADER_TIMEOUT; use servers::http::GreptimeQueryOutput; use snafu::ResultExt; @@ -26,10 +30,16 @@ pub(crate) struct DatabaseClient { addr: String, catalog: String, auth_header: Option, + timeout: Option, } impl DatabaseClient { - pub fn new(addr: String, catalog: String, auth_basic: Option) -> Self { + pub fn new( + addr: String, + catalog: String, + auth_basic: Option, + timeout: Option, + ) -> Self { let auth_header = if let Some(basic) = auth_basic { let encoded = general_purpose::STANDARD.encode(basic); Some(format!("basic {}", encoded)) @@ -41,6 +51,7 @@ impl DatabaseClient { addr, catalog, auth_header, + timeout, } } @@ -62,6 +73,12 @@ impl DatabaseClient { if let Some(ref auth) = self.auth_header { request = request.header("Authorization", auth); } + if let Some(ref timeout) = self.timeout { + request = request.header( + GREPTIME_DB_HEADER_TIMEOUT, + format_duration(*timeout).to_string(), + ); + } let response = request.send().await.with_context(|_| HttpQuerySqlSnafu { reason: format!("bad url: {}", url), diff --git a/src/cmd/src/cli/export.rs b/src/cmd/src/cli/export.rs index ee5f5329cdd5..760fdbbc02f6 100644 --- a/src/cmd/src/cli/export.rs +++ b/src/cmd/src/cli/export.rs @@ -15,6 +15,7 @@ use std::collections::HashSet; use std::path::PathBuf; use std::sync::Arc; +use std::time::Duration; use async_trait::async_trait; use clap::{Parser, ValueEnum}; @@ -83,14 +84,22 @@ pub struct ExportCommand { /// The basic authentication for connecting to the server #[clap(long)] auth_basic: Option, + + /// The timeout of invoking the database. + #[clap(long, value_parser = humantime::parse_duration)] + timeout: Option, } impl ExportCommand { pub async fn build(&self, guard: Vec) -> Result { let (catalog, schema) = database::split_database(&self.database)?; - let database_client = - DatabaseClient::new(self.addr.clone(), catalog.clone(), self.auth_basic.clone()); + let database_client = DatabaseClient::new( + self.addr.clone(), + catalog.clone(), + self.auth_basic.clone(), + self.timeout, + ); Ok(Instance::new( Box::new(Export { diff --git a/src/cmd/src/cli/import.rs b/src/cmd/src/cli/import.rs index b1d27fb0e058..908fc944bd03 100644 --- a/src/cmd/src/cli/import.rs +++ b/src/cmd/src/cli/import.rs @@ -14,6 +14,7 @@ use std::path::PathBuf; use std::sync::Arc; +use std::time::Duration; use async_trait::async_trait; use clap::{Parser, ValueEnum}; @@ -68,13 +69,21 @@ pub struct ImportCommand { /// The basic authentication for connecting to the server #[clap(long)] auth_basic: Option, + + /// The timeout of invoking the database. + #[clap(long, value_parser = humantime::parse_duration)] + timeout: Option, } impl ImportCommand { pub async fn build(&self, guard: Vec) -> Result { let (catalog, schema) = database::split_database(&self.database)?; - let database_client = - DatabaseClient::new(self.addr.clone(), catalog.clone(), self.auth_basic.clone()); + let database_client = DatabaseClient::new( + self.addr.clone(), + catalog.clone(), + self.auth_basic.clone(), + self.timeout, + ); Ok(Instance::new( Box::new(Import { diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index a2803ae03572..47bbfc4f7382 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -52,11 +52,13 @@ datafusion-expr.workspace = true datatypes.workspace = true derive_builder.workspace = true futures = "0.3" +futures-util.workspace = true hashbrown = "0.14" headers = "0.3" hostname = "0.3" http = "0.2" http-body = "0.4" +humantime.workspace = true humantime-serde.workspace = true hyper = { version = "0.14", features = ["full"] } influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", branch = "feat/line-protocol" } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 12ac06db9070..34cff5de6fe4 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -45,7 +45,6 @@ use serde_json::Value; use snafu::{ensure, ResultExt}; use tokio::sync::oneshot::{self, Sender}; use tokio::sync::Mutex; -use tower::timeout::TimeoutLayer; use tower::ServiceBuilder; use tower_http::decompression::RequestDecompressionLayer; use tower_http::trace::TraceLayer; @@ -101,6 +100,9 @@ pub mod greptime_result_v1; pub mod influxdb_result_v1; pub mod json_result; pub mod table_result; +mod timeout; + +pub(crate) use timeout::DynamicTimeoutLayer; #[cfg(any(test, feature = "testing"))] pub mod test_helpers; @@ -704,7 +706,7 @@ impl HttpServer { pub fn build(&self, router: Router) -> Router { let timeout_layer = if self.options.timeout != Duration::default() { - Some(ServiceBuilder::new().layer(TimeoutLayer::new(self.options.timeout))) + Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout))) } else { info!("HTTP server timeout is disabled"); None @@ -997,10 +999,12 @@ mod test { use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{StringVector, UInt32Vector}; + use header::constants::GREPTIME_DB_HEADER_TIMEOUT; use query::parser::PromQuery; use query::query_engine::DescribeResult; use session::context::QueryContextRef; use tokio::sync::mpsc; + use tokio::time::Instant; use super::*; use crate::error::Error; @@ -1062,8 +1066,8 @@ mod test { } } - fn timeout() -> TimeoutLayer { - TimeoutLayer::new(Duration::from_millis(10)) + fn timeout() -> DynamicTimeoutLayer { + DynamicTimeoutLayer::new(Duration::from_millis(10)) } async fn forever() { @@ -1102,6 +1106,16 @@ mod test { let client = TestClient::new(app); let res = client.get("/test/timeout").send().await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); + + let now = Instant::now(); + let res = client + .get("/test/timeout") + .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms") + .send() + .await; + assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); + let elapsed = now.elapsed(); + assert!(elapsed > Duration::from_millis(15)); } #[tokio::test] diff --git a/src/servers/src/http/header.rs b/src/servers/src/http/header.rs index 16962a56395a..bf5b0a4ebc14 100644 --- a/src/servers/src/http/header.rs +++ b/src/servers/src/http/header.rs @@ -39,6 +39,7 @@ pub mod constants { // LEGACY HEADERS - KEEP IT UNMODIFIED pub const GREPTIME_DB_HEADER_FORMAT: &str = "x-greptime-format"; + pub const GREPTIME_DB_HEADER_TIMEOUT: &str = "x-greptime-timeout"; pub const GREPTIME_DB_HEADER_EXECUTION_TIME: &str = "x-greptime-execution-time"; pub const GREPTIME_DB_HEADER_METRICS: &str = "x-greptime-metrics"; pub const GREPTIME_DB_HEADER_NAME: &str = "x-greptime-db-name"; diff --git a/src/servers/src/http/timeout.rs b/src/servers/src/http/timeout.rs new file mode 100644 index 000000000000..7a42918124d4 --- /dev/null +++ b/src/servers/src/http/timeout.rs @@ -0,0 +1,144 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use axum::body::Body; +use axum::http::Request; +use axum::response::Response; +use pin_project::pin_project; +use tokio::time::Sleep; +use tower::timeout::error::Elapsed; +use tower::{BoxError, Layer, Service}; + +use crate::http::header::constants::GREPTIME_DB_HEADER_TIMEOUT; + +/// [`Timeout`] response future +/// +/// [`Timeout`]: crate::timeout::Timeout +/// +/// Modified from https://github.com/tower-rs/tower/blob/8b84b98d93a2493422a0ecddb6251f292a904cff/tower/src/timeout/future.rs +#[derive(Debug)] +#[pin_project] +pub struct ResponseFuture { + #[pin] + response: T, + #[pin] + sleep: Sleep, +} + +impl ResponseFuture { + pub(crate) fn new(response: T, sleep: Sleep) -> Self { + ResponseFuture { response, sleep } + } +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + // First, try polling the future + match this.response.poll(cx) { + Poll::Ready(v) => return Poll::Ready(v.map_err(Into::into)), + Poll::Pending => {} + } + + // Now check the sleep + match this.sleep.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(_) => Poll::Ready(Err(Elapsed::new().into())), + } + } +} + +/// Applies a timeout to requests via the supplied inner service. +/// +/// Modified from https://github.com/tower-rs/tower/blob/8b84b98d93a2493422a0ecddb6251f292a904cff/tower/src/timeout/layer.rs +#[derive(Debug, Clone)] +pub struct DynamicTimeoutLayer { + default_timeout: Duration, +} + +impl DynamicTimeoutLayer { + /// Create a timeout from a duration + pub fn new(default_timeout: Duration) -> Self { + DynamicTimeoutLayer { default_timeout } + } +} + +impl Layer for DynamicTimeoutLayer { + type Service = DynamicTimeout; + + fn layer(&self, service: S) -> Self::Service { + DynamicTimeout::new(service, self.default_timeout) + } +} + +/// Modified from https://github.com/tower-rs/tower/blob/8b84b98d93a2493422a0ecddb6251f292a904cff/tower/src/timeout/mod.rs +#[derive(Clone)] +pub struct DynamicTimeout { + inner: S, + default_timeout: Duration, +} + +impl DynamicTimeout { + /// Create a new [`DynamicTimeout`] with the given timeout + pub fn new(inner: S, default_timeout: Duration) -> Self { + DynamicTimeout { + inner, + default_timeout, + } + } +} + +impl Service> for DynamicTimeout +where + S: Service, Response = Response> + Send + 'static, + S::Error: Into, +{ + type Response = S::Response; + type Error = BoxError; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.inner.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), + } + } + + fn call(&mut self, request: Request) -> Self::Future { + let user_timeout = request + .headers() + .get(GREPTIME_DB_HEADER_TIMEOUT) + .and_then(|value| { + value + .to_str() + .ok() + .and_then(|value| humantime::parse_duration(value).ok()) + }); + let response = self.inner.call(request); + let sleep = tokio::time::sleep(user_timeout.unwrap_or(self.default_timeout)); + ResponseFuture::new(response, sleep) + } +}