diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index cbc9f1f0ed..c1936e7a7d 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -11,6 +11,44 @@ # meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"} # author = "rcoh" +[[aws-sdk-rust]] +message = """ +Add configurable stalled-stream protection for downloads. + +When making HTTP calls, +it's possible for a connection to 'stall out' and emit no more data due to server-side issues. +In the event this happens, it's desirable for the stream to error out as quickly as possible. +While timeouts can protect you from this issue, they aren't adaptive to the amount of data +being sent and so must be configured specifically for each use case. When enabled, stalled-stream +protection will ensure that bad streams error out quickly, regardless of the amount of data being +downloaded. + +Protection is enabled by default for all clients but can be configured or disabled. +See [this discussion](https://github.com/awslabs/aws-sdk-rust/discussions/956) for more details. +""" +references = ["smithy-rs#3202"] +meta = { "breaking" = true, "tada" = true, "bug" = false } +author = "Velfi" + +[[smithy-rs]] +message = """ +Add configurable stalled-stream protection for downloads. + +When making HTTP calls, +it's possible for a connection to 'stall out' and emit no more data due to server-side issues. +In the event this happens, it's desirable for the stream to error out as quickly as possible. +While timeouts can protect you from this issue, they aren't adaptive to the amount of data +being sent and so must be configured specifically for each use case. When enabled, stalled-stream +protection will ensure that bad streams error out quickly, regardless of the amount of data being +downloaded. + +Protection is enabled by default for all clients but can be configured or disabled. +See [this discussion](https://github.com/awslabs/aws-sdk-rust/discussions/956) for more details. +""" +references = ["smithy-rs#3202"] +meta = { "breaking" = true, "tada" = true, "bug" = false, "target" = "client" } +author = "Velfi" + [[aws-sdk-rust]] message = "Make certain types for EMR Serverless optional. Previously, they defaulted to 0, but this created invalid requests." references = ["smithy-rs#3217"] diff --git a/aws/rust-runtime/aws-config/external-types.toml b/aws/rust-runtime/aws-config/external-types.toml index 19c1f6aab8..38a34bd786 100644 --- a/aws/rust-runtime/aws-config/external-types.toml +++ b/aws/rust-runtime/aws-config/external-types.toml @@ -14,15 +14,16 @@ allowed_external_types = [ "aws_smithy_runtime_api::box_error::BoxError", "aws_smithy_runtime::client::identity::cache::IdentityCache", "aws_smithy_runtime::client::identity::cache::lazy::LazyCacheBuilder", + "aws_smithy_runtime_api::client::behavior_version::BehaviorVersion", "aws_smithy_runtime_api::client::dns::ResolveDns", "aws_smithy_runtime_api::client::dns::SharedDnsResolver", "aws_smithy_runtime_api::client::http::HttpClient", "aws_smithy_runtime_api::client::http::SharedHttpClient", "aws_smithy_runtime_api::client::identity::ResolveCachedIdentity", "aws_smithy_runtime_api::client::identity::ResolveIdentity", - "aws_smithy_runtime_api::client::behavior_version::BehaviorVersion", "aws_smithy_runtime_api::client::orchestrator::HttpResponse", "aws_smithy_runtime_api::client::result::SdkError", + "aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig", "aws_smithy_types::body::SdkBody", "aws_smithy_types::retry", "aws_smithy_types::retry::*", diff --git a/aws/rust-runtime/aws-config/src/lib.rs b/aws/rust-runtime/aws-config/src/lib.rs index 73d5106e9d..e9b2d1890f 100644 --- a/aws/rust-runtime/aws-config/src/lib.rs +++ b/aws/rust-runtime/aws-config/src/lib.rs @@ -134,6 +134,7 @@ pub mod retry; mod sensitive_command; #[cfg(feature = "sso")] pub mod sso; +pub mod stalled_stream_protection; pub(crate) mod standard_property; pub mod sts; pub mod timeout; @@ -216,6 +217,7 @@ mod loader { use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion; use aws_smithy_runtime_api::client::http::HttpClient; use aws_smithy_runtime_api::client::identity::{ResolveCachedIdentity, SharedIdentityCache}; + use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig; use aws_smithy_runtime_api::shared::IntoShared; use aws_smithy_types::retry::RetryConfig; use aws_smithy_types::timeout::TimeoutConfig; @@ -259,6 +261,7 @@ mod loader { use_fips: Option, use_dual_stack: Option, time_source: Option, + stalled_stream_protection_config: Option, env: Option, fs: Option, behavior_version: Option, @@ -611,6 +614,39 @@ mod loader { self } + /// Override the [`StalledStreamProtectionConfig`] used to build [`SdkConfig`](aws_types::SdkConfig). + /// + /// This configures stalled stream protection. When enabled, download streams + /// that stop (stream no data) for longer than a configured grace period will return an error. + /// + /// By default, streams that transmit less than one byte per-second for five seconds will + /// be cancelled. + /// + /// _Note_: When an override is provided, the default implementation is replaced. + /// + /// # Examples + /// ```no_run + /// # async fn create_config() { + /// use aws_config::stalled_stream_protection::StalledStreamProtectionConfig; + /// use std::time::Duration; + /// let config = aws_config::from_env() + /// .stalled_stream_protection( + /// StalledStreamProtectionConfig::enabled() + /// .grace_period(Duration::from_secs(1)) + /// .build() + /// ) + /// .load() + /// .await; + /// # } + /// ``` + pub fn stalled_stream_protection( + mut self, + stalled_stream_protection_config: StalledStreamProtectionConfig, + ) -> Self { + self.stalled_stream_protection_config = Some(stalled_stream_protection_config); + self + } + /// Set configuration for all sub-loaders (credentials, region etc.) /// /// Update the `ProviderConfig` used for all nested loaders. This can be used to override @@ -757,6 +793,7 @@ mod loader { builder.set_endpoint_url(self.endpoint_url); builder.set_use_fips(use_fips); builder.set_use_dual_stack(use_dual_stack); + builder.set_stalled_stream_protection(self.stalled_stream_protection_config); builder.build() } } diff --git a/aws/rust-runtime/aws-config/src/stalled_stream_protection.rs b/aws/rust-runtime/aws-config/src/stalled_stream_protection.rs new file mode 100644 index 0000000000..1e47905305 --- /dev/null +++ b/aws/rust-runtime/aws-config/src/stalled_stream_protection.rs @@ -0,0 +1,9 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Stalled stream protection configuration + +// Re-export from aws-smithy-types +pub use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig; diff --git a/aws/rust-runtime/aws-types/external-types.toml b/aws/rust-runtime/aws-types/external-types.toml index 9e0184826b..1e4e0c863d 100644 --- a/aws/rust-runtime/aws-types/external-types.toml +++ b/aws/rust-runtime/aws-types/external-types.toml @@ -5,11 +5,12 @@ allowed_external_types = [ "aws_smithy_async::rt::sleep::SharedAsyncSleep", "aws_smithy_async::time::SharedTimeSource", "aws_smithy_async::time::TimeSource", + "aws_smithy_runtime_api::client::behavior_version::BehaviorVersion", "aws_smithy_runtime_api::client::http::HttpClient", "aws_smithy_runtime_api::client::http::SharedHttpClient", "aws_smithy_runtime_api::client::identity::ResolveCachedIdentity", "aws_smithy_runtime_api::client::identity::SharedIdentityCache", - "aws_smithy_runtime_api::client::behavior_version::BehaviorVersion", + "aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig", "aws_smithy_runtime_api::http::headers::Headers", "aws_smithy_types::config_bag::storable::Storable", "aws_smithy_types::config_bag::storable::StoreReplace", diff --git a/aws/rust-runtime/aws-types/src/sdk_config.rs b/aws/rust-runtime/aws-types/src/sdk_config.rs index 6ea10806f1..1da2bfe657 100644 --- a/aws/rust-runtime/aws-types/src/sdk_config.rs +++ b/aws/rust-runtime/aws-types/src/sdk_config.rs @@ -21,6 +21,7 @@ use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion; use aws_smithy_runtime_api::client::http::HttpClient; pub use aws_smithy_runtime_api::client::http::SharedHttpClient; use aws_smithy_runtime_api::client::identity::{ResolveCachedIdentity, SharedIdentityCache}; +pub use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig; use aws_smithy_runtime_api::shared::IntoShared; pub use aws_smithy_types::retry::RetryConfig; pub use aws_smithy_types::timeout::TimeoutConfig; @@ -60,6 +61,7 @@ pub struct SdkConfig { sleep_impl: Option, time_source: Option, timeout_config: Option, + stalled_stream_protection_config: Option, http_client: Option, use_fips: Option, use_dual_stack: Option, @@ -82,6 +84,7 @@ pub struct Builder { sleep_impl: Option, time_source: Option, timeout_config: Option, + stalled_stream_protection_config: Option, http_client: Option, use_fips: Option, use_dual_stack: Option, @@ -567,10 +570,82 @@ impl Builder { use_dual_stack: self.use_dual_stack, time_source: self.time_source, behavior_version: self.behavior_version, + stalled_stream_protection_config: self.stalled_stream_protection_config, } } } +impl Builder { + /// Set the [`StalledStreamProtectionConfig`] to configure protection for stalled streams. + /// + /// This configures stalled stream protection. When enabled, download streams + /// that stall (stream no data) for longer than a configured grace period will return an error. + /// + /// _Note:_ Stalled stream protection requires both a sleep implementation and a time source + /// in order to work. When enabling stalled stream protection, make sure to set + /// - A sleep impl with [Self::sleep_impl] or [Self::set_sleep_impl]. + /// - A time source with [Self::time_source] or [Self::set_time_source]. + /// + /// # Examples + /// ```rust + /// use std::time::Duration; + /// use aws_types::SdkConfig; + /// pub use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig; + /// + /// let stalled_stream_protection_config = StalledStreamProtectionConfig::enabled() + /// .grace_period(Duration::from_secs(1)) + /// .build(); + /// let config = SdkConfig::builder() + /// .stalled_stream_protection(stalled_stream_protection_config) + /// .build(); + /// ``` + pub fn stalled_stream_protection( + mut self, + stalled_stream_protection_config: StalledStreamProtectionConfig, + ) -> Self { + self.set_stalled_stream_protection(Some(stalled_stream_protection_config)); + self + } + + /// Set the [`StalledStreamProtectionConfig`] to configure protection for stalled streams. + /// + /// This configures stalled stream protection. When enabled, download streams + /// that stall (stream no data) for longer than a configured grace period will return an error. + /// + /// By default, streams that transmit less than one byte per-second for five seconds will + /// be cancelled. + /// + /// _Note:_ Stalled stream protection requires both a sleep implementation and a time source + /// in order to work. When enabling stalled stream protection, make sure to set + /// - A sleep impl with [Self::sleep_impl] or [Self::set_sleep_impl]. + /// - A time source with [Self::time_source] or [Self::set_time_source]. + /// + /// # Examples + /// ```rust + /// use std::time::Duration; + /// use aws_types::sdk_config::{SdkConfig, Builder}; + /// pub use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig; + /// + /// fn set_stalled_stream_protection(builder: &mut Builder) { + /// let stalled_stream_protection_config = StalledStreamProtectionConfig::enabled() + /// .grace_period(Duration::from_secs(1)) + /// .build(); + /// builder.set_stalled_stream_protection(Some(stalled_stream_protection_config)); + /// } + /// + /// let mut builder = SdkConfig::builder(); + /// set_stalled_stream_protection(&mut builder); + /// let config = builder.build(); + /// ``` + pub fn set_stalled_stream_protection( + &mut self, + stalled_stream_protection_config: Option, + ) -> &mut Self { + self.stalled_stream_protection_config = stalled_stream_protection_config; + self + } +} + impl SdkConfig { /// Configured region pub fn region(&self) -> Option<&Region> { @@ -633,6 +708,11 @@ impl SdkConfig { self.use_dual_stack } + /// Configured stalled stream protection + pub fn stalled_stream_protection(&self) -> Option { + self.stalled_stream_protection_config.clone() + } + /// Behavior major version configured for this client pub fn behavior_version(&self) -> Option { self.behavior_version.clone() @@ -668,6 +748,7 @@ impl SdkConfig { use_fips: self.use_fips, use_dual_stack: self.use_dual_stack, behavior_version: self.behavior_version, + stalled_stream_protection_config: self.stalled_stream_protection_config, } } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt index bce62f58aa..69a1299d73 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/SdkConfigDecorator.kt @@ -79,6 +79,10 @@ class GenericSmithySdkConfigSettings : ClientCodegenDecorator { ${section.serviceConfigBuilder}.set_http_client(${section.sdkConfig}.http_client()); ${section.serviceConfigBuilder}.set_time_source(${section.sdkConfig}.time_source()); ${section.serviceConfigBuilder}.set_behavior_version(${section.sdkConfig}.behavior_version()); + // setting `None` here removes the default + if let Some(config) = ${section.sdkConfig}.stalled_stream_protection() { + ${section.serviceConfigBuilder}.set_stalled_stream_protection(Some(config)); + } if let Some(cache) = ${section.sdkConfig}.identity_cache() { ${section.serviceConfigBuilder}.set_identity_cache(cache); diff --git a/aws/sdk/integration-tests/dynamodb/tests/endpoints.rs b/aws/sdk/integration-tests/dynamodb/tests/endpoints.rs index 2f0180b449..e430d60d43 100644 --- a/aws/sdk/integration-tests/dynamodb/tests/endpoints.rs +++ b/aws/sdk/integration-tests/dynamodb/tests/endpoints.rs @@ -17,6 +17,9 @@ async fn expect_uri( let conf = customize( aws_sdk_dynamodb::config::Builder::from(&conf) .credentials_provider(Credentials::for_tests()) + .stalled_stream_protection( + aws_sdk_dynamodb::config::StalledStreamProtectionConfig::disabled(), + ) .http_client(http_client), ) .build(); diff --git a/aws/sdk/integration-tests/dynamodb/tests/retries-with-client-rate-limiting.rs b/aws/sdk/integration-tests/dynamodb/tests/retries-with-client-rate-limiting.rs index f565e887c5..834ec850f2 100644 --- a/aws/sdk/integration-tests/dynamodb/tests/retries-with-client-rate-limiting.rs +++ b/aws/sdk/integration-tests/dynamodb/tests/retries-with-client-rate-limiting.rs @@ -3,7 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_sdk_dynamodb::config::{Credentials, Region, SharedAsyncSleep}; +use aws_sdk_dynamodb::config::{ + Credentials, Region, SharedAsyncSleep, StalledStreamProtectionConfig, +}; use aws_sdk_dynamodb::{config::retry::RetryConfig, error::ProvideErrorMetadata}; use aws_smithy_async::test_util::instant_time_and_sleep; use aws_smithy_async::time::SharedTimeSource; @@ -65,6 +67,7 @@ async fn test_adaptive_retries_with_no_throttling_errors() { let http_client = StaticReplayClient::new(events); let config = aws_sdk_dynamodb::Config::builder() + .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) .credentials_provider(Credentials::for_tests()) .region(Region::new("us-east-1")) .retry_config( @@ -120,6 +123,7 @@ async fn test_adaptive_retries_with_throttling_errors() { let http_client = StaticReplayClient::new(events); let config = aws_sdk_dynamodb::Config::builder() + .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) .credentials_provider(Credentials::for_tests()) .region(Region::new("us-east-1")) .retry_config( diff --git a/aws/sdk/integration-tests/dynamodb/tests/shared-config.rs b/aws/sdk/integration-tests/dynamodb/tests/shared-config.rs index 0ce9d0c9de..02786fe0af 100644 --- a/aws/sdk/integration-tests/dynamodb/tests/shared-config.rs +++ b/aws/sdk/integration-tests/dynamodb/tests/shared-config.rs @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -use aws_sdk_dynamodb::config::{Credentials, Region}; +use aws_sdk_dynamodb::config::{Credentials, Region, StalledStreamProtectionConfig}; use aws_smithy_runtime::client::http::test_util::capture_request; use http::Uri; @@ -12,6 +12,7 @@ use http::Uri; async fn shared_config_testbed() { let shared_config = aws_types::SdkConfig::builder() .region(Region::new("us-east-4")) + .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) .build(); let (http_client, request) = capture_request(None); let conf = aws_sdk_dynamodb::config::Builder::from(&shared_config) diff --git a/aws/sdk/integration-tests/dynamodb/tests/timeouts.rs b/aws/sdk/integration-tests/dynamodb/tests/timeouts.rs index abd63673a5..d7a41750dc 100644 --- a/aws/sdk/integration-tests/dynamodb/tests/timeouts.rs +++ b/aws/sdk/integration-tests/dynamodb/tests/timeouts.rs @@ -6,13 +6,12 @@ use std::time::Duration; use aws_credential_types::provider::SharedCredentialsProvider; -use aws_credential_types::Credentials; +use aws_sdk_dynamodb::config::{Credentials, Region, StalledStreamProtectionConfig}; use aws_sdk_dynamodb::error::SdkError; use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep, Sleep}; use aws_smithy_runtime::client::http::test_util::NeverClient; use aws_smithy_types::retry::RetryConfig; use aws_smithy_types::timeout::TimeoutConfig; -use aws_types::region::Region; use aws_types::SdkConfig; #[derive(Debug, Clone)] @@ -36,9 +35,10 @@ async fn api_call_timeout_retries() { .build(), ) .retry_config(RetryConfig::standard()) + .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) .sleep_impl(SharedAsyncSleep::new(InstantSleep)) .build(); - let client = aws_sdk_dynamodb::Client::from_conf(aws_sdk_dynamodb::Config::new(&conf)); + let client = aws_sdk_dynamodb::Client::new(&conf); let resp = client .list_tables() .send() @@ -68,6 +68,7 @@ async fn no_retries_on_operation_timeout() { .operation_timeout(Duration::new(123, 0)) .build(), ) + .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) .retry_config(RetryConfig::standard()) .sleep_impl(SharedAsyncSleep::new(InstantSleep)) .build(); diff --git a/aws/sdk/integration-tests/no-default-features/tests/client-construction.rs b/aws/sdk/integration-tests/no-default-features/tests/client-construction.rs index 68c59b4368..481077d429 100644 --- a/aws/sdk/integration-tests/no-default-features/tests/client-construction.rs +++ b/aws/sdk/integration-tests/no-default-features/tests/client-construction.rs @@ -7,7 +7,7 @@ use aws_sdk_s3::config::IdentityCache; use aws_sdk_s3::config::{ retry::RetryConfig, timeout::TimeoutConfig, BehaviorVersion, Config, Credentials, Region, - SharedAsyncSleep, Sleep, + SharedAsyncSleep, Sleep, StalledStreamProtectionConfig, }; use aws_sdk_s3::primitives::SdkBody; use aws_smithy_runtime::client::http::test_util::infallible_client_fn; @@ -143,6 +143,7 @@ async fn test_time_source_for_identity_cache() { .identity_cache(IdentityCache::lazy().build()) .credentials_provider(Credentials::for_tests()) .retry_config(RetryConfig::disabled()) + .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) .timeout_config(TimeoutConfig::disabled()) .behavior_version(BehaviorVersion::latest()) .build(); @@ -160,6 +161,7 @@ async fn behavior_mv_from_aws_config() { .credentials_provider(Credentials::for_tests()) .identity_cache(IdentityCache::no_cache()) .timeout_config(TimeoutConfig::disabled()) + .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) .region(Region::new("us-west-2")) .load() .await; @@ -183,6 +185,7 @@ async fn behavior_mv_from_client_construction() { .retry_config(RetryConfig::disabled()) .identity_cache(IdentityCache::no_cache()) .timeout_config(TimeoutConfig::disabled()) + .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) .region(Region::new("us-west-2")) .build(); let s3_client = aws_sdk_s3::Client::from_conf( diff --git a/aws/sdk/integration-tests/s3/Cargo.toml b/aws/sdk/integration-tests/s3/Cargo.toml index 19602531be..80ba48df85 100644 --- a/aws/sdk/integration-tests/s3/Cargo.toml +++ b/aws/sdk/integration-tests/s3/Cargo.toml @@ -21,7 +21,6 @@ aws-credential-types = { path = "../../build/aws-sdk/sdk/aws-credential-types", aws-http = { path = "../../build/aws-sdk/sdk/aws-http" } aws-runtime = { path = "../../build/aws-sdk/sdk/aws-runtime", features = ["test-util"] } aws-sdk-s3 = { path = "../../build/aws-sdk/sdk/s3", features = ["test-util", "behavior-version-latest"] } -# aws-sdk-sts = { path = "../../build/aws-sdk/sdk/sts" } aws-smithy-async = { path = "../../build/aws-sdk/sdk/aws-smithy-async", features = ["test-util", "rt-tokio"] } aws-smithy-http = { path = "../../build/aws-sdk/sdk/aws-smithy-http" } aws-smithy-protocol-test = { path = "../../build/aws-sdk/sdk/aws-smithy-protocol-test" } diff --git a/aws/sdk/integration-tests/s3/tests/alternative-async-runtime.rs b/aws/sdk/integration-tests/s3/tests/alternative-async-runtime.rs index 08f3765928..2c8c4f4b0a 100644 --- a/aws/sdk/integration-tests/s3/tests/alternative-async-runtime.rs +++ b/aws/sdk/integration-tests/s3/tests/alternative-async-runtime.rs @@ -5,7 +5,7 @@ use aws_config::retry::RetryConfig; use aws_credential_types::provider::SharedCredentialsProvider; -use aws_sdk_s3::config::{Credentials, Region}; +use aws_sdk_s3::config::{Credentials, Region, StalledStreamProtectionConfig}; use aws_sdk_s3::types::{ CompressionType, CsvInput, CsvOutput, ExpressionType, FileHeaderInfo, InputSerialization, OutputSerialization, @@ -147,6 +147,7 @@ async fn retry_test(sleep_impl: SharedAsyncSleep) -> Result<(), Box (impl Future, SocketAddr) { + use tokio::net::{TcpListener, TcpStream}; + use tokio::time::sleep; + + let listener = TcpListener::bind("0.0.0.0:0") + .await + .expect("socket is free"); + let bind_addr = listener.local_addr().unwrap(); + + async fn process_socket(socket: TcpStream) { + let mut buf = BytesMut::new(); + let mut time_to_stall = false; + + loop { + if time_to_stall { + debug!("faulty server has read partial request, now getting stuck"); + break; + } + + match socket.try_read_buf(&mut buf) { + Ok(0) => { + unreachable!( + "The connection will be closed before this branch is ever reached" + ); + } + Ok(n) => { + debug!("read {n} bytes from the socket"); + + // Check to see if we've received some headers + if buf.len() >= 128 { + let s = String::from_utf8_lossy(&buf); + debug!("{s}"); + + time_to_stall = true; + } + } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + debug!("reading would block, sleeping for 1ms and then trying again"); + sleep(Duration::from_millis(1)).await; + } + Err(e) => { + panic!("{e}") + } + } + } + + loop { + tokio::task::yield_now().await + } + } + + let fut = async move { + loop { + let (socket, addr) = listener + .accept() + .await + .expect("listener can accept new connections"); + debug!("server received new connection from {addr:?}"); + let start = std::time::Instant::now(); + process_socket(socket).await; + debug!( + "connection to {addr:?} closed after {:.02?}", + start.elapsed() + ); + } + }; + + (fut, bind_addr) +} + +#[tokio::test] +async fn test_explicitly_configured_stalled_stream_protection_for_downloads() { + // We spawn a faulty server that will close the connection after + // writing half of the response body. + let (server, server_addr) = start_faulty_download_server().await; + let _ = tokio::spawn(server); + + let conf = Config::builder() + .credentials_provider(Credentials::for_tests()) + .region(Region::new("us-east-1")) + .endpoint_url(format!("http://{server_addr}")) + .stalled_stream_protection( + StalledStreamProtectionConfig::enabled() + // Fail stalled streams immediately + .grace_period(Duration::from_secs(0)) + .build(), + ) + .build(); + let client = Client::from_conf(conf); + + let res = client + .get_object() + .bucket("a-test-bucket") + .key("stalled-stream-test.txt") + .send() + .await + .unwrap(); + + let err = res + .body + .collect() + .await + .expect_err("download stream stalled out"); + let err = err.source().expect("inner error exists"); + assert_eq!( + err.to_string(), + "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" + ); +} + +#[tokio::test] +async fn test_stalled_stream_protection_for_downloads_can_be_disabled() { + // We spawn a faulty server that will close the connection after + // writing half of the response body. + let (server, server_addr) = start_faulty_download_server().await; + let _ = tokio::spawn(server); + + let conf = Config::builder() + .credentials_provider(Credentials::for_tests()) + .region(Region::new("us-east-1")) + .endpoint_url(format!("http://{server_addr}")) + .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) + .build(); + let client = Client::from_conf(conf); + + let res = client + .get_object() + .bucket("a-test-bucket") + .key("stalled-stream-test.txt") + .send() + .await + .unwrap(); + + let timeout_duration = Duration::from_secs(2); + match tokio::time::timeout(timeout_duration, res.body.collect()).await { + Ok(_) => panic!("stalled stream protection kicked in but it shouldn't have"), + // If timeout elapses, then stalled stream protection didn't end the stream early. + Err(elapsed) => assert_eq!("deadline has elapsed".to_owned(), elapsed.to_string()), + } +} + +// This test will always take as long as whatever grace period is set by default. +#[tokio::test] +async fn test_stalled_stream_protection_for_downloads_is_enabled_by_default() { + // We spawn a faulty server that will close the connection after + // writing half of the response body. + let (server, server_addr) = start_faulty_download_server().await; + let _ = tokio::spawn(server); + + // Stalled stream protection should be enabled by default. + let sdk_config = aws_config::from_env() + .credentials_provider(Credentials::for_tests()) + .region(Region::new("us-east-1")) + .endpoint_url(format!("http://{server_addr}")) + .load() + .await; + let client = Client::new(&sdk_config); + + let res = client + .get_object() + .bucket("a-test-bucket") + .key("stalled-stream-test.txt") + .send() + .await + .unwrap(); + + let start = std::time::Instant::now(); + let err = res + .body + .collect() + .await + .expect_err("download stream stalled out"); + let err = err.source().expect("inner error exists"); + assert_eq!( + err.to_string(), + "minimum throughput was specified at 1 B/s, but throughput of 0 B/s was observed" + ); + // 1s check interval + 5s grace period + assert_eq!(start.elapsed().as_secs(), 6); +} + +async fn start_faulty_download_server() -> (impl Future, SocketAddr) { + use tokio::net::{TcpListener, TcpStream}; + use tokio::time::sleep; + + let listener = TcpListener::bind("0.0.0.0:0") + .await + .expect("socket is free"); + let bind_addr = listener.local_addr().unwrap(); + + async fn process_socket(socket: TcpStream) { + let mut buf = BytesMut::new(); + let response: &[u8] = br#"HTTP/1.1 200 OK +x-amz-request-id: 4B4NGF0EAWN0GE63 +content-length: 12 +etag: 3e25960a79dbc69b674cd4ec67a72c62 +content-type: application/octet-stream +server: AmazonS3 +content-encoding: +last-modified: Tue, 21 Jun 2022 16:29:14 GMT +date: Tue, 21 Jun 2022 16:29:23 GMT +x-amz-id-2: kPl+IVVZAwsN8ePUyQJZ40WD9dzaqtr4eNESArqE68GSKtVvuvCTDe+SxhTT+JTUqXB1HL4OxNM= +accept-ranges: bytes + +"#; + let mut time_to_respond = false; + + loop { + match socket.try_read_buf(&mut buf) { + Ok(0) => { + unreachable!( + "The connection will be closed before this branch is ever reached" + ); + } + Ok(n) => { + debug!("read {n} bytes from the socket"); + + // Check for CRLF to see if we've received the entire HTTP request. + if buf.ends_with(b"\r\n\r\n") { + time_to_respond = true; + } + } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + debug!("reading would block, sleeping for 1ms and then trying again"); + sleep(Duration::from_millis(1)).await; + } + Err(e) => { + panic!("{e}") + } + } + + if socket.writable().await.is_ok() && time_to_respond { + // The content length is 12 but we'll only write 5 bytes + socket.try_write(response).unwrap(); + // We break from the R/W loop after sending a partial response in order to + // close the connection early. + debug!("faulty server has written partial response, now getting stuck"); + break; + } + } + + loop { + tokio::task::yield_now().await + } + } + + let fut = async move { + loop { + let (socket, addr) = listener + .accept() + .await + .expect("listener can accept new connections"); + debug!("server received new connection from {addr:?}"); + let start = std::time::Instant::now(); + process_socket(socket).await; + debug!( + "connection to {addr:?} closed after {:.02?}", + start.elapsed() + ); + } + }; + + (fut, bind_addr) +} diff --git a/aws/sdk/integration-tests/timestreamquery/tests/endpoint_disco.rs b/aws/sdk/integration-tests/timestreamquery/tests/endpoint_disco.rs index c5a22c5d1b..d78ca76266 100644 --- a/aws/sdk/integration-tests/timestreamquery/tests/endpoint_disco.rs +++ b/aws/sdk/integration-tests/timestreamquery/tests/endpoint_disco.rs @@ -9,7 +9,7 @@ use aws_smithy_runtime::client::http::test_util::dvr::ReplayingClient; async fn do_endpoint_discovery() { use aws_credential_types::provider::SharedCredentialsProvider; use aws_sdk_timestreamquery as query; - use aws_sdk_timestreamquery::config::Credentials; + use aws_sdk_timestreamquery::config::{Credentials, StalledStreamProtectionConfig}; use aws_smithy_async::rt::sleep::SharedAsyncSleep; use aws_smithy_async::test_util::controlled_time_and_sleep; use aws_smithy_async::time::{SharedTimeSource, TimeSource}; @@ -31,6 +31,7 @@ async fn do_endpoint_discovery() { .credentials_provider(SharedCredentialsProvider::new( Credentials::for_tests_with_session_token(), )) + .stalled_stream_protection(StalledStreamProtectionConfig::disabled()) .time_source(SharedTimeSource::new(ts.clone())) .build(); let conf = query::config::Builder::from(&config) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt index 08814fadac..9f75a10ddc 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/RustClientCodegenPlugin.kt @@ -21,6 +21,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.RequiredCusto import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointParamsDecorator import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointsDecorator import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientDecorator +import software.amazon.smithy.rust.codegen.client.smithy.generators.config.StalledStreamProtectionDecorator import software.amazon.smithy.rust.codegen.client.testutil.ClientDecoratableBuildPlugin import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.NonExhaustive import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider @@ -68,6 +69,7 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() { HttpConnectorConfigDecorator(), SensitiveOutputDecorator(), IdempotencyTokenDecorator(), + StalledStreamProtectionDecorator(), *decorator, ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt index 501e4c7dd4..4d4d32a8dd 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationRuntimePluginGenerator.kt @@ -50,11 +50,12 @@ class OperationRuntimePluginGenerator( operationStructName: String, customizations: List, ) { + val layerName = operationShape.id.name.dq() writer.rustTemplate( """ impl #{RuntimePlugin} for $operationStructName { fn config(&self) -> #{Option}<#{FrozenLayer}> { - let mut cfg = #{Layer}::new(${operationShape.id.name.dq()}); + let mut cfg = #{Layer}::new($layerName); cfg.store_put(#{SharedRequestSerializer}::new(${operationStructName}RequestSerializer)); cfg.store_put(#{SharedResponseDeserializer}::new(${operationStructName}ResponseDeserializer)); @@ -68,11 +69,12 @@ class OperationRuntimePluginGenerator( } fn runtime_components(&self, _: &#{RuntimeComponentsBuilder}) -> #{Cow}<'_, #{RuntimeComponentsBuilder}> { - #{Cow}::Owned( - #{RuntimeComponentsBuilder}::new(${operationShape.id.name.dq()}) + ##[allow(unused_mut)] + let mut rcb = #{RuntimeComponentsBuilder}::new($layerName) #{interceptors} - #{retry_classifiers} - ) + #{retry_classifiers}; + + #{Cow}::Owned(rcb) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt new file mode 100644 index 0000000000..03164cbfab --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/config/StalledStreamProtectionConfigCustomization.kt @@ -0,0 +1,126 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators.config + +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.configReexport +import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization +import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope +import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization + +class StalledStreamProtectionDecorator : ClientCodegenDecorator { + override val name: String = "StalledStreamProtection" + override val order: Byte = 0 + + override fun configCustomizations( + codegenContext: ClientCodegenContext, + baseCustomizations: List, + ): List { + return baseCustomizations + StalledStreamProtectionConfigCustomization(codegenContext) + } + + override fun operationCustomizations( + codegenContext: ClientCodegenContext, + operation: OperationShape, + baseCustomizations: List, + ): List { + return baseCustomizations + StalledStreamProtectionOperationCustomization(codegenContext) + } +} + +/** + * Add a `stalled_stream_protection` field to Service config. + */ +class StalledStreamProtectionConfigCustomization(codegenContext: ClientCodegenContext) : NamedCustomization() { + private val rc = codegenContext.runtimeConfig + private val codegenScope = arrayOf( + *preludeScope, + "StalledStreamProtectionConfig" to configReexport(RuntimeType.smithyRuntimeApi(rc).resolve("client::stalled_stream_protection::StalledStreamProtectionConfig")), + ) + + override fun section(section: ServiceConfig): Writable { + return when (section) { + ServiceConfig.ConfigImpl -> writable { + rustTemplate( + """ + /// Return a reference to the stalled stream protection configuration contained in this config, if any. + pub fn stalled_stream_protection(&self) -> #{Option}<&#{StalledStreamProtectionConfig}> { + self.config.load::<#{StalledStreamProtectionConfig}>() + } + """, + *codegenScope, + ) + } + ServiceConfig.BuilderImpl -> writable { + rustTemplate( + """ + /// Set the [`StalledStreamProtectionConfig`](#{StalledStreamProtectionConfig}) + /// to configure protection for stalled streams. + pub fn stalled_stream_protection( + mut self, + stalled_stream_protection_config: #{StalledStreamProtectionConfig} + ) -> Self { + self.set_stalled_stream_protection(#{Some}(stalled_stream_protection_config)); + self + } + """, + *codegenScope, + ) + + rustTemplate( + """ + /// Set the [`StalledStreamProtectionConfig`](#{StalledStreamProtectionConfig}) + /// to configure protection for stalled streams. + pub fn set_stalled_stream_protection( + &mut self, + stalled_stream_protection_config: #{Option}<#{StalledStreamProtectionConfig}> + ) -> &mut Self { + self.config.store_or_unset(stalled_stream_protection_config); + self + } + """, + *codegenScope, + ) + } + + else -> emptySection + } + } +} + +class StalledStreamProtectionOperationCustomization( + codegenContext: ClientCodegenContext, +) : OperationCustomization() { + private val rc = codegenContext.runtimeConfig + + override fun section(section: OperationSection): Writable = writable { + when (section) { + is OperationSection.AdditionalInterceptors -> { + val stalledStreamProtectionModule = RuntimeType.smithyRuntime(rc).resolve("client::stalled_stream_protection") + section.registerInterceptor(rc, this) { + // Currently, only response bodies are protected/supported because + // we can't count on hyper to poll a request body on wake. + rustTemplate( + """ + #{StalledStreamProtectionInterceptor}::new(#{Kind}::ResponseBody) + """, + *preludeScope, + "StalledStreamProtectionInterceptor" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptor"), + "Kind" to stalledStreamProtectionModule.resolve("StalledStreamProtectionInterceptorKind"), + ) + } + } + else -> { } + } + } +} diff --git a/rust-runtime/aws-smithy-runtime-api/src/client.rs b/rust-runtime/aws-smithy-runtime-api/src/client.rs index ce011bd98a..cb5a439522 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client.rs @@ -117,4 +117,7 @@ pub mod runtime_components; pub mod runtime_plugin; pub mod behavior_version; + pub mod ser_de; + +pub mod stalled_stream_protection; diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs b/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs new file mode 100644 index 0000000000..25c9c5c67d --- /dev/null +++ b/rust-runtime/aws-smithy-runtime-api/src/client/stalled_stream_protection.rs @@ -0,0 +1,109 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#![allow(missing_docs)] + +//! Stalled stream protection. +//! +//! When enabled, upload and download streams that stall (stream no data) for +//! longer than a configured grace period will return an error. + +use aws_smithy_types::config_bag::{Storable, StoreReplace}; +use std::time::Duration; + +const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(5); + +/// Configuration for stalled stream protection. +/// +/// When enabled, download streams that stall out will be cancelled. +#[derive(Clone, Debug)] +pub struct StalledStreamProtectionConfig { + is_enabled: bool, + grace_period: Duration, +} + +impl StalledStreamProtectionConfig { + /// Create a new config that enables stalled stream protection. + pub fn enabled() -> Builder { + Builder { + is_enabled: Some(true), + grace_period: None, + } + } + + /// Create a new config that disables stalled stream protection. + pub fn disabled() -> Self { + Self { + is_enabled: false, + grace_period: DEFAULT_GRACE_PERIOD, + } + } + + /// Return whether stalled stream protection is enabled. + pub fn is_enabled(&self) -> bool { + self.is_enabled + } + + /// Return the grace period for stalled stream protection. + /// + /// When a stream stalls for longer than this grace period, the stream will + /// return an error. + pub fn grace_period(&self) -> Duration { + self.grace_period + } +} + +#[derive(Clone, Debug)] +pub struct Builder { + is_enabled: Option, + grace_period: Option, +} + +impl Builder { + /// Set the grace period for stalled stream protection. + pub fn grace_period(mut self, grace_period: Duration) -> Self { + self.grace_period = Some(grace_period); + self + } + + /// Set the grace period for stalled stream protection. + pub fn set_grace_period(&mut self, grace_period: Option) -> &mut Self { + self.grace_period = grace_period; + self + } + + /// Set whether stalled stream protection is enabled. + pub fn is_enabled(mut self, is_enabled: bool) -> Self { + self.is_enabled = Some(is_enabled); + self + } + + /// Set whether stalled stream protection is enabled. + pub fn set_is_enabled(&mut self, is_enabled: Option) -> &mut Self { + self.is_enabled = is_enabled; + self + } + + /// Build the config. + pub fn build(self) -> StalledStreamProtectionConfig { + StalledStreamProtectionConfig { + is_enabled: self.is_enabled.unwrap_or_default(), + grace_period: self.grace_period.unwrap_or(DEFAULT_GRACE_PERIOD), + } + } +} + +impl From for Builder { + fn from(config: StalledStreamProtectionConfig) -> Self { + Builder { + is_enabled: Some(config.is_enabled), + grace_period: Some(config.grace_period), + } + } +} + +impl Storable for StalledStreamProtectionConfig { + type Storer = StoreReplace; +} diff --git a/rust-runtime/aws-smithy-runtime/src/client.rs b/rust-runtime/aws-smithy-runtime/src/client.rs index 4457ef2ac8..10591578fc 100644 --- a/rust-runtime/aws-smithy-runtime/src/client.rs +++ b/rust-runtime/aws-smithy-runtime/src/client.rs @@ -41,3 +41,6 @@ pub mod identity; /// Interceptors for Smithy clients. pub mod interceptors; + +/// Stalled stream protection for clients +pub mod stalled_stream_protection; diff --git a/rust-runtime/aws-smithy-runtime/src/client/defaults.rs b/rust-runtime/aws-smithy-runtime/src/client/defaults.rs index 6e78d00cc5..ca2d06a387 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/defaults.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/defaults.rs @@ -23,11 +23,13 @@ use aws_smithy_runtime_api::client::runtime_components::{ use aws_smithy_runtime_api::client::runtime_plugin::{ Order, SharedRuntimePlugin, StaticRuntimePlugin, }; +use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig; use aws_smithy_runtime_api::shared::IntoShared; use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer}; use aws_smithy_types::retry::RetryConfig; use aws_smithy_types::timeout::TimeoutConfig; use std::borrow::Cow; +use std::time::Duration; fn default_plugin(name: &'static str, components_fn: CompFn) -> StaticRuntimePlugin where @@ -164,6 +166,59 @@ pub fn default_identity_cache_plugin() -> Option { ) } +/// Runtime plugin that sets the default stalled stream protection config. +/// +/// By default, when throughput falls below 1/Bs for more than 5 seconds, the +/// stream is cancelled. +pub fn default_stalled_stream_protection_config_plugin() -> Option { + Some( + default_plugin( + "default_stalled_stream_protection_config_plugin", + |components| { + components.with_config_validator(SharedConfigValidator::base_client_config_fn( + validate_stalled_stream_protection_config, + )) + }, + ) + .with_config(layer("default_stalled_stream_protection_config", |layer| { + layer.store_put( + StalledStreamProtectionConfig::enabled() + .grace_period(Duration::from_secs(5)) + .build(), + ); + })) + .into_shared(), + ) +} + +fn validate_stalled_stream_protection_config( + components: &RuntimeComponentsBuilder, + cfg: &ConfigBag, +) -> Result<(), BoxError> { + if let Some(stalled_stream_protection_config) = cfg.load::() { + if stalled_stream_protection_config.is_enabled() { + if components.sleep_impl().is_none() { + return Err( + "An async sleep implementation is required for stalled stream protection to work. \ + Please provide a `sleep_impl` on the config, or disable stalled stream protection.".into()); + } + + if components.time_source().is_none() { + return Err( + "A time source is required for stalled stream protection to work.\ + Please provide a `time_source` on the config, or disable stalled stream protection.".into()); + } + } + + Ok(()) + } else { + Err( + "The default stalled stream protection config was removed, and no other config was put in its place." + .into(), + ) + } +} + /// Arguments for the [`default_plugins`] method. /// /// This is a struct to enable adding new parameters in the future without breaking the API. @@ -208,6 +263,7 @@ pub fn default_plugins( default_sleep_impl_plugin(), default_time_source_plugin(), default_timeout_config_plugin(), + default_stalled_stream_protection_config_plugin(), ] .into_iter() .flatten() diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs index 2ddcc6c3f7..006de72cfd 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput.rs @@ -7,20 +7,22 @@ //! //! If data is being streamed too slowly, this body type will emit an error next time it's polled. +/// An implementation of v0.4 `http_body::Body` for `MinimumThroughputBody` and related code. +pub mod http_body_0_4_x; + +/// Options for a [`MinimumThroughputBody`]. +pub mod options; +mod throughput; + use aws_smithy_async::rt::sleep::Sleep; use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep}; use aws_smithy_async::time::{SharedTimeSource, TimeSource}; use aws_smithy_runtime_api::box_error::BoxError; use aws_smithy_runtime_api::shared::IntoShared; +use options::MinimumThroughputBodyOptions; use std::fmt; -use std::time::Duration; use throughput::{Throughput, ThroughputLogs}; -/// An implementation of v0.4 `http_body::Body` for `MinimumThroughputBody` and related code. -pub mod http_body_0_4_x; - -mod throughput; - pin_project_lite::pin_project! { /// A body-wrapping type that ensures data is being streamed faster than some lower limit. /// @@ -28,11 +30,13 @@ pin_project_lite::pin_project! { pub struct MinimumThroughputBody { async_sleep: SharedAsyncSleep, time_source: SharedTimeSource, - minimum_throughput: Throughput, + options: MinimumThroughputBodyOptions, throughput_logs: ThroughputLogs, #[pin] sleep_fut: Option, #[pin] + grace_period_fut: Option, + #[pin] inner: B, } } @@ -46,26 +50,26 @@ impl MinimumThroughputBody { time_source: impl TimeSource + 'static, async_sleep: impl AsyncSleep + 'static, body: B, - (bytes_read, per_time_elapsed): (u64, Duration), + options: MinimumThroughputBodyOptions, ) -> Self { - let minimum_throughput = Throughput::new(bytes_read as f64, per_time_elapsed); Self { throughput_logs: ThroughputLogs::new( // Never keep more than 10KB of logs in memory. This currently // equates to 426 logs. (NUMBER_OF_LOGS_IN_ONE_KB * 10.0) as usize, - minimum_throughput.per_time_elapsed(), + options.minimum_throughput().per_time_elapsed(), ), async_sleep: async_sleep.into_shared(), time_source: time_source.into_shared(), - minimum_throughput, inner: body, sleep_fut: None, + grace_period_fut: None, + options, } } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] enum Error { ThroughputBelowMinimum { expected: Throughput, diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs index 46b401a305..b8f2ffd3ec 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/http_body_0_4_x.rs @@ -25,10 +25,6 @@ where // Attempt to read the data from the inner body, then update the // throughput logs. let mut this = self.as_mut().project(); - // Push a start log if we haven't already done so. - if this.throughput_logs.is_empty() { - this.throughput_logs.push((now, 0)); - } let poll_res = match this.inner.poll_data(cx) { Poll::Ready(Some(Ok(bytes))) => { this.throughput_logs.push((now, bytes.len() as u64)); @@ -38,17 +34,21 @@ where // If we've read all the data or an error occurred, then return that result. res => return res, }; + // Push a start log if we haven't already done so. + if this.throughput_logs.is_empty() { + this.throughput_logs.push((now, 0)); + } // Check the sleep future to see if it needs refreshing. let mut sleep_fut = this.sleep_fut.take().unwrap_or_else(|| { this.async_sleep - .sleep(this.minimum_throughput.per_time_elapsed()) + .sleep(this.options.minimum_throughput().per_time_elapsed()) }); if let Poll::Ready(()) = pin!(&mut sleep_fut).poll(cx) { // Whenever the sleep future expires, we replace it. sleep_fut = this .async_sleep - .sleep(this.minimum_throughput.per_time_elapsed()); + .sleep(this.options.minimum_throughput().per_time_elapsed()); // We also schedule a wake up for current task to ensure that // it gets polled at least one more time. @@ -56,19 +56,33 @@ where }; this.sleep_fut.replace(sleep_fut); - // Calculate the current throughput and emit an error if it's too low. + // Calculate the current throughput and emit an error if it's too low and + // the grace period has elapsed. let actual_throughput = this.throughput_logs.calculate_throughput(now); let is_below_minimum_throughput = actual_throughput - .map(|t| t < self.minimum_throughput) + .map(|t| t < this.options.minimum_throughput()) .unwrap_or_default(); if is_below_minimum_throughput { - Poll::Ready(Some(Err(Box::new(Error::ThroughputBelowMinimum { - expected: self.minimum_throughput, - actual: actual_throughput.unwrap(), - })))) + // Check the grace period future to see if it needs creating. + let mut grace_period_fut = this + .grace_period_fut + .take() + .unwrap_or_else(|| this.async_sleep.sleep(this.options.grace_period())); + if let Poll::Ready(()) = pin!(&mut grace_period_fut).poll(cx) { + // The grace period has ended! + return Poll::Ready(Some(Err(Box::new(Error::ThroughputBelowMinimum { + expected: self.options.minimum_throughput(), + actual: actual_throughput.unwrap(), + })))); + }; + this.grace_period_fut.replace(grace_period_fut); } else { - poll_res + // Ensure we don't have an active grace period future if we're not + // currently below the minimum throughput. + let _ = this.grace_period_fut.take(); } + + poll_res } fn poll_trailers( @@ -84,6 +98,7 @@ where #[cfg(all(test, feature = "connector-hyper-0-14-x"))] mod test { use super::{super::Throughput, Error, MinimumThroughputBody}; + use crate::client::http::body::minimum_throughput::options::MinimumThroughputBodyOptions; use aws_smithy_async::rt::sleep::AsyncSleep; use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep, ManualTimeSource}; use aws_smithy_types::body::SdkBody; @@ -129,7 +144,7 @@ mod test { time_source.clone(), async_sleep.clone(), NeverBody, - (1, Duration::from_secs(1)), + Default::default(), ); time_source.advance(Duration::from_secs(1)); let actual_err = body.data().await.expect("next chunk exists").unwrap_err(); @@ -164,7 +179,7 @@ mod test { Lazy::new(|| (1..=255).flat_map(|_| b"00000000").copied().collect()); fn eight_byte_per_second_stream_with_minimum_throughput_timeout( - minimum_throughput: (u64, Duration), + minimum_throughput: Throughput, ) -> ( impl Future>, ManualTimeSource, @@ -187,18 +202,20 @@ mod test { time_source, async_sleep, body, - minimum_throughput, + MinimumThroughputBodyOptions::builder() + .minimum_throughput(minimum_throughput) + .build(), )) }); (body.collect(), time_source, async_sleep) } - async fn expect_error(minimum_throughput: (u64, Duration)) { + async fn expect_error(minimum_throughput: Throughput) { let (res, ..) = eight_byte_per_second_stream_with_minimum_throughput_timeout(minimum_throughput); let expected_err = Error::ThroughputBelowMinimum { - expected: minimum_throughput.into(), + expected: minimum_throughput, actual: Throughput::new(8.889, Duration::from_secs(1)), }; match res.await { @@ -219,11 +236,11 @@ mod test { #[tokio::test] async fn test_throughput_timeout_less_than() { - let minimum_throughput = (9, Duration::from_secs(1)); + let minimum_throughput = Throughput::new_bytes_per_second(9.0); expect_error(minimum_throughput).await; } - async fn expect_success(minimum_throughput: (u64, Duration)) { + async fn expect_success(minimum_throughput: Throughput) { let (res, time_source, async_sleep) = eight_byte_per_second_stream_with_minimum_throughput_timeout(minimum_throughput); match res.await { @@ -238,13 +255,13 @@ mod test { #[tokio::test] async fn test_throughput_timeout_equal_to() { - let minimum_throughput = (32, Duration::from_secs(4)); + let minimum_throughput = Throughput::new(32.0, Duration::from_secs(4)); expect_success(minimum_throughput).await; } #[tokio::test] async fn test_throughput_timeout_greater_than() { - let minimum_throughput = (20, Duration::from_secs(3)); + let minimum_throughput = Throughput::new(20.0, Duration::from_secs(3)); expect_success(minimum_throughput).await; } @@ -279,11 +296,12 @@ mod test { #[tokio::test] async fn test_throughput_timeout_shrinking_sine_wave() { - // Minimum throughput per second will be approx. half of the BYTE_COUNT_UPPER_LIMIT. - let minimum_throughput = ( - BYTE_COUNT_UPPER_LIMIT as u64 / 2 + 2, - Duration::from_secs(1), - ); + let options = MinimumThroughputBodyOptions::builder() + // Minimum throughput per second will be approx. half of the BYTE_COUNT_UPPER_LIMIT. + .minimum_throughput(Throughput::new_bytes_per_second( + BYTE_COUNT_UPPER_LIMIT / 2.0 + 2.0, + )) + .build(); let (time_source, async_sleep) = instant_time_and_sleep(UNIX_EPOCH); let time_clone = time_source.clone(); @@ -301,7 +319,7 @@ mod test { time_source, async_sleep, body, - minimum_throughput, + options.clone(), )) }) .collect(); diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs new file mode 100644 index 0000000000..8618707f73 --- /dev/null +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/options.rs @@ -0,0 +1,161 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use super::Throughput; +use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig; +use std::time::Duration; + +/// A collection of options for configuring a [`MinimumThroughputBody`](super::MinimumThroughputBody). +#[derive(Debug, Clone)] +pub struct MinimumThroughputBodyOptions { + /// The minimum throughput that is acceptable. + minimum_throughput: Throughput, + /// The 'grace period' after which the minimum throughput will be enforced. + /// + /// If this is set to 0, the minimum throughput will be enforced immediately. + /// + /// If this is set to a positive value, whenever throughput is below the minimum throughput, + /// a timer is started. If the timer expires before throughput rises above the minimum, + /// an error is emitted. + grace_period: Duration, + /// The interval at which the throughput is checked. + check_interval: Duration, +} + +impl MinimumThroughputBodyOptions { + /// Create a new builder. + pub fn builder() -> MinimumThroughputBodyOptionsBuilder { + Default::default() + } + + /// Convert this struct into a builder. + pub fn to_builder(self) -> MinimumThroughputBodyOptionsBuilder { + MinimumThroughputBodyOptionsBuilder::new() + .minimum_throughput(self.minimum_throughput) + .grace_period(self.grace_period) + .check_interval(self.check_interval) + } + + /// The throughput check grace period. + /// + /// If throughput is below the minimum for longer than this period, an error is emitted. + /// + /// If this is set to 0, the minimum throughput will be enforced immediately. + pub fn grace_period(&self) -> Duration { + self.grace_period + } + + /// The minimum acceptable throughput + pub fn minimum_throughput(&self) -> Throughput { + self.minimum_throughput + } + + /// The rate at which the throughput is checked. + /// + /// The actual rate throughput is checked may be higher than this value, + /// but it will never be lower. + pub fn check_interval(&self) -> Duration { + self.check_interval + } +} + +impl Default for MinimumThroughputBodyOptions { + fn default() -> Self { + Self { + minimum_throughput: DEFAULT_MINIMUM_THROUGHPUT, + grace_period: DEFAULT_GRACE_PERIOD, + check_interval: DEFAULT_CHECK_INTERVAL, + } + } +} + +/// A builder for [`MinimumThroughputBodyOptions`] +#[derive(Debug, Default, Clone)] +pub struct MinimumThroughputBodyOptionsBuilder { + minimum_throughput: Option, + check_interval: Option, + grace_period: Option, +} + +const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(1); +const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(0); +const DEFAULT_MINIMUM_THROUGHPUT: Throughput = Throughput { + bytes_read: 1.0, + per_time_elapsed: Duration::from_secs(1), +}; + +impl MinimumThroughputBodyOptionsBuilder { + /// Create a new `MinimumThroughputBodyOptionsBuilder`. + pub fn new() -> Self { + Default::default() + } + + /// Set the amount of time that throughput my fall below minimum before an error is emitted. + /// + /// If throughput rises above the minimum, the timer is reset. + pub fn grace_period(mut self, grace_period: Duration) -> Self { + self.set_grace_period(Some(grace_period)); + self + } + + /// Set the amount of time that throughput my fall below minimum before an error is emitted. + /// + /// If throughput rises above the minimum, the timer is reset. + pub fn set_grace_period(&mut self, grace_period: Option) -> &mut Self { + self.grace_period = grace_period; + self + } + + /// Set the minimum allowable throughput. + pub fn minimum_throughput(mut self, minimum_throughput: Throughput) -> Self { + self.set_minimum_throughput(Some(minimum_throughput)); + self + } + + /// Set the minimum allowable throughput. + pub fn set_minimum_throughput(&mut self, minimum_throughput: Option) -> &mut Self { + self.minimum_throughput = minimum_throughput; + self + } + + /// Set the rate at which throughput is checked. + /// + /// Defaults to 1 second. + pub fn check_interval(mut self, check_interval: Duration) -> Self { + self.set_check_interval(Some(check_interval)); + self + } + + /// Set the rate at which throughput is checked. + /// + /// Defaults to 1 second. + pub fn set_check_interval(&mut self, check_interval: Option) -> &mut Self { + self.check_interval = check_interval; + self + } + + /// Build this builder, producing a [`MinimumThroughputBodyOptions`]. + /// + /// Unset fields will be set with defaults. + pub fn build(self) -> MinimumThroughputBodyOptions { + MinimumThroughputBodyOptions { + grace_period: self.grace_period.unwrap_or(DEFAULT_GRACE_PERIOD), + minimum_throughput: self + .minimum_throughput + .unwrap_or(DEFAULT_MINIMUM_THROUGHPUT), + check_interval: self.check_interval.unwrap_or(DEFAULT_CHECK_INTERVAL), + } + } +} + +impl From for MinimumThroughputBodyOptions { + fn from(value: StalledStreamProtectionConfig) -> Self { + MinimumThroughputBodyOptions { + grace_period: value.grace_period(), + minimum_throughput: DEFAULT_MINIMUM_THROUGHPUT, + check_interval: DEFAULT_CHECK_INTERVAL, + } + } +} diff --git a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs index 3b610f857e..6186f8c9cc 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/http/body/minimum_throughput/throughput.rs @@ -8,13 +8,14 @@ use std::fmt; use std::time::{Duration, SystemTime}; #[derive(Debug, Clone, Copy)] -pub(super) struct Throughput { - bytes_read: f64, - per_time_elapsed: Duration, +pub struct Throughput { + pub(super) bytes_read: f64, + pub(super) per_time_elapsed: Duration, } impl Throughput { - pub(super) fn new(bytes_read: f64, per_time_elapsed: Duration) -> Self { + /// Create a new throughput with the given bytes read and time elapsed. + pub fn new(bytes_read: f64, per_time_elapsed: Duration) -> Self { debug_assert!( !bytes_read.is_nan(), "cannot create a throughput if bytes_read == NaN" @@ -34,6 +35,30 @@ impl Throughput { } } + /// Create a new throughput in bytes per second. + pub fn new_bytes_per_second(bytes: f64) -> Self { + Self { + bytes_read: bytes, + per_time_elapsed: Duration::from_secs(1), + } + } + + /// Create a new throughput in kilobytes per second. + pub fn new_kilobytes_per_second(kilobytes: f64) -> Self { + Self { + bytes_read: kilobytes * 1000.0, + per_time_elapsed: Duration::from_secs(1), + } + } + + /// Create a new throughput in megabytes per second. + pub fn new_megabytes_per_second(megabytes: f64) -> Self { + Self { + bytes_read: megabytes * 1000.0 * 1000.0, + per_time_elapsed: Duration::from_secs(1), + } + } + pub(super) fn per_time_elapsed(&self) -> Duration { self.per_time_elapsed } diff --git a/rust-runtime/aws-smithy-runtime/src/client/identity/cache/lazy.rs b/rust-runtime/aws-smithy-runtime/src/client/identity/cache/lazy.rs index 1ed06cba18..bd1d284017 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/identity/cache/lazy.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/identity/cache/lazy.rs @@ -689,6 +689,7 @@ mod tests { .unwrap(); let (cache, _) = test_cache(BUFFER_TIME_NO_JITTER, Vec::new()); + #[allow(clippy::disallowed_methods)] let far_future = SystemTime::now() + Duration::from_secs(10_000); // Resolver A and B both return an identical identity type with different tokens with an expiration diff --git a/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs new file mode 100644 index 0000000000..3e07b3f0b8 --- /dev/null +++ b/rust-runtime/aws-smithy-runtime/src/client/stalled_stream_protection.rs @@ -0,0 +1,138 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use crate::client::http::body::minimum_throughput::MinimumThroughputBody; +use aws_smithy_async::rt::sleep::SharedAsyncSleep; +use aws_smithy_async::time::SharedTimeSource; +use aws_smithy_runtime_api::box_error::BoxError; +use aws_smithy_runtime_api::client::interceptors::context::{ + BeforeDeserializationInterceptorContextMut, BeforeTransmitInterceptorContextMut, +}; +use aws_smithy_runtime_api::client::interceptors::Intercept; +use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; +use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig; +use aws_smithy_types::body::SdkBody; +use aws_smithy_types::config_bag::ConfigBag; +use std::mem; + +/// Adds stalled stream protection when sending requests and/or receiving responses. +#[derive(Debug)] +pub struct StalledStreamProtectionInterceptor { + enable_for_request_body: bool, + enable_for_response_body: bool, +} + +/// Stalled stream protection can be enable for request bodies, response bodies, +/// or both. +pub enum StalledStreamProtectionInterceptorKind { + /// Enable stalled stream protection for request bodies. + RequestBody, + /// Enable stalled stream protection for response bodies. + ResponseBody, + /// Enable stalled stream protection for both request and response bodies. + RequestAndResponseBody, +} + +impl StalledStreamProtectionInterceptor { + /// Create a new stalled stream protection interceptor. + pub fn new(kind: StalledStreamProtectionInterceptorKind) -> Self { + use StalledStreamProtectionInterceptorKind::*; + let (enable_for_request_body, enable_for_response_body) = match kind { + RequestBody => (true, false), + ResponseBody => (false, true), + RequestAndResponseBody => (true, true), + }; + + Self { + enable_for_request_body, + enable_for_response_body, + } + } +} + +impl Intercept for StalledStreamProtectionInterceptor { + fn name(&self) -> &'static str { + "StalledStreamProtectionInterceptor" + } + + fn modify_before_transmit( + &self, + context: &mut BeforeTransmitInterceptorContextMut<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + if self.enable_for_request_body { + if let Some(cfg) = cfg.load::() { + if cfg.is_enabled() { + let (async_sleep, time_source) = + get_runtime_component_deps(runtime_components)?; + tracing::trace!("adding stalled stream protection to request body"); + add_stalled_stream_protection_to_body( + context.request_mut().body_mut(), + cfg, + async_sleep, + time_source, + ); + } + } + } + + Ok(()) + } + + fn modify_before_deserialization( + &self, + context: &mut BeforeDeserializationInterceptorContextMut<'_>, + runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + if self.enable_for_response_body { + if let Some(cfg) = cfg.load::() { + if cfg.is_enabled() { + let (async_sleep, time_source) = + get_runtime_component_deps(runtime_components)?; + tracing::trace!("adding stalled stream protection to response body"); + add_stalled_stream_protection_to_body( + context.response_mut().body_mut(), + cfg, + async_sleep, + time_source, + ); + } + } + } + Ok(()) + } +} + +fn get_runtime_component_deps( + runtime_components: &RuntimeComponents, +) -> Result<(SharedAsyncSleep, SharedTimeSource), BoxError> { + let async_sleep = runtime_components.sleep_impl().ok_or( + "An async sleep implementation is required when stalled stream protection is enabled", + )?; + let time_source = runtime_components + .time_source() + .ok_or("A time source is required when stalled stream protection is enabled")?; + Ok((async_sleep, time_source)) +} + +fn add_stalled_stream_protection_to_body( + body: &mut SdkBody, + cfg: &StalledStreamProtectionConfig, + async_sleep: SharedAsyncSleep, + time_source: SharedTimeSource, +) { + let cfg = cfg.clone(); + let it = mem::replace(body, SdkBody::taken()); + let it = it.map_preserve_contents(move |body| { + let cfg = cfg.clone(); + let async_sleep = async_sleep.clone(); + let time_source = time_source.clone(); + let mtb = MinimumThroughputBody::new(time_source, async_sleep, body, cfg.into()); + SdkBody::from_body_0_4(mtb) + }); + let _ = mem::replace(body, it); +}