diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index b08ebefd4e6..289f8778ecd 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -177,3 +177,9 @@ message = "Source defaults from the default trait instead of implicitly based on references = ["smithy-rs#2985"] meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" } author = "rcoh" + +[[smithy-rs]] +message = "`StaticUriEndpointResolver`'s `uri` constructor now takes a `String` instead of a `Uri`." +references = ["TODO"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" } +author = "jdisanti" diff --git a/aws/rust-runtime/aws-config/Cargo.toml b/aws/rust-runtime/aws-config/Cargo.toml index 60fb0f52618..2e9d00f08a1 100644 --- a/aws/rust-runtime/aws-config/Cargo.toml +++ b/aws/rust-runtime/aws-config/Cargo.toml @@ -9,7 +9,7 @@ license = "Apache-2.0" repository = "https://github.com/awslabs/smithy-rs" [features] -client-hyper = ["aws-smithy-client/client-hyper"] +client-hyper = ["aws-smithy-client/client-hyper", "aws-smithy-runtime/connector-hyper"] rustls = ["aws-smithy-client/rustls", "client-hyper"] native-tls = [] allow-compilation = [] # our tests use `cargo test --all-features` and native-tls breaks CI @@ -27,7 +27,10 @@ aws-smithy-client = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-client", de aws-smithy-http = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-http" } aws-smithy-http-tower = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-http-tower" } aws-smithy-json = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-json" } +aws-smithy-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime", features = ["client"] } +aws-smithy-runtime-api = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime-api", features = ["client"] } aws-smithy-types = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-types" } +aws-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-runtime" } aws-types = { path = "../../sdk/build/aws-sdk/sdk/aws-types" } hyper = { version = "0.14.26", default-features = false } time = { version = "0.3.4", features = ["parsing"] } @@ -48,6 +51,7 @@ hex = { version = "0.4.3", optional = true } zeroize = { version = "1", optional = true } [dev-dependencies] +aws-smithy-runtime = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-runtime", features = ["client", "test-util"] } futures-util = { version = "0.3.16", default-features = false } tracing-test = "0.2.1" tracing-subscriber = { version = "0.3.16", features = ["fmt", "json"] } diff --git a/aws/rust-runtime/aws-config/examples/imds.rs b/aws/rust-runtime/aws-config/examples/imds.rs index 3722835e59b..b1da233363f 100644 --- a/aws/rust-runtime/aws-config/examples/imds.rs +++ b/aws/rust-runtime/aws-config/examples/imds.rs @@ -12,8 +12,8 @@ async fn main() -> Result<(), Box> { use aws_config::imds::Client; - let imds = Client::builder().build().await?; + let imds = Client::builder().build(); let instance_id = imds.get("/latest/meta-data/instance-id").await?; - println!("current instance id: {}", instance_id); + println!("current instance id: {}", instance_id.as_ref()); Ok(()) } diff --git a/aws/rust-runtime/aws-config/src/ecs.rs b/aws/rust-runtime/aws-config/src/ecs.rs index 17fab949f5b..f5ca10e54c1 100644 --- a/aws/rust-runtime/aws-config/src/ecs.rs +++ b/aws/rust-runtime/aws-config/src/ecs.rs @@ -55,7 +55,7 @@ use aws_credential_types::provider::{self, error::CredentialsError, future, Prov use aws_smithy_client::erase::boxclone::BoxCloneService; use aws_smithy_http::endpoint::apply_endpoint; use aws_smithy_types::error::display::DisplayErrorContext; -use http::uri::{InvalidUri, Scheme}; +use http::uri::{InvalidUri, PathAndQuery, Scheme}; use http::{HeaderValue, Uri}; use tower::{Service, ServiceExt}; @@ -166,6 +166,15 @@ impl Provider { Err(EcsConfigurationError::NotConfigured) => return Provider::NotConfigured, Err(err) => return Provider::InvalidConfiguration(err), }; + let path = uri.path().to_string(); + let endpoint = { + let mut parts = uri.into_parts(); + parts.path_and_query = Some(PathAndQuery::from_static("/")); + Uri::from_parts(parts) + } + .expect("parts will be valid") + .to_string(); + let http_provider = HttpCredentialProvider::builder() .configure(&provider_config) .connector_settings( @@ -174,7 +183,7 @@ impl Provider { .read_timeout(DEFAULT_READ_TIMEOUT) .build(), ) - .build("EcsContainer", uri); + .build("EcsContainer", &endpoint, path); Provider::Configured(http_provider) } diff --git a/aws/rust-runtime/aws-config/src/http_credential_provider.rs b/aws/rust-runtime/aws-config/src/http_credential_provider.rs index 2568cc435d2..c3b08b80e02 100644 --- a/aws/rust-runtime/aws-config/src/http_credential_provider.rs +++ b/aws/rust-runtime/aws-config/src/http_credential_provider.rs @@ -8,35 +8,44 @@ //! //! Future work will stabilize this interface and enable it to be used directly. +use crate::connector::expect_connector; +use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials}; +use crate::provider_config::ProviderConfig; use aws_credential_types::provider::{self, error::CredentialsError}; use aws_credential_types::Credentials; -use aws_smithy_client::erase::DynConnector; +use aws_sdk_sso::config::interceptors::InterceptorContext; use aws_smithy_client::http_connector::ConnectorSettings; use aws_smithy_http::body::SdkBody; -use aws_smithy_http::operation::{Operation, Request}; -use aws_smithy_http::response::ParseStrictResponse; -use aws_smithy_http::result::{SdkError, SdkSuccess}; -use aws_smithy_http::retry::ClassifyRetry; -use aws_smithy_types::retry::{ErrorKind, RetryKind}; - -use crate::connector::expect_connector; -use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials}; -use crate::provider_config::ProviderConfig; - -use bytes::Bytes; +use aws_smithy_http::result::SdkError; +use aws_smithy_runtime::client::connectors::adapter::DynConnectorAdapter; +use aws_smithy_runtime::client::orchestrator::operation::Operation; +use aws_smithy_runtime::client::retries::classifier::{ + HttpStatusCodeClassifier, SmithyErrorClassifier, +}; +use aws_smithy_runtime_api::client::connectors::SharedHttpConnector; +use aws_smithy_runtime_api::client::interceptors::context::Error; +use aws_smithy_runtime_api::client::orchestrator::{ + HttpResponse, OrchestratorError, SensitiveOutput, +}; +use aws_smithy_runtime_api::client::retries::{ClassifyRetry, RetryClassifiers, RetryReason}; +use aws_smithy_runtime_api::client::runtime_plugin::{SharedRuntimePlugin, StaticRuntimePlugin}; +use aws_smithy_types::config_bag::Layer; +use aws_smithy_types::retry::{ErrorKind, RetryConfig}; use http::header::{ACCEPT, AUTHORIZATION}; -use http::{HeaderValue, Response, Uri}; +use http::{HeaderValue, Response}; use std::time::Duration; -use tower::layer::util::Identity; const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5); const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2); +#[derive(Debug)] +struct OpInput { + auth: Option, +} + #[derive(Debug)] pub(crate) struct HttpCredentialProvider { - uri: Uri, - client: aws_smithy_client::Client, - provider_name: &'static str, + operation: Operation, } impl HttpCredentialProvider { @@ -45,34 +54,13 @@ impl HttpCredentialProvider { } pub(crate) async fn credentials(&self, auth: Option) -> provider::Result { - let credentials = self.client.call(self.operation(auth)).await; + let credentials = self.operation.invoke(OpInput { auth }).await; match credentials { Ok(creds) => Ok(creds), Err(SdkError::ServiceError(context)) => Err(context.into_err()), Err(other) => Err(CredentialsError::unhandled(other)), } } - - fn operation( - &self, - auth: Option, - ) -> Operation { - let mut http_req = http::Request::builder() - .uri(&self.uri) - .header(ACCEPT, "application/json"); - - if let Some(auth) = auth { - http_req = http_req.header(AUTHORIZATION, auth); - } - let http_req = http_req.body(SdkBody::empty()).expect("valid request"); - Operation::new( - Request::new(http_req), - CredentialsResponseParser { - provider_name: self.provider_name, - }, - ) - .with_retry_classifier(HttpCredentialRetryClassifier) - } } #[derive(Default)] @@ -92,7 +80,12 @@ impl Builder { self } - pub(crate) fn build(self, provider_name: &'static str, uri: Uri) -> HttpCredentialProvider { + pub(crate) fn build( + self, + provider_name: &'static str, + endpoint: &str, + path: impl Into, + ) -> HttpCredentialProvider { let provider_config = self.provider_config.unwrap_or_default(); let connector_settings = self.connector_settings.unwrap_or_else(|| { ConnectorSettings::builder() @@ -104,198 +97,243 @@ impl Builder { "The HTTP credentials provider", provider_config.connector(&connector_settings), ); - let mut client_builder = aws_smithy_client::Client::builder() - .connector(connector) - .middleware(Identity::new()); - client_builder.set_sleep_impl(provider_config.sleep()); - let client = client_builder.build(); - HttpCredentialProvider { - uri, - client, - provider_name, + + // The following errors are retryable: + // - Socket errors + // - Networking timeouts + // - 5xx errors + // - Non-parseable 200 responses. + let retry_classifiers = RetryClassifiers::new() + .with_classifier(HttpCredentialRetryClassifier) + // Socket errors and network timeouts + .with_classifier(SmithyErrorClassifier::::new()) + // 5xx errors + .with_classifier(HttpStatusCodeClassifier::default()); + + let mut builder = Operation::builder() + .service_name("HttpCredentialProvider") + .operation_name("LoadCredentials") + .http_connector(SharedHttpConnector::new(DynConnectorAdapter::new( + connector, + ))) + .endpoint_url(endpoint) + .no_auth() + .runtime_plugin(SharedRuntimePlugin::new( + StaticRuntimePlugin::new().with_config({ + let mut layer = Layer::new("SensitiveOutput"); + layer.store_put(SensitiveOutput); + layer.freeze() + }), + )); + if let Some(sleep_impl) = provider_config.sleep() { + builder = builder + .standard_retry(&RetryConfig::standard()) + .retry_classifiers(retry_classifiers) + .sleep_impl(sleep_impl); + } else { + builder = builder.no_retry(); } + let path = path.into(); + let operation = builder + .serializer(move |input: OpInput| { + let mut http_req = http::Request::builder() + .uri(path.clone()) + .header(ACCEPT, "application/json"); + if let Some(auth) = input.auth { + http_req = http_req.header(AUTHORIZATION, auth); + } + Ok(http_req.body(SdkBody::empty()).expect("valid request")) + }) + .deserializer(move |response| parse_response(provider_name, response)) + .build(); + HttpCredentialProvider { operation } } } -#[derive(Clone, Debug)] -struct CredentialsResponseParser { +fn parse_response( provider_name: &'static str, -} -impl ParseStrictResponse for CredentialsResponseParser { - type Output = provider::Result; - - fn parse(&self, response: &Response) -> Self::Output { - if !response.status().is_success() { - return Err(CredentialsError::provider_error(format!( + response: &Response, +) -> Result> { + if !response.status().is_success() { + return Err(OrchestratorError::operation( + CredentialsError::provider_error(format!( "Non-success status from HTTP credential provider: {:?}", response.status() - ))); - } - let str_resp = - std::str::from_utf8(response.body().as_ref()).map_err(CredentialsError::unhandled)?; - let json_creds = parse_json_credentials(str_resp).map_err(CredentialsError::unhandled)?; - match json_creds { - JsonCredentials::RefreshableCredentials(RefreshableCredentials { - access_key_id, - secret_access_key, - session_token, - expiration, - }) => Ok(Credentials::new( - access_key_id, - secret_access_key, - Some(session_token.to_string()), - Some(expiration), - self.provider_name, - )), - JsonCredentials::Error { code, message } => Err(CredentialsError::provider_error( - format!("failed to load credentials [{}]: {}", code, message), )), - } + )); } - - fn sensitive(&self) -> bool { - true + let resp_bytes = response.body().bytes().expect("non-streaming deserializer"); + let str_resp = std::str::from_utf8(resp_bytes) + .map_err(|err| OrchestratorError::operation(CredentialsError::unhandled(err)))?; + let json_creds = parse_json_credentials(str_resp) + .map_err(|err| OrchestratorError::operation(CredentialsError::unhandled(err)))?; + match json_creds { + JsonCredentials::RefreshableCredentials(RefreshableCredentials { + access_key_id, + secret_access_key, + session_token, + expiration, + }) => Ok(Credentials::new( + access_key_id, + secret_access_key, + Some(session_token.to_string()), + Some(expiration), + provider_name, + )), + JsonCredentials::Error { code, message } => Err(OrchestratorError::operation( + CredentialsError::provider_error(format!( + "failed to load credentials [{}]: {}", + code, message + )), + )), } } #[derive(Clone, Debug)] struct HttpCredentialRetryClassifier; -impl ClassifyRetry, SdkError> - for HttpCredentialRetryClassifier -{ - fn classify_retry( - &self, - response: Result<&SdkSuccess, &SdkError>, - ) -> RetryKind { - /* The following errors are retryable: - * - Socket errors - * - Networking timeouts - * - 5xx errors - * - Non-parseable 200 responses. - * */ - match response { - Ok(_) => RetryKind::Unnecessary, - // socket errors, networking timeouts - Err(SdkError::DispatchFailure(client_err)) - if client_err.is_timeout() || client_err.is_io() => - { - RetryKind::Error(ErrorKind::TransientError) - } - // non-parseable 200s - Err(SdkError::ServiceError(context)) - if matches!(context.err(), CredentialsError::Unhandled { .. }) - && context.raw().http().status().is_success() => - { - RetryKind::Error(ErrorKind::ServerError) - } - // 5xx errors - Err(SdkError::ResponseError(context)) - if context.raw().http().status().is_server_error() => - { - RetryKind::Error(ErrorKind::ServerError) - } - Err(SdkError::ServiceError(context)) - if context.raw().http().status().is_server_error() => - { - RetryKind::Error(ErrorKind::ServerError) +impl ClassifyRetry for HttpCredentialRetryClassifier { + fn name(&self) -> &'static str { + "HttpCredentialRetryClassifier" + } + + fn classify_retry(&self, ctx: &InterceptorContext) -> Option { + let output_or_error = ctx.output_or_error()?; + let error = match output_or_error { + Ok(_) => return None, + Err(err) => err, + }; + + // Retry non-parseable 200 responses + if let Some((err, status)) = error + .as_operation_error() + .and_then(|err| err.downcast_ref::()) + .zip(ctx.response().map(HttpResponse::status)) + { + if matches!(err, CredentialsError::Unhandled { .. }) && status.is_success() { + return Some(RetryReason::Error(ErrorKind::ServerError)); } - Err(_) => RetryKind::UnretryableFailure, } + + None } } #[cfg(test)] mod test { - use crate::http_credential_provider::{ - CredentialsResponseParser, HttpCredentialRetryClassifier, - }; + use super::*; use aws_credential_types::provider::error::CredentialsError; - use aws_credential_types::Credentials; + use aws_smithy_client::test_connection::TestConnection; use aws_smithy_http::body::SdkBody; - use aws_smithy_http::operation; - use aws_smithy_http::response::ParseStrictResponse; - use aws_smithy_http::result::{SdkError, SdkSuccess}; - use aws_smithy_http::retry::ClassifyRetry; - use aws_smithy_types::retry::{ErrorKind, RetryKind}; - use bytes::Bytes; + use aws_smithy_runtime_api::client::orchestrator::HttpRequest; + use http::{Request, Response, Uri}; + use std::time::SystemTime; - fn sdk_resp( - resp: http::Response<&'static str>, - ) -> Result, SdkError> { - let resp = resp.map(|data| Bytes::from_static(data.as_bytes())); - match (CredentialsResponseParser { - provider_name: "test", - }) - .parse(&resp) - { - Ok(creds) => Ok(SdkSuccess { - raw: operation::Response::new(resp.map(SdkBody::from)), - parsed: creds, - }), - Err(err) => Err(SdkError::service_error( - err, - operation::Response::new(resp.map(SdkBody::from)), - )), - } + async fn provide_creds( + connector: TestConnection, + ) -> Result { + let provider_config = ProviderConfig::default().with_http_connector(connector.clone()); + let provider = HttpCredentialProvider::builder() + .configure(&provider_config) + .build("test", "http://localhost:1234/", "/some-creds"); + provider.credentials(None).await } - #[test] - fn non_parseable_is_retriable() { - let bad_response = http::Response::builder() - .status(200) - .body("notjson") - .unwrap(); - - assert_eq!( - HttpCredentialRetryClassifier.classify_retry(sdk_resp(bad_response).as_ref()), - RetryKind::Error(ErrorKind::ServerError) - ); + fn successful_req_resp() -> (HttpRequest, HttpResponse) { + ( + Request::builder() + .uri(Uri::from_static("http://localhost:1234/some-creds")) + .body(SdkBody::empty()) + .unwrap(), + Response::builder() + .status(200) + .body(SdkBody::from( + r#"{ + "AccessKeyId" : "MUA...", + "SecretAccessKey" : "/7PC5om....", + "Token" : "AQoDY....=", + "Expiration" : "2016-02-25T06:03:31Z" + }"#, + )) + .unwrap(), + ) } - #[test] - fn ok_response_not_retriable() { - let ok_response = http::Response::builder() - .status(200) - .body( - r#" { - "AccessKeyId" : "MUA...", - "SecretAccessKey" : "/7PC5om....", - "Token" : "AQoDY....=", - "Expiration" : "2016-02-25T06:03:31Z" - }"#, - ) - .unwrap(); - let sdk_result = sdk_resp(ok_response); - + #[tokio::test] + async fn successful_response() { + let connector = TestConnection::new(vec![successful_req_resp()]); + let creds = provide_creds(connector.clone()).await.expect("success"); + assert_eq!("MUA...", creds.access_key_id()); + assert_eq!("/7PC5om....", creds.secret_access_key()); + assert_eq!(Some("AQoDY....="), creds.session_token()); assert_eq!( - HttpCredentialRetryClassifier.classify_retry(sdk_result.as_ref()), - RetryKind::Unnecessary + Some(SystemTime::UNIX_EPOCH + Duration::from_secs(1456380211)), + creds.expiry() ); + connector.assert_requests_match(&[]); + } - assert!(sdk_result.is_ok(), "should be ok: {:?}", sdk_result) + #[tokio::test] + async fn retry_nonparseable_response() { + let connector = TestConnection::new(vec![ + ( + Request::builder() + .uri(Uri::from_static("http://localhost:1234/some-creds")) + .body(SdkBody::empty()) + .unwrap(), + Response::builder() + .status(200) + .body(SdkBody::from(r#"not json"#)) + .unwrap(), + ), + successful_req_resp(), + ]); + let creds = provide_creds(connector.clone()).await.expect("success"); + assert_eq!("MUA...", creds.access_key_id()); + connector.assert_requests_match(&[]); } - #[test] - fn explicit_error_not_retriable() { - let error_response = http::Response::builder() - .status(400) - .body(r#"{ "Code": "Error", "Message": "There was a problem, it was your fault" }"#) - .unwrap(); - let sdk_result = sdk_resp(error_response); - assert_eq!( - HttpCredentialRetryClassifier.classify_retry(sdk_result.as_ref()), - RetryKind::UnretryableFailure - ); - let sdk_error = sdk_result.expect_err("should be error"); + #[tokio::test] + async fn retry_error_code() { + let connector = TestConnection::new(vec![ + ( + Request::builder() + .uri(Uri::from_static("http://localhost:1234/some-creds")) + .body(SdkBody::empty()) + .unwrap(), + Response::builder() + .status(500) + .body(SdkBody::from(r#"it broke"#)) + .unwrap(), + ), + successful_req_resp(), + ]); + let creds = provide_creds(connector.clone()).await.expect("success"); + assert_eq!("MUA...", creds.access_key_id()); + connector.assert_requests_match(&[]); + } + #[tokio::test] + async fn explicit_error_not_retriable() { + let connector = TestConnection::new(vec![( + Request::builder() + .uri(Uri::from_static("http://localhost:1234/some-creds")) + .body(SdkBody::empty()) + .unwrap(), + Response::builder() + .status(400) + .body(SdkBody::from( + r#"{ "Code": "Error", "Message": "There was a problem, it was your fault" }"#, + )) + .unwrap(), + )]); + let err = provide_creds(connector.clone()) + .await + .expect_err("it should fail"); assert!( - matches!( - sdk_error, - SdkError::ServiceError(ref context) if matches!(context.err(), CredentialsError::ProviderError { .. }) - ), - "should be provider error: {}", - sdk_error + matches!(err, CredentialsError::ProviderError { .. }), + "should be CredentialsError::ProviderError: {err}", ); + connector.assert_requests_match(&[]); } } diff --git a/aws/rust-runtime/aws-config/src/imds/client.rs b/aws/rust-runtime/aws-config/src/imds/client.rs index f084bc9a832..b753353f05b 100644 --- a/aws/rust-runtime/aws-config/src/imds/client.rs +++ b/aws/rust-runtime/aws-config/src/imds/client.rs @@ -9,34 +9,47 @@ use crate::connector::expect_connector; use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode}; -use crate::imds::client::token::TokenMiddleware; +use crate::imds::client::token::TokenRuntimePlugin; use crate::provider_config::ProviderConfig; use crate::PKG_VERSION; -use aws_http::user_agent::{ApiMetadata, AwsUserAgent, UserAgentStage}; +use aws_http::user_agent::{ApiMetadata, AwsUserAgent}; +use aws_runtime::user_agent::UserAgentInterceptor; +use aws_sdk_sso::config::interceptors::InterceptorContext; +use aws_sdk_sso::config::{SharedAsyncSleep, SharedInterceptor}; +use aws_smithy_async::time::SharedTimeSource; +use aws_smithy_client::erase::DynConnector; use aws_smithy_client::http_connector::ConnectorSettings; -use aws_smithy_client::{erase::DynConnector, SdkSuccess}; -use aws_smithy_client::{retry, SdkError}; +use aws_smithy_client::SdkError; use aws_smithy_http::body::SdkBody; -use aws_smithy_http::endpoint::apply_endpoint; -use aws_smithy_http::operation; -use aws_smithy_http::operation::{Metadata, Operation}; -use aws_smithy_http::response::ParseStrictResponse; -use aws_smithy_http::retry::ClassifyRetry; -use aws_smithy_http_tower::map_request::{ - AsyncMapRequestLayer, AsyncMapRequestService, MapRequestLayer, MapRequestService, +use aws_smithy_http::result::ConnectorError; +use aws_smithy_runtime::client::connectors::adapter::DynConnectorAdapter; +use aws_smithy_runtime::client::orchestrator::operation::Operation; +use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy; +use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams; +use aws_smithy_runtime_api::client::connectors::SharedHttpConnector; +use aws_smithy_runtime_api::client::endpoint::{ + EndpointResolver, EndpointResolverParams, SharedEndpointResolver, }; -use aws_smithy_types::error::display::DisplayErrorContext; -use aws_smithy_types::retry::{ErrorKind, RetryKind}; +use aws_smithy_runtime_api::client::orchestrator::{ + Future, HttpResponse, OrchestratorError, SensitiveOutput, +}; +use aws_smithy_runtime_api::client::retries::{ + ClassifyRetry, RetryClassifiers, RetryReason, SharedRetryStrategy, +}; +use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder; +use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin}; +use aws_smithy_types::config_bag::{FrozenLayer, Layer}; +use aws_smithy_types::endpoint::Endpoint; +use aws_smithy_types::retry::{ErrorKind, RetryConfig}; use aws_smithy_types::timeout::TimeoutConfig; use aws_types::os_shim_internal::Env; -use bytes::Bytes; -use http::{Response, Uri}; +use http::Uri; use std::borrow::Cow; -use std::error::Error; +use std::error::Error as _; +use std::fmt; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use tokio::sync::OnceCell; pub mod error; mod token; @@ -85,8 +98,7 @@ fn user_agent() -> AwsUserAgent { /// # async fn docs() { /// let client = Client::builder() /// .endpoint(Uri::from_static("http://customidms:456/")) -/// .build() -/// .await; +/// .build(); /// # } /// ``` /// @@ -104,7 +116,7 @@ fn user_agent() -> AwsUserAgent { /// ```no_run /// use aws_config::imds::client::{Client, EndpointMode}; /// # async fn docs() { -/// let client = Client::builder().endpoint_mode(EndpointMode::IpV6).build().await; +/// let client = Client::builder().endpoint_mode(EndpointMode::IpV6).build(); /// # } /// ``` /// @@ -123,49 +135,7 @@ fn user_agent() -> AwsUserAgent { /// #[derive(Clone, Debug)] pub struct Client { - inner: Arc, -} - -#[derive(Debug)] -struct ClientInner { - endpoint: Uri, - smithy_client: aws_smithy_client::Client, -} - -/// Client where build is sync, but usage is async -/// -/// Building an imds::Client is actually an async operation, however, for credentials and region -/// providers, we want build to always be a synchronous operation. This allows building to be deferred -/// and cached until request time. -#[derive(Debug)] -pub(super) struct LazyClient { - client: OnceCell>, - builder: Builder, -} - -impl LazyClient { - pub(super) fn from_ready_client(client: Client) -> Self { - Self { - client: OnceCell::from(Ok(client)), - // the builder will never be used in this case - builder: Builder::default(), - } - } - pub(super) async fn client(&self) -> Result<&Client, &BuildError> { - let builder = &self.builder; - self.client - // the clone will only happen once when we actually construct it for the first time, - // after that, we will use the cache. - .get_or_init(|| async { - let client = builder.clone().build().await; - if let Err(err) = &client { - tracing::warn!(err = %DisplayErrorContext(err), "failed to create IMDS client") - } - client - }) - .await - .as_ref() - } + operation: Operation, } impl Client { @@ -187,18 +157,16 @@ impl Client { /// ```no_run /// use aws_config::imds::client::Client; /// # async fn docs() { - /// let client = Client::builder().build().await.expect("valid client"); + /// let client = Client::builder().build(); /// let ami_id = client /// .get("/latest/meta-data/ami-id") /// .await /// .expect("failure communicating with IMDS"); /// # } /// ``` - pub async fn get(&self, path: &str) -> Result { - let operation = self.make_operation(path)?; - self.inner - .smithy_client - .call(operation) + pub async fn get(&self, path: impl Into) -> Result { + self.operation + .invoke(path.into()) .await .map_err(|err| match err { SdkError::ConstructionFailure(_) if err.source().is_some() => { @@ -213,76 +181,112 @@ impl Client { InnerImdsError::InvalidUtf8 => { ImdsError::unexpected("IMDS returned invalid UTF-8") } - InnerImdsError::BadStatus => { - ImdsError::error_response(context.into_raw().into_parts().0) - } + InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()), }, - SdkError::TimeoutError(_) - | SdkError::DispatchFailure(_) - | SdkError::ResponseError(_) => ImdsError::io_error(err), + // If the error source is an ImdsError, then we need to directly return that source. + // That way, the IMDS token provider's errors can become the top-level ImdsError. + // There is a unit test that checks the correct error is being extracted. + err @ SdkError::DispatchFailure(_) => match err.into_source() { + Ok(source) => match source.downcast::() { + Ok(source) => match source.into_source().downcast::() { + Ok(source) => *source, + Err(err) => ImdsError::unexpected(err), + }, + Err(err) => ImdsError::unexpected(err), + }, + Err(err) => ImdsError::unexpected(err), + }, + SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err), _ => ImdsError::unexpected(err), }) } +} - /// Creates a aws_smithy_http Operation to for `path` - /// - Convert the path to a URI - /// - Set the base endpoint on the URI - /// - Add a user agent - fn make_operation( - &self, - path: &str, - ) -> Result, ImdsError> { - let mut base_uri: Uri = path.parse().map_err(|_| { - ImdsError::unexpected("IMDS path was not a valid URI. Hint: does it begin with `/`?") - })?; - apply_endpoint(&mut base_uri, &self.inner.endpoint, None).map_err(ImdsError::unexpected)?; - let request = http::Request::builder() - .uri(base_uri) - .body(SdkBody::empty()) - .expect("valid request"); - let mut request = operation::Request::new(request); - request.properties_mut().insert(user_agent()); - Ok(Operation::new(request, ImdsGetResponseHandler) - .with_metadata(Metadata::new("get", "imds")) - .with_retry_classifier(ImdsResponseRetryClassifier)) +/// New-type around `String` that doesn't emit the string value in the `Debug` impl. +#[derive(Clone)] +pub struct SensitiveString(String); + +impl fmt::Debug for SensitiveString { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("SensitiveString") + .field(&"** redacted **") + .finish() } } -/// IMDS Middleware -/// -/// The IMDS middleware includes a token-loader & a UserAgent stage -#[derive(Clone, Debug)] -struct ImdsMiddleware { - token_loader: TokenMiddleware, +impl AsRef for SensitiveString { + fn as_ref(&self) -> &str { + &self.0 + } } -impl tower::Layer for ImdsMiddleware { - type Service = AsyncMapRequestService, TokenMiddleware>; +impl From for SensitiveString { + fn from(value: String) -> Self { + Self(value) + } +} - fn layer(&self, inner: S) -> Self::Service { - AsyncMapRequestLayer::for_mapper(self.token_loader.clone()) - .layer(MapRequestLayer::for_mapper(UserAgentStage::new()).layer(inner)) +impl From for String { + fn from(value: SensitiveString) -> Self { + value.0 } } -#[derive(Copy, Clone)] -struct ImdsGetResponseHandler; +/// Runtime plugin that is used by both the IMDS client and the inner client that resolves +/// the IMDS token and attaches it to requests. This runtime plugin marks the responses as +/// sensitive, configures user agent headers, and sets up retries and timeouts. +#[derive(Debug)] +struct ImdsCommonRuntimePlugin { + config: FrozenLayer, + components: RuntimeComponentsBuilder, +} -impl ParseStrictResponse for ImdsGetResponseHandler { - type Output = Result; +impl ImdsCommonRuntimePlugin { + fn new( + connector: DynConnector, + endpoint_resolver: ImdsEndpointResolver, + retry_config: &RetryConfig, + timeout_config: TimeoutConfig, + time_source: SharedTimeSource, + sleep_impl: Option, + ) -> Self { + let mut layer = Layer::new("ImdsCommonRuntimePlugin"); + layer.store_put(AuthSchemeOptionResolverParams::new(())); + layer.store_put(EndpointResolverParams::new(())); + layer.store_put(SensitiveOutput); + layer.store_put(timeout_config); + layer.store_put(user_agent()); - fn parse(&self, response: &Response) -> Self::Output { - if response.status().is_success() { - std::str::from_utf8(response.body().as_ref()) - .map(|data| data.to_string()) - .map_err(|_| InnerImdsError::InvalidUtf8) - } else { - Err(InnerImdsError::BadStatus) + Self { + config: layer.freeze(), + components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin") + .with_http_connector(Some(SharedHttpConnector::new(DynConnectorAdapter::new( + connector, + )))) + .with_endpoint_resolver(Some(SharedEndpointResolver::new(endpoint_resolver))) + .with_interceptor(SharedInterceptor::new(UserAgentInterceptor::new())) + .with_retry_classifiers(Some( + RetryClassifiers::new().with_classifier(ImdsResponseRetryClassifier), + )) + .with_retry_strategy(Some(SharedRetryStrategy::new(StandardRetryStrategy::new( + retry_config, + )))) + .with_time_source(Some(time_source)) + .with_sleep_impl(sleep_impl), } } +} - fn sensitive(&self) -> bool { - true +impl RuntimePlugin for ImdsCommonRuntimePlugin { + fn config(&self) -> Option { + Some(self.config.clone()) + } + + fn runtime_components( + &self, + _current_components: &RuntimeComponentsBuilder, + ) -> Cow<'_, RuntimeComponentsBuilder> { + Cow::Borrowed(&self.components) } } @@ -415,15 +419,8 @@ impl Builder { self }*/ - pub(super) fn build_lazy(self) -> LazyClient { - LazyClient { - client: OnceCell::new(), - builder: self, - } - } - /// Build an IMDSv2 Client - pub async fn build(self) -> Result { + pub fn build(self) -> Client { let config = self.config.unwrap_or_default(); let timeout_config = TimeoutConfig::builder() .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT)) @@ -437,34 +434,46 @@ impl Builder { let endpoint_source = self .endpoint .unwrap_or_else(|| EndpointSource::Env(config.clone())); - let endpoint = endpoint_source.endpoint(self.mode_override).await?; - let retry_config = retry::Config::default() + let endpoint_resolver = ImdsEndpointResolver { + endpoint_source: Arc::new(endpoint_source), + mode_override: self.mode_override, + }; + let retry_config = RetryConfig::standard() .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS)); - let token_loader = token::TokenMiddleware::new( - connector.clone(), + let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new( + connector, + endpoint_resolver, + &retry_config, + timeout_config, config.time_source(), - endpoint.clone(), - self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL), - retry_config.clone(), - timeout_config.clone(), config.sleep(), - ); - let middleware = ImdsMiddleware { token_loader }; - let mut smithy_builder = aws_smithy_client::Client::builder() - .connector(connector.clone()) - .middleware(middleware) - .retry_config(retry_config) - .operation_timeout_config(timeout_config.into()); - smithy_builder.set_sleep_impl(config.sleep()); - let smithy_client = smithy_builder.build(); - - let client = Client { - inner: Arc::new(ClientInner { - endpoint, - smithy_client, - }), - }; - Ok(client) + )); + let operation = Operation::builder() + .service_name("imds") + .operation_name("get") + .runtime_plugin(common_plugin.clone()) + .runtime_plugin(SharedRuntimePlugin::new(TokenRuntimePlugin::new( + common_plugin, + config.time_source(), + self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL), + ))) + .serializer(|path| { + Ok(http::Request::builder() + .uri(path) + .body(SdkBody::empty()) + .expect("valid request")) + }) + .deserializer(|response| { + if response.status().is_success() { + std::str::from_utf8(response.body().bytes().expect("non-streaming response")) + .map(|data| SensitiveString::from(data.to_string())) + .map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8)) + } else { + Err(OrchestratorError::operation(InnerImdsError::BadStatus)) + } + }) + .build(); + Client { operation } } } @@ -531,19 +540,22 @@ impl EndpointSource { } } -#[derive(Clone)] -struct ImdsResponseRetryClassifier; +#[derive(Clone, Debug)] +struct ImdsEndpointResolver { + endpoint_source: Arc, + mode_override: Option, +} -impl ImdsResponseRetryClassifier { - fn classify(response: &operation::Response) -> RetryKind { - let status = response.http().status(); - match status { - _ if status.is_server_error() => RetryKind::Error(ErrorKind::ServerError), - // 401 indicates that the token has expired, this is retryable - _ if status.as_u16() == 401 => RetryKind::Error(ErrorKind::ServerError), - // This catch-all includes successful responses that fail to parse. These should not be retried. - _ => RetryKind::UnretryableFailure, - } +impl EndpointResolver for ImdsEndpointResolver { + fn resolve_endpoint(&self, _: &EndpointResolverParams) -> Future { + let this = self.clone(); + Future::new(Box::pin(async move { + this.endpoint_source + .endpoint(this.mode_override) + .await + .map(|uri| Endpoint::builder().url(uri.to_string()).build()) + .map_err(|err| err.into()) + })) } } @@ -556,13 +568,35 @@ impl ImdsResponseRetryClassifier { /// - 403 (IMDS disabled): **Not Retryable** /// - 404 (Not found): **Not Retryable** /// - >=500 (server error): **Retryable** -impl ClassifyRetry, SdkError> for ImdsResponseRetryClassifier { - fn classify_retry(&self, response: Result<&SdkSuccess, &SdkError>) -> RetryKind { - match response { - Ok(_) => RetryKind::Unnecessary, - Err(SdkError::ResponseError(context)) => Self::classify(context.raw()), - Err(SdkError::ServiceError(context)) => Self::classify(context.raw()), - _ => RetryKind::UnretryableFailure, +#[derive(Clone, Debug)] +struct ImdsResponseRetryClassifier; + +impl ImdsResponseRetryClassifier { + fn classify(response: &HttpResponse) -> Option { + let status = response.status(); + match status { + _ if status.is_server_error() => Some(RetryReason::Error(ErrorKind::ServerError)), + // 401 indicates that the token has expired, this is retryable + _ if status.as_u16() == 401 => Some(RetryReason::Error(ErrorKind::ServerError)), + // This catch-all includes successful responses that fail to parse. These should not be retried. + _ => None, + } + } +} + +impl ClassifyRetry for ImdsResponseRetryClassifier { + fn name(&self) -> &'static str { + "ImdsResponseRetryClassifier" + } + + fn classify_retry(&self, ctx: &InterceptorContext) -> Option { + if let Some(response) = ctx.response() { + Self::classify(response) + } else { + // Don't retry timeouts for IMDS, or else it will take ~30 seconds for the default + // credentials provider chain to fail to provide credentials. + // Also don't retry non-responses. + None } } } @@ -571,14 +605,18 @@ impl ClassifyRetry, SdkError> for ImdsResponseRetryClassi pub(crate) mod test { use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier}; use crate::provider_config::ProviderConfig; + use aws_sdk_sso::config::interceptors::InterceptorContext; + use aws_sdk_sso::error::DisplayErrorContext; use aws_smithy_async::rt::sleep::TokioSleep; use aws_smithy_async::test_util::instant_time_and_sleep; use aws_smithy_client::erase::DynConnector; use aws_smithy_client::test_connection::{capture_request, TestConnection}; - use aws_smithy_client::{SdkError, SdkSuccess}; use aws_smithy_http::body::SdkBody; - use aws_smithy_http::operation; - use aws_smithy_types::retry::RetryKind; + use aws_smithy_http::result::ConnectorError; + use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs; + use aws_smithy_runtime_api::client::interceptors::context::{Input, Output}; + use aws_smithy_runtime_api::client::orchestrator::OrchestratorError; + use aws_smithy_runtime_api::client::retries::ClassifyRetry; use aws_types::os_shim_internal::{Env, Fs}; use http::header::USER_AGENT; use http::Uri; @@ -637,7 +675,7 @@ pub(crate) mod test { http::Response::builder().status(200).body(body).unwrap() } - pub(crate) async fn make_client(conn: &TestConnection) -> super::Client + pub(crate) fn make_client(conn: &TestConnection) -> super::Client where SdkBody: From, T: Send + 'static, @@ -650,8 +688,6 @@ pub(crate) mod test { .with_http_connector(DynConnector::new(conn.clone())), ) .build() - .await - .expect("valid client") } #[tokio::test] @@ -670,13 +706,13 @@ pub(crate) mod test { imds_response("output2"), ), ]); - let client = make_client(&connection).await; + let client = make_client(&connection); // load once let metadata = client.get("/latest/metadata").await.expect("failed"); - assert_eq!(metadata, "test-imds-output"); + assert_eq!("test-imds-output", metadata.as_ref()); // load again: the cached token should be used let metadata = client.get("/latest/metadata2").await.expect("failed"); - assert_eq!(metadata, "output2"); + assert_eq!("output2", metadata.as_ref()); connection.assert_requests_match(&[]); } @@ -710,17 +746,15 @@ pub(crate) mod test { ) .endpoint_mode(EndpointMode::IpV6) .token_ttl(Duration::from_secs(600)) - .build() - .await - .expect("valid client"); + .build(); let resp1 = client.get("/latest/metadata").await.expect("success"); // now the cached credential has expired time_source.advance(Duration::from_secs(600)); let resp2 = client.get("/latest/metadata").await.expect("success"); connection.assert_requests_match(&[]); - assert_eq!(resp1, "test-imds-output1"); - assert_eq!(resp2, "test-imds-output2"); + assert_eq!("test-imds-output1", resp1.as_ref()); + assert_eq!("test-imds-output2", resp2.as_ref()); } /// Tokens are refreshed up to 120 seconds early to avoid using an expired token. @@ -761,9 +795,7 @@ pub(crate) mod test { ) .endpoint_mode(EndpointMode::IpV6) .token_ttl(Duration::from_secs(600)) - .build() - .await - .expect("valid client"); + .build(); let resp1 = client.get("/latest/metadata").await.expect("success"); // now the cached credential has expired @@ -772,9 +804,9 @@ pub(crate) mod test { time_source.advance(Duration::from_secs(150)); let resp3 = client.get("/latest/metadata").await.expect("success"); connection.assert_requests_match(&[]); - assert_eq!(resp1, "test-imds-output1"); - assert_eq!(resp2, "test-imds-output2"); - assert_eq!(resp3, "test-imds-output3"); + assert_eq!("test-imds-output1", resp1.as_ref()); + assert_eq!("test-imds-output2", resp2.as_ref()); + assert_eq!("test-imds-output3", resp3.as_ref()); } /// 500 error during the GET should be retried @@ -795,8 +827,15 @@ pub(crate) mod test { imds_response("ok"), ), ]); - let client = make_client(&connection).await; - assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok"); + let client = make_client(&connection); + assert_eq!( + "ok", + client + .get("/latest/metadata") + .await + .expect("success") + .as_ref() + ); connection.assert_requests_match(&[]); // all requests should have a user agent header @@ -823,8 +862,15 @@ pub(crate) mod test { imds_response("ok"), ), ]); - let client = make_client(&connection).await; - assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok"); + let client = make_client(&connection); + assert_eq!( + "ok", + client + .get("/latest/metadata") + .await + .expect("success") + .as_ref() + ); connection.assert_requests_match(&[]); } @@ -850,8 +896,15 @@ pub(crate) mod test { imds_response("ok"), ), ]); - let client = make_client(&connection).await; - assert_eq!(client.get("/latest/metadata").await.expect("success"), "ok"); + let client = make_client(&connection); + assert_eq!( + "ok", + client + .get("/latest/metadata") + .await + .expect("success") + .as_ref() + ); connection.assert_requests_match(&[]); } @@ -863,7 +916,7 @@ pub(crate) mod test { token_request("http://169.254.169.254", 21600), http::Response::builder().status(403).body("").unwrap(), )]); - let client = make_client(&connection).await; + let client = make_client(&connection); let err = client.get("/latest/metadata").await.expect_err("no token"); assert_full_error_contains!(err, "forbidden"); connection.assert_requests_match(&[]); @@ -872,30 +925,18 @@ pub(crate) mod test { /// Successful responses should classify as `RetryKind::Unnecessary` #[test] fn successful_response_properly_classified() { - use aws_smithy_http::retry::ClassifyRetry; - + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); + ctx.set_output_or_error(Ok(Output::doesnt_matter())); + ctx.set_response(imds_response("").map(|_| SdkBody::empty())); let classifier = ImdsResponseRetryClassifier; - fn response_200() -> operation::Response { - operation::Response::new(imds_response("").map(|_| SdkBody::empty())) - } - let success = SdkSuccess { - raw: response_200(), - parsed: (), - }; - assert_eq!( - RetryKind::Unnecessary, - classifier.classify_retry(Ok::<_, &SdkError<()>>(&success)) - ); + assert_eq!(None, classifier.classify_retry(&ctx)); // Emulate a failure to parse the response body (using an io error since it's easy to construct in a test) - let failure = SdkError::<()>::response_error( - io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse"), - response_200(), - ); - assert_eq!( - RetryKind::UnretryableFailure, - classifier.classify_retry(Err::<&SdkSuccess<()>, _>(&failure)) - ); + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); + ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io( + io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(), + )))); + assert_eq!(None, classifier.classify_retry(&ctx)); } // since tokens are sent as headers, the tokens need to be valid header values @@ -905,7 +946,7 @@ pub(crate) mod test { token_request("http://169.254.169.254", 21600), token_response(21600, "replaced").map(|_| vec![1, 0]), )]); - let client = make_client(&connection).await; + let client = make_client(&connection); let err = client.get("/latest/metadata").await.expect_err("no token"); assert_full_error_contains!(err, "invalid token"); connection.assert_requests_match(&[]); @@ -926,7 +967,7 @@ pub(crate) mod test { .unwrap(), ), ]); - let client = make_client(&connection).await; + let client = make_client(&connection); let err = client.get("/latest/metadata").await.expect_err("no token"); assert_full_error_contains!(err, "invalid UTF-8"); connection.assert_requests_match(&[]); @@ -943,14 +984,20 @@ pub(crate) mod test { let client = Client::builder() // 240.* can never be resolved .endpoint(Uri::from_static("http://240.0.0.0")) - .build() - .await - .expect("valid client"); + .build(); let now = SystemTime::now(); let resp = client .get("/latest/metadata") .await .expect_err("240.0.0.0 will never resolve"); + match resp { + err @ ImdsError::FailedToLoadToken(_) + if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} // ok, + other => panic!( + "wrong error, expected construction failure with TimedOutError inside: {}", + DisplayErrorContext(&other) + ), + } let time_elapsed = now.elapsed().unwrap(); assert!( time_elapsed > Duration::from_secs(1), @@ -962,14 +1009,6 @@ pub(crate) mod test { "time_elapsed should be less than 2s but was {:?}", time_elapsed ); - match resp { - err @ ImdsError::FailedToLoadToken(_) - if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} // ok, - other => panic!( - "wrong error, expected construction failure with TimedOutError inside: {}", - other - ), - } } #[derive(Debug, Deserialize)] @@ -983,8 +1022,10 @@ pub(crate) mod test { } #[tokio::test] - async fn config_tests() -> Result<(), Box> { - let test_cases = std::fs::read_to_string("test-data/imds-config/imds-tests.json")?; + async fn endpoint_config_tests() -> Result<(), Box> { + let _logs = capture_test_logs(); + + let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?; #[derive(Deserialize)] struct TestCases { tests: Vec, @@ -1014,24 +1055,22 @@ pub(crate) mod test { imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap()); } - let imds_client = imds_client.build().await; - let (uri, imds_client) = match (&test_case.result, imds_client) { - (Ok(uri), Ok(client)) => (uri, client), - (Err(test), Ok(_client)) => panic!( - "test should fail: {} but a valid client was made. {}", - test, test_case.docs - ), - (Err(substr), Err(err)) => { - assert_full_error_contains!(err, substr); - return; + let imds_client = imds_client.build(); + match &test_case.result { + Ok(uri) => { + // this request will fail, we just want to capture the endpoint configuration + let _ = imds_client.get("/hello").await; + assert_eq!(&watcher.expect_request().uri().to_string(), uri); + } + Err(expected) => { + let err = imds_client.get("/hello").await.expect_err("it should fail"); + let message = format!("{}", DisplayErrorContext(&err)); + assert!( + message.contains(expected), + "{}\nexpected error: {expected}\nactual error: {message}", + test_case.docs + ); } - (Ok(_uri), Err(e)) => panic!( - "a valid client should be made but: {}. {}", - e, test_case.docs - ), }; - // this request will fail, we just want to capture the endpoint configuration - let _ = imds_client.get("/hello").await; - assert_eq!(&watcher.expect_request().uri().to_string(), uri); } } diff --git a/aws/rust-runtime/aws-config/src/imds/client/error.rs b/aws/rust-runtime/aws-config/src/imds/client/error.rs index b9559486a6c..4d32aee0127 100644 --- a/aws/rust-runtime/aws-config/src/imds/client/error.rs +++ b/aws/rust-runtime/aws-config/src/imds/client/error.rs @@ -8,13 +8,14 @@ use aws_smithy_client::SdkError; use aws_smithy_http::body::SdkBody; use aws_smithy_http::endpoint::error::InvalidEndpointError; +use aws_smithy_runtime_api::client::orchestrator::HttpResponse; use std::error::Error; use std::fmt; /// Error context for [`ImdsError::FailedToLoadToken`] #[derive(Debug)] pub struct FailedToLoadToken { - source: SdkError, + source: SdkError, } impl FailedToLoadToken { @@ -23,7 +24,7 @@ impl FailedToLoadToken { matches!(self.source, SdkError::DispatchFailure(_)) } - pub(crate) fn into_source(self) -> SdkError { + pub(crate) fn into_source(self) -> SdkError { self.source } } @@ -76,7 +77,7 @@ pub enum ImdsError { } impl ImdsError { - pub(super) fn failed_to_load_token(source: SdkError) -> Self { + pub(super) fn failed_to_load_token(source: SdkError) -> Self { Self::FailedToLoadToken(FailedToLoadToken { source }) } diff --git a/aws/rust-runtime/aws-config/src/imds/client/token.rs b/aws/rust-runtime/aws-config/src/imds/client/token.rs index 41e96777b45..285e1ab7229 100644 --- a/aws/rust-runtime/aws-config/src/imds/client/token.rs +++ b/aws/rust-runtime/aws-config/src/imds/client/token.rs @@ -15,26 +15,31 @@ //! - Attach the token to the request in the `x-aws-ec2-metadata-token` header use crate::imds::client::error::{ImdsError, TokenError, TokenErrorKind}; -use crate::imds::client::ImdsResponseRetryClassifier; use aws_credential_types::cache::ExpiringCache; -use aws_http::user_agent::UserAgentStage; -use aws_smithy_async::rt::sleep::SharedAsyncSleep; +use aws_sdk_sso::config::RuntimeComponents; use aws_smithy_async::time::SharedTimeSource; -use aws_smithy_client::erase::DynConnector; -use aws_smithy_client::retry; use aws_smithy_http::body::SdkBody; -use aws_smithy_http::endpoint::apply_endpoint; -use aws_smithy_http::middleware::AsyncMapRequest; -use aws_smithy_http::operation; -use aws_smithy_http::operation::Operation; -use aws_smithy_http::operation::{Metadata, Request}; -use aws_smithy_http::response::ParseStrictResponse; -use aws_smithy_http_tower::map_request::MapRequestLayer; -use aws_smithy_types::timeout::TimeoutConfig; +use aws_smithy_runtime::client::orchestrator::operation::Operation; +use aws_smithy_runtime_api::box_error::BoxError; +use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver; +use aws_smithy_runtime_api::client::auth::{ + AuthScheme, AuthSchemeEndpointConfig, AuthSchemeId, SharedAuthScheme, + SharedAuthSchemeOptionResolver, Signer, +}; +use aws_smithy_runtime_api::client::identity::{ + Identity, IdentityResolver, SharedIdentityResolver, +}; +use aws_smithy_runtime_api::client::orchestrator::{ + Future, HttpRequest, HttpResponse, OrchestratorError, +}; +use aws_smithy_runtime_api::client::runtime_components::{ + GetIdentityResolver, RuntimeComponentsBuilder, +}; +use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin}; +use aws_smithy_types::config_bag::ConfigBag; use http::{HeaderValue, Uri}; -use std::fmt::{Debug, Formatter}; -use std::future::Future; -use std::pin::Pin; +use std::borrow::Cow; +use std::fmt; use std::sync::Arc; use std::time::{Duration, SystemTime}; @@ -47,6 +52,7 @@ const TOKEN_REFRESH_BUFFER: Duration = Duration::from_secs(120); const X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS: &str = "x-aws-ec2-metadata-token-ttl-seconds"; const X_AWS_EC2_METADATA_TOKEN: &str = "x-aws-ec2-metadata-token"; +const IMDS_TOKEN_AUTH_SCHEME: AuthSchemeId = AuthSchemeId::new(X_AWS_EC2_METADATA_TOKEN); /// IMDS Token #[derive(Clone)] @@ -54,151 +60,215 @@ struct Token { value: HeaderValue, expiry: SystemTime, } +impl fmt::Debug for Token { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Token") + .field("value", &"** redacted **") + .field("expiry", &self.expiry) + .finish() + } +} -/// Token Middleware -/// -/// Token middleware will load/cache a token when required and handle caching/expiry. +/// Token Runtime Plugin /// -/// It will attach the token to the incoming request on the `x-aws-ec2-metadata-token` header. -#[derive(Clone)] -pub(super) struct TokenMiddleware { - client: Arc>>, - token_parser: GetTokenResponseHandler, - token: ExpiringCache, - time_source: SharedTimeSource, - endpoint: Uri, - token_ttl: Duration, +/// This runtime plugin wires up the necessary componetns to load/cache a token +/// when required and handle caching/expiry. This token will get attached to the +/// request to IMDS on the `x-aws-ec2-metadata-token` header. +#[derive(Debug)] +pub(super) struct TokenRuntimePlugin { + components: RuntimeComponentsBuilder, +} + +impl TokenRuntimePlugin { + pub(super) fn new( + common_plugin: SharedRuntimePlugin, + time_source: SharedTimeSource, + token_ttl: Duration, + ) -> Self { + Self { + components: RuntimeComponentsBuilder::new("TokenRuntimePlugin") + .with_auth_scheme(SharedAuthScheme::new(TokenAuthScheme::new())) + .with_auth_scheme_option_resolver(Some(SharedAuthSchemeOptionResolver::new( + StaticAuthSchemeOptionResolver::new(vec![IMDS_TOKEN_AUTH_SCHEME]), + ))) + .with_identity_resolver( + IMDS_TOKEN_AUTH_SCHEME, + SharedIdentityResolver::new(TokenResolver::new( + common_plugin, + time_source, + token_ttl, + )), + ), + } + } } -impl Debug for TokenMiddleware { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "ImdsTokenMiddleware") +impl RuntimePlugin for TokenRuntimePlugin { + fn runtime_components( + &self, + _current_components: &RuntimeComponentsBuilder, + ) -> Cow<'_, RuntimeComponentsBuilder> { + Cow::Borrowed(&self.components) } } -impl TokenMiddleware { - pub(super) fn new( - connector: DynConnector, +#[derive(Debug)] +struct TokenResolverInner { + cache: ExpiringCache, + refresh: Operation<(), Token, TokenError>, + time_source: SharedTimeSource, +} + +#[derive(Clone, Debug)] +struct TokenResolver { + inner: Arc, +} + +impl TokenResolver { + fn new( + common_plugin: SharedRuntimePlugin, time_source: SharedTimeSource, - endpoint: Uri, token_ttl: Duration, - retry_config: retry::Config, - timeout_config: TimeoutConfig, - sleep_impl: Option, ) -> Self { - let mut inner_builder = aws_smithy_client::Client::builder() - .connector(connector) - .middleware(MapRequestLayer::::default()) - .retry_config(retry_config) - .operation_timeout_config(timeout_config.into()); - inner_builder.set_sleep_impl(sleep_impl); - let inner_client = inner_builder.build(); - let client = Arc::new(inner_client); Self { - client, - token_parser: GetTokenResponseHandler { - time: time_source.clone(), - }, - token: ExpiringCache::new(TOKEN_REFRESH_BUFFER), - time_source, - endpoint, - token_ttl, + inner: Arc::new(TokenResolverInner { + cache: ExpiringCache::new(TOKEN_REFRESH_BUFFER), + refresh: Operation::builder() + .service_name("imds") + .operation_name("get-token") + .runtime_plugin(common_plugin) + .no_auth() + .serializer(move |_| { + Ok(http::Request::builder() + .method("PUT") + .uri(Uri::from_static("/latest/api/token")) + .header(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS, token_ttl.as_secs()) + .body(SdkBody::empty()) + .expect("valid HTTP request")) + }) + .deserializer({ + let time_source = time_source.clone(); + move |response| { + let now = time_source.now(); + parse_token_response(response, now) + .map_err(OrchestratorError::operation) + } + }) + .build(), + time_source, + }), } } - async fn add_token(&self, request: Request) -> Result { - let preloaded_token = self - .token - .yield_or_clear_if_expired(self.time_source.now()) - .await; - let token = match preloaded_token { - Some(token) => Ok(token), - None => { - self.token - .get_or_load(|| async move { self.get_token().await }) - .await - } - }?; - request.augment(|mut request, _| { - request - .headers_mut() - .insert(X_AWS_EC2_METADATA_TOKEN, token.value); - Ok(request) - }) - } async fn get_token(&self) -> Result<(Token, SystemTime), ImdsError> { - let mut uri = Uri::from_static("/latest/api/token"); - apply_endpoint(&mut uri, &self.endpoint, None).map_err(ImdsError::unexpected)?; - let request = http::Request::builder() - .header( - X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS, - self.token_ttl.as_secs(), - ) - .uri(uri) - .method("PUT") - .body(SdkBody::empty()) - .expect("valid HTTP request"); - let mut request = operation::Request::new(request); - request.properties_mut().insert(super::user_agent()); - - let operation = Operation::new(request, self.token_parser.clone()) - .with_retry_classifier(ImdsResponseRetryClassifier) - .with_metadata(Metadata::new("get-token", "imds")); - let response = self - .client - .call(operation) + self.inner + .refresh + .invoke(()) .await - .map_err(ImdsError::failed_to_load_token)?; - let expiry = response.expiry; - Ok((response, expiry)) + .map(|token| { + let expiry = token.expiry; + (token, expiry) + }) + .map_err(ImdsError::failed_to_load_token) } } -impl AsyncMapRequest for TokenMiddleware { - type Error = ImdsError; - type Future = Pin> + Send + 'static>>; - - fn name(&self) -> &'static str { - "attach_imds_token" +fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result { + match response.status().as_u16() { + 400 => return Err(TokenErrorKind::InvalidParameters.into()), + 403 => return Err(TokenErrorKind::Forbidden.into()), + _ => {} } + let value = HeaderValue::from_bytes(response.body().bytes().expect("non-streaming response")) + .map_err(|_| TokenErrorKind::InvalidToken)?; + let ttl: u64 = response + .headers() + .get(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS) + .ok_or(TokenErrorKind::NoTtl)? + .to_str() + .map_err(|_| TokenErrorKind::InvalidTtl)? + .parse() + .map_err(|_parse_error| TokenErrorKind::InvalidTtl)?; + Ok(Token { + value, + expiry: now + Duration::from_secs(ttl), + }) +} - fn apply(&self, request: Request) -> Self::Future { +impl IdentityResolver for TokenResolver { + fn resolve_identity(&self, _config_bag: &ConfigBag) -> Future { let this = self.clone(); - Box::pin(async move { this.add_token(request).await }) + Future::new(Box::pin(async move { + let preloaded_token = this + .inner + .cache + .yield_or_clear_if_expired(this.inner.time_source.now()) + .await; + let token = match preloaded_token { + Some(token) => Ok(token), + None => { + this.inner + .cache + .get_or_load(|| { + let this = this.clone(); + async move { this.get_token().await } + }) + .await + } + }?; + + let expiry = token.expiry; + Ok(Identity::new(token, Some(expiry))) + })) } } -#[derive(Clone)] -struct GetTokenResponseHandler { - time: SharedTimeSource, +#[derive(Debug)] +struct TokenAuthScheme { + signer: TokenSigner, } -impl ParseStrictResponse for GetTokenResponseHandler { - type Output = Result; - - fn parse(&self, response: &http::Response) -> Self::Output { - match response.status().as_u16() { - 400 => return Err(TokenErrorKind::InvalidParameters.into()), - 403 => return Err(TokenErrorKind::Forbidden.into()), - _ => {} +impl TokenAuthScheme { + fn new() -> Self { + Self { + signer: TokenSigner, } - let value = HeaderValue::from_maybe_shared(response.body().clone()) - .map_err(|_| TokenErrorKind::InvalidToken)?; - let ttl: u64 = response - .headers() - .get(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS) - .ok_or(TokenErrorKind::NoTtl)? - .to_str() - .map_err(|_| TokenErrorKind::InvalidTtl)? - .parse() - .map_err(|_parse_error| TokenErrorKind::InvalidTtl)?; - Ok(Token { - value, - expiry: self.time.now() + Duration::from_secs(ttl), - }) } +} + +impl AuthScheme for TokenAuthScheme { + fn scheme_id(&self) -> AuthSchemeId { + IMDS_TOKEN_AUTH_SCHEME + } + + fn identity_resolver( + &self, + identity_resolvers: &dyn GetIdentityResolver, + ) -> Option { + identity_resolvers.identity_resolver(IMDS_TOKEN_AUTH_SCHEME) + } + + fn signer(&self) -> &dyn Signer { + &self.signer + } +} + +#[derive(Debug)] +struct TokenSigner; - fn sensitive(&self) -> bool { - true +impl Signer for TokenSigner { + fn sign_http_request( + &self, + request: &mut HttpRequest, + identity: &Identity, + _auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>, + _runtime_components: &RuntimeComponents, + _config_bag: &ConfigBag, + ) -> Result<(), BoxError> { + let token = identity.data::().expect("correct type"); + request + .headers_mut() + .append(X_AWS_EC2_METADATA_TOKEN, token.value.clone()); + Ok(()) } } diff --git a/aws/rust-runtime/aws-config/src/imds/credentials.rs b/aws/rust-runtime/aws-config/src/imds/credentials.rs index ccc65eaf992..3bde8a4510f 100644 --- a/aws/rust-runtime/aws-config/src/imds/credentials.rs +++ b/aws/rust-runtime/aws-config/src/imds/credentials.rs @@ -9,8 +9,7 @@ //! This credential provider will NOT fallback to IMDSv1. Ensure that IMDSv2 is enabled on your instances. use super::client::error::ImdsError; -use crate::imds; -use crate::imds::client::LazyClient; +use crate::imds::{self, Client}; use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials}; use crate::provider_config::ProviderConfig; use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials}; @@ -50,7 +49,7 @@ impl StdError for ImdsCommunicationError { /// _Note: This credentials provider will NOT fallback to the IMDSv1 flow._ #[derive(Debug)] pub struct ImdsCredentialsProvider { - client: LazyClient, + client: Client, env: Env, profile: Option, time_source: SharedTimeSource, @@ -110,12 +109,7 @@ impl Builder { let env = provider_config.env(); let client = self .imds_override - .map(LazyClient::from_ready_client) - .unwrap_or_else(|| { - imds::Client::builder() - .configure(&provider_config) - .build_lazy() - }); + .unwrap_or_else(|| imds::Client::builder().configure(&provider_config).build()); ImdsCredentialsProvider { client, env, @@ -156,23 +150,14 @@ impl ImdsCredentialsProvider { } } - /// Load an inner IMDS client from the OnceCell - async fn client(&self) -> Result<&imds::Client, CredentialsError> { - self.client.client().await.map_err(|build_error| { - // need to format the build error since we don't own it and it can't be cloned - CredentialsError::invalid_configuration(format!("{}", build_error)) - }) - } - /// Retrieve the instance profile from IMDS async fn get_profile_uncached(&self) -> Result { match self - .client() - .await? + .client .get("/latest/meta-data/iam/security-credentials/") .await { - Ok(profile) => Ok(profile), + Ok(profile) => Ok(profile.as_ref().into()), Err(ImdsError::ErrorResponse(context)) if context.response().status().as_u16() == 404 => { @@ -223,9 +208,11 @@ impl ImdsCredentialsProvider { async fn retrieve_credentials(&self) -> provider::Result { if self.imds_disabled() { - tracing::debug!("IMDS disabled because $AWS_EC2_METADATA_DISABLED was set to `true`"); + tracing::debug!( + "IMDS disabled because AWS_EC2_METADATA_DISABLED env var was set to `true`" + ); return Err(CredentialsError::not_loaded( - "IMDS disabled by $AWS_ECS_METADATA_DISABLED", + "IMDS disabled by AWS_ECS_METADATA_DISABLED env var", )); } tracing::debug!("loading credentials from IMDS"); @@ -235,15 +222,14 @@ impl ImdsCredentialsProvider { }; tracing::debug!(profile = %profile, "loaded profile"); let credentials = self - .client() - .await? - .get(&format!( + .client + .get(format!( "/latest/meta-data/iam/security-credentials/{}", profile )) .await .map_err(CredentialsError::provider_error)?; - match parse_json_credentials(&credentials) { + match parse_json_credentials(credentials.as_ref()) { Ok(JsonCredentials::RefreshableCredentials(RefreshableCredentials { access_key_id, secret_access_key, @@ -296,19 +282,16 @@ impl ImdsCredentialsProvider { #[cfg(test)] mod test { - use std::time::{Duration, UNIX_EPOCH}; - + use super::*; use crate::imds::client::test::{ imds_request, imds_response, make_client, token_request, token_response, }; - use crate::imds::credentials::{ - ImdsCredentialsProvider, WARNING_FOR_EXTENDING_CREDENTIALS_EXPIRY, - }; use crate::provider_config::ProviderConfig; use aws_credential_types::provider::ProvideCredentials; use aws_smithy_async::test_util::instant_time_and_sleep; use aws_smithy_client::erase::DynConnector; use aws_smithy_client::test_connection::TestConnection; + use std::time::{Duration, UNIX_EPOCH}; use tracing_test::traced_test; const TOKEN_A: &str = "token_a"; @@ -338,7 +321,7 @@ mod test { ), ]); let client = ImdsCredentialsProvider::builder() - .imds_client(make_client(&connection).await) + .imds_client(make_client(&connection)) .build(); let creds1 = client.provide_credentials().await.expect("valid creds"); let creds2 = client.provide_credentials().await.expect("valid creds"); @@ -376,9 +359,7 @@ mod test { .with_time_source(time_source); let client = crate::imds::Client::builder() .configure(&provider_config) - .build() - .await - .expect("valid client"); + .build(); let provider = ImdsCredentialsProvider::builder() .configure(&provider_config) .imds_client(client) @@ -422,9 +403,7 @@ mod test { .with_time_source(time_source); let client = crate::imds::Client::builder() .configure(&provider_config) - .build() - .await - .expect("valid client"); + .build(); let provider = ImdsCredentialsProvider::builder() .configure(&provider_config) .imds_client(client) @@ -443,9 +422,7 @@ mod test { let client = crate::imds::Client::builder() // 240.* can never be resolved .endpoint(http::Uri::from_static("http://240.0.0.0")) - .build() - .await - .expect("valid client"); + .build(); let expected = aws_credential_types::Credentials::for_tests(); let provider = ImdsCredentialsProvider::builder() .imds_client(client) @@ -463,18 +440,16 @@ mod test { let client = crate::imds::Client::builder() // 240.* can never be resolved .endpoint(http::Uri::from_static("http://240.0.0.0")) - .build() - .await - .expect("valid client"); + .build(); let provider = ImdsCredentialsProvider::builder() .imds_client(client) // no fallback credentials provided .build(); let actual = provider.provide_credentials().await; - assert!(matches!( - actual, - Err(aws_credential_types::provider::error::CredentialsError::CredentialsNotLoaded(_)) - )); + assert!( + matches!(actual, Err(CredentialsError::CredentialsNotLoaded(_))), + "\nexpected: Err(CredentialsError::CredentialsNotLoaded(_))\nactual: {actual:?}" + ); } #[tokio::test] @@ -484,9 +459,7 @@ mod test { let client = crate::imds::Client::builder() // 240.* can never be resolved .endpoint(http::Uri::from_static("http://240.0.0.0")) - .build() - .await - .expect("valid client"); + .build(); let expected = aws_credential_types::Credentials::for_tests(); let provider = ImdsCredentialsProvider::builder() .imds_client(client) @@ -536,7 +509,7 @@ mod test { ), ]); let provider = ImdsCredentialsProvider::builder() - .imds_client(make_client(&connection).await) + .imds_client(make_client(&connection)) .build(); let creds1 = provider.provide_credentials().await.expect("valid creds"); assert_eq!(creds1.access_key_id(), "ASIARTEST"); diff --git a/aws/rust-runtime/aws-config/src/imds/region.rs b/aws/rust-runtime/aws-config/src/imds/region.rs index bc784f8d4f1..072dc97a873 100644 --- a/aws/rust-runtime/aws-config/src/imds/region.rs +++ b/aws/rust-runtime/aws-config/src/imds/region.rs @@ -8,8 +8,7 @@ //! Load region from IMDS from `/latest/meta-data/placement/region` //! This provider has a 5 second timeout. -use crate::imds; -use crate::imds::client::LazyClient; +use crate::imds::{self, Client}; use crate::meta::region::{future, ProvideRegion}; use crate::provider_config::ProviderConfig; use aws_smithy_types::error::display::DisplayErrorContext; @@ -22,7 +21,7 @@ use tracing::Instrument; /// This provider is included in the default region chain, so it does not need to be used manually. #[derive(Debug)] pub struct ImdsRegionProvider { - client: LazyClient, + client: Client, env: Env, } @@ -49,11 +48,10 @@ impl ImdsRegionProvider { tracing::debug!("not using IMDS to load region, IMDS is disabled"); return None; } - let client = self.client.client().await.ok()?; - match client.get(REGION_PATH).await { + match self.client.get(REGION_PATH).await { Ok(region) => { - tracing::debug!(region = %region, "loaded region from IMDS"); - Some(Region::new(region)) + tracing::debug!(region = %region.as_ref(), "loaded region from IMDS"); + Some(Region::new(String::from(region))) } Err(err) => { tracing::warn!(err = %DisplayErrorContext(&err), "failed to load region from IMDS"); @@ -99,12 +97,7 @@ impl Builder { let provider_config = self.provider_config.unwrap_or_default(); let client = self .imds_client_override - .map(LazyClient::from_ready_client) - .unwrap_or_else(|| { - imds::Client::builder() - .configure(&provider_config) - .build_lazy() - }); + .unwrap_or_else(|| imds::Client::builder().configure(&provider_config).build()); ImdsRegionProvider { client, env: provider_config.env(), diff --git a/aws/rust-runtime/aws-config/test-data/imds-config/imds-tests.json b/aws/rust-runtime/aws-config/test-data/imds-config/imds-endpoint-tests.json similarity index 100% rename from aws/rust-runtime/aws-config/test-data/imds-config/imds-tests.json rename to aws/rust-runtime/aws-config/test-data/imds-config/imds-endpoint-tests.json diff --git a/aws/rust-runtime/aws-runtime/src/user_agent.rs b/aws/rust-runtime/aws-runtime/src/user_agent.rs index 7310820f86c..1b6c3699809 100644 --- a/aws/rust-runtime/aws-runtime/src/user_agent.rs +++ b/aws/rust-runtime/aws-runtime/src/user_agent.rs @@ -82,25 +82,25 @@ impl Interceptor for UserAgentInterceptor { _runtime_components: &RuntimeComponents, cfg: &mut ConfigBag, ) -> Result<(), BoxError> { - let api_metadata = cfg - .load::() - .ok_or(UserAgentInterceptorError::MissingApiMetadata)?; - // Allow for overriding the user agent by an earlier interceptor (so, for example, // tests can use `AwsUserAgent::for_tests()`) by attempting to grab one out of the // config bag before creating one. let ua: Cow<'_, AwsUserAgent> = cfg .load::() .map(Cow::Borrowed) + .map(Result::<_, UserAgentInterceptorError>::Ok) .unwrap_or_else(|| { + let api_metadata = cfg + .load::() + .ok_or(UserAgentInterceptorError::MissingApiMetadata)?; let mut ua = AwsUserAgent::new_from_environment(Env::real(), api_metadata.clone()); let maybe_app_name = cfg.load::(); if let Some(app_name) = maybe_app_name { ua.set_app_name(app_name.clone()); } - Cow::Owned(ua) - }); + Ok(Cow::Owned(ua)) + })?; let headers = context.request_mut().headers_mut(); let (user_agent, x_amz_user_agent) = header_values(&ua)?; @@ -250,4 +250,30 @@ mod tests { "`{error}` should contain message `This is a bug`" ); } + + #[test] + fn test_api_metadata_missing_with_ua_override() { + let rc = RuntimeComponentsBuilder::for_tests().build().unwrap(); + let mut context = context(); + + let mut layer = Layer::new("test"); + layer.store_put(AwsUserAgent::for_tests()); + let mut config = ConfigBag::of_layers(vec![layer]); + + let interceptor = UserAgentInterceptor::new(); + let mut ctx = Into::into(&mut context); + + interceptor + .modify_before_signing(&mut ctx, &rc, &mut config) + .expect("it should succeed"); + + let header = expect_header(&context, "user-agent"); + assert_eq!(AwsUserAgent::for_tests().ua_header(), header); + assert!(!header.contains("unused")); + + assert_eq!( + AwsUserAgent::for_tests().aws_ua_header(), + expect_header(&context, "x-amz-user-agent") + ); + } } diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt index b8502f0d350..b06444d15fe 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsPresigningDecorator.kt @@ -26,7 +26,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.docs -import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable @@ -173,7 +172,7 @@ class AwsPresignedFluentBuilderMethod( &self.handle.conf, self.config_override, ) - .with_client_plugin(#{SigV4PresigningRuntimePlugin}::new(presigning_config, #{payload_override})) + .with_client_plugin(#{SharedRuntimePlugin}::new(#{SigV4PresigningRuntimePlugin}::new(presigning_config, #{payload_override}))) #{alternate_presigning_serializer_registration}; let input = self.inner.build().map_err(#{SdkError}::construction_failure)?; @@ -193,6 +192,7 @@ class AwsPresignedFluentBuilderMethod( "RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig), "SharedInterceptor" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors") .resolve("SharedInterceptor"), + "SharedRuntimePlugin" to RuntimeType.sharedRuntimePlugin(runtimeConfig), "SigV4PresigningRuntimePlugin" to AwsRuntimeType.presigningInterceptor(runtimeConfig) .resolve("SigV4PresigningRuntimePlugin"), "StopPoint" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::StopPoint"), @@ -224,7 +224,10 @@ class AwsPresignedFluentBuilderMethod( }, "alternate_presigning_serializer_registration" to writable { if (presignableOp.hasModelTransforms()) { - rust(".with_operation_plugin(AlternatePresigningSerializerRuntimePlugin)") + rustTemplate( + ".with_operation_plugin(#{SharedRuntimePlugin}::new(AlternatePresigningSerializerRuntimePlugin))", + "SharedRuntimePlugin" to RuntimeType.sharedRuntimePlugin(codegenContext.runtimeConfig), + ) } }, "payload_override" to writable { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpChecksumRequiredGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpChecksumRequiredGenerator.kt index 94c37d0d732..135e6814c12 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpChecksumRequiredGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/HttpChecksumRequiredGenerator.kt @@ -19,6 +19,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.toType import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.inputShape @@ -38,7 +39,7 @@ class HttpChecksumRequiredGenerator( is OperationSection.AdditionalRuntimePlugins -> writable { section.addOperationRuntimePlugin(this) { rustTemplate( - "#{HttpChecksumRequiredRuntimePlugin}::new()", + "#{SharedRuntimePlugin}::new(#{HttpChecksumRequiredRuntimePlugin}::new())", "HttpChecksumRequiredRuntimePlugin" to InlineDependency.forRustFile( RustModule.pubCrate("client_http_checksum_required", parent = ClientRustModule.root), @@ -48,6 +49,7 @@ class HttpChecksumRequiredGenerator( CargoDependency.Http, CargoDependency.Md5, ).toType().resolve("HttpChecksumRequiredRuntimePlugin"), + "SharedRuntimePlugin" to RuntimeType.sharedRuntimePlugin(codegenContext.runtimeConfig), ) } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt index dd1bcfef08f..332d6baf000 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/IdempotencyTokenGenerator.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.toType import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.smithy.isOptional import software.amazon.smithy.rust.codegen.core.util.findMemberWithTrait @@ -48,6 +49,7 @@ class IdempotencyTokenGenerator( CargoDependency.smithyRuntimeApi(runtimeConfig), CargoDependency.smithyTypes(runtimeConfig), ).toType().resolve("IdempotencyTokenRuntimePlugin"), + "SharedRuntimePlugin" to RuntimeType.sharedRuntimePlugin(runtimeConfig), ) return when (section) { @@ -58,12 +60,14 @@ class IdempotencyTokenGenerator( // then we'll generate one and set it. rustTemplate( """ - #{IdempotencyTokenRuntimePlugin}::new(|token_provider, input| { - let input: &mut #{Input} = input.downcast_mut().expect("correct type"); - if input.$memberName.is_none() { - input.$memberName = #{Some}(token_provider.make_idempotency_token()); - } - }) + #{SharedRuntimePlugin}::new( + #{IdempotencyTokenRuntimePlugin}::new(|token_provider, input| { + let input: &mut #{Input} = input.downcast_mut().expect("correct type"); + if input.$memberName.is_none() { + input.$memberName = #{Some}(token_provider.make_idempotency_token()); + } + }) + ) """, *codegenScope, ) @@ -73,12 +77,14 @@ class IdempotencyTokenGenerator( // and set it. rustTemplate( """ - #{IdempotencyTokenRuntimePlugin}::new(|token_provider, input| { - let input: &mut #{Input} = input.downcast_mut().expect("correct type"); - if input.$memberName.is_empty() { - input.$memberName = token_provider.make_idempotency_token(); - } - }) + #{SharedRuntimePlugin}::new( + #{IdempotencyTokenRuntimePlugin}::new(|token_provider, input| { + let input: &mut #{Input} = input.downcast_mut().expect("correct type"); + if input.$memberName.is_empty() { + input.$memberName = token_provider.make_idempotency_token(); + } + }) + ) """, *codegenScope, ) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt index a902c1432db..88a1a77c930 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SensitiveOutputDecorator.kt @@ -12,7 +12,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCus import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection import software.amazon.smithy.rust.codegen.client.smithy.generators.SensitiveIndex import software.amazon.smithy.rust.codegen.core.rustlang.Writable -import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt index 9a3793b7669..46008a3c26d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/OperationGenerator.kt @@ -141,14 +141,16 @@ open class OperationGenerator( client_config: &crate::config::Config, config_override: #{Option}, ) -> #{RuntimePlugins} { - let mut runtime_plugins = client_runtime_plugins.with_operation_plugin(Self::new()); + let mut runtime_plugins = client_runtime_plugins.with_operation_plugin(#{SharedRuntimePlugin}::new(Self::new())); #{additional_runtime_plugins} if let #{Some}(config_override) = config_override { for plugin in config_override.runtime_plugins.iter().cloned() { runtime_plugins = runtime_plugins.with_operation_plugin(plugin); } runtime_plugins = runtime_plugins.with_operation_plugin( - crate::config::ConfigOverrideRuntimePlugin::new(config_override, client_config.config.clone(), &client_config.runtime_components) + #{SharedRuntimePlugin}::new( + crate::config::ConfigOverrideRuntimePlugin::new(config_override, client_config.config.clone(), &client_config.runtime_components) + ) ); } runtime_plugins @@ -160,6 +162,7 @@ open class OperationGenerator( "OrchestratorError" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::orchestrator::error::OrchestratorError"), "RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig), "RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig), + "SharedRuntimePlugin" to RuntimeType.sharedRuntimePlugin(runtimeConfig), "StopPoint" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::StopPoint"), "invoke_with_stop_point" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::invoke_with_stop_point"), "additional_runtime_plugins" to writable { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt index b2927dca8e5..b43c3cbf4ac 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/FluentClientGenerator.kt @@ -481,12 +481,14 @@ private fun baseClientRuntimePluginsFn(runtimeConfig: RuntimeConfig): RuntimeTyp ::std::mem::swap(&mut config.runtime_plugins, &mut configured_plugins); let mut plugins = #{RuntimePlugins}::new() .with_client_plugin( - #{StaticRuntimePlugin}::new() - .with_config(config.config.clone()) - .with_runtime_components(config.runtime_components.clone()) + #{SharedRuntimePlugin}::new( + #{StaticRuntimePlugin}::new() + .with_config(config.config.clone()) + .with_runtime_components(config.runtime_components.clone()) + ) ) - .with_client_plugin(crate::config::ServiceRuntimePlugin::new(config)) - .with_client_plugin(#{NoAuthRuntimePlugin}::new()); + .with_client_plugin(#{SharedRuntimePlugin}::new(crate::config::ServiceRuntimePlugin::new(config))) + .with_client_plugin(#{SharedRuntimePlugin}::new(#{NoAuthRuntimePlugin}::new())); for plugin in configured_plugins { plugins = plugins.with_client_plugin(plugin); } @@ -497,6 +499,7 @@ private fun baseClientRuntimePluginsFn(runtimeConfig: RuntimeConfig): RuntimeTyp "RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig), "NoAuthRuntimePlugin" to RuntimeType.smithyRuntime(runtimeConfig) .resolve("client::auth::no_auth::NoAuthRuntimePlugin"), + "SharedRuntimePlugin" to RuntimeType.sharedRuntimePlugin(runtimeConfig), "StaticRuntimePlugin" to RuntimeType.smithyRuntimeApi(runtimeConfig) .resolve("client::runtime_plugin::StaticRuntimePlugin"), ) diff --git a/rust-runtime/aws-smithy-client/src/test_connection.rs b/rust-runtime/aws-smithy-client/src/test_connection.rs index b0600b369af..622f5fedcec 100644 --- a/rust-runtime/aws-smithy-client/src/test_connection.rs +++ b/rust-runtime/aws-smithy-client/src/test_connection.rs @@ -134,7 +134,7 @@ pub struct ValidateRequest { impl ValidateRequest { pub fn assert_matches(&self, ignore_headers: &[HeaderName]) { let (actual, expected) = (&self.actual, &self.expected); - assert_eq!(actual.uri(), expected.uri()); + assert_eq!(expected.uri(), actual.uri()); for (name, value) in expected.headers() { if !ignore_headers.contains(name) { let actual_header = actual diff --git a/rust-runtime/aws-smithy-http/src/result.rs b/rust-runtime/aws-smithy-http/src/result.rs index 8d60c8d2ed7..c86e9c5f88a 100644 --- a/rust-runtime/aws-smithy-http/src/result.rs +++ b/rust-runtime/aws-smithy-http/src/result.rs @@ -651,6 +651,11 @@ impl ConnectorError { } } + /// Grants ownership of this error's source. + pub fn into_source(self) -> BoxError { + self.source + } + /// Returns metadata about the connection /// /// If a connection was established and provided by the internal connector, a connection will diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs index 8de224cc80b..a034515562f 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/orchestrator.rs @@ -25,7 +25,8 @@ use aws_smithy_http::body::SdkBody; use aws_smithy_http::result::{ConnectorError, SdkError}; use aws_smithy_types::config_bag::{Storable, StoreReplace}; use bytes::Bytes; -use std::fmt::Debug; +use std::error::Error as StdError; +use std::fmt; use std::future::Future as StdFuture; use std::pin::Pin; @@ -244,6 +245,35 @@ impl OrchestratorError { } } +impl StdError for OrchestratorError +where + E: StdError + 'static, +{ + fn source(&self) -> Option<&(dyn StdError + 'static)> { + Some(match &self.kind { + ErrorKind::Connector { source } => source as _, + ErrorKind::Operation { err } => err as _, + ErrorKind::Interceptor { source } => source as _, + ErrorKind::Response { source } => source.as_ref(), + ErrorKind::Timeout { source } => source.as_ref(), + ErrorKind::Other { source } => source.as_ref(), + }) + } +} + +impl fmt::Display for OrchestratorError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self.kind { + ErrorKind::Connector { .. } => "connector error", + ErrorKind::Operation { .. } => "operation error", + ErrorKind::Interceptor { .. } => "interceptor error", + ErrorKind::Response { .. } => "response error", + ErrorKind::Timeout { .. } => "timeout", + ErrorKind::Other { .. } => "an unknown error occurred", + }) + } +} + fn convert_dispatch_error( err: BoxError, response: Option, @@ -262,7 +292,7 @@ fn convert_dispatch_error( impl From for OrchestratorError where - E: Debug + std::error::Error + 'static, + E: fmt::Debug + std::error::Error + 'static, { fn from(err: InterceptorError) -> Self { Self::interceptor(err) diff --git a/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs b/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs index 9fec5519a73..111a5670b51 100644 --- a/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs +++ b/rust-runtime/aws-smithy-runtime-api/src/client/runtime_plugin.rs @@ -185,12 +185,12 @@ impl RuntimePlugin for StaticRuntimePlugin { self.runtime_components .as_ref() .map(Cow::Borrowed) - .unwrap_or_else(|| RuntimePlugin::runtime_components(self, _current_components)) + .unwrap_or_else(|| Cow::Borrowed(&EMPTY_RUNTIME_COMPONENTS_BUILDER)) } } macro_rules! insert_plugin { - ($vec:expr, $plugin:ident, $create_rp:expr) => {{ + ($vec:expr, $plugin:ident) => {{ // Insert the plugin in the correct order let mut insert_index = 0; let order = $plugin.order(); @@ -202,7 +202,7 @@ macro_rules! insert_plugin { break; } } - $vec.insert(insert_index, $create_rp); + $vec.insert(insert_index, $plugin); }}; } @@ -235,21 +235,13 @@ impl RuntimePlugins { Default::default() } - pub fn with_client_plugin(mut self, plugin: impl RuntimePlugin + 'static) -> Self { - insert_plugin!( - self.client_plugins, - plugin, - SharedRuntimePlugin::new(plugin) - ); + pub fn with_client_plugin(mut self, plugin: SharedRuntimePlugin) -> Self { + insert_plugin!(self.client_plugins, plugin); self } - pub fn with_operation_plugin(mut self, plugin: impl RuntimePlugin + 'static) -> Self { - insert_plugin!( - self.operation_plugins, - plugin, - SharedRuntimePlugin::new(plugin) - ); + pub fn with_operation_plugin(mut self, plugin: SharedRuntimePlugin) -> Self { + insert_plugin!(self.operation_plugins, plugin); self } @@ -274,7 +266,7 @@ mod tests { use crate::client::connectors::{HttpConnector, HttpConnectorFuture, SharedHttpConnector}; use crate::client::orchestrator::HttpRequest; use crate::client::runtime_components::RuntimeComponentsBuilder; - use crate::client::runtime_plugin::Order; + use crate::client::runtime_plugin::{Order, SharedRuntimePlugin}; use aws_smithy_http::body::SdkBody; use aws_smithy_types::config_bag::ConfigBag; use http::HeaderValue; @@ -287,7 +279,7 @@ mod tests { #[test] fn can_add_runtime_plugin_implementors_to_runtime_plugins() { - RuntimePlugins::new().with_client_plugin(SomeStruct); + RuntimePlugins::new().with_client_plugin(SharedRuntimePlugin::new(SomeStruct)); } #[test] @@ -307,7 +299,7 @@ mod tests { } fn insert_plugin(vec: &mut Vec, plugin: RP) { - insert_plugin!(vec, plugin, plugin); + insert_plugin!(vec, plugin); } let mut vec = Vec::new(); @@ -425,8 +417,8 @@ mod tests { // Emulate assembling a full runtime plugins list and using it to apply configuration let plugins = RuntimePlugins::new() // intentionally configure the plugins in the reverse order - .with_client_plugin(Plugin2) - .with_client_plugin(Plugin1); + .with_client_plugin(SharedRuntimePlugin::new(Plugin2)) + .with_client_plugin(SharedRuntimePlugin::new(Plugin1)); let mut cfg = ConfigBag::base(); let components = plugins.apply_client_configuration(&mut cfg).unwrap(); diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs index 85fdd345647..48703d0b6d8 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator.rs @@ -459,6 +459,7 @@ mod tests { use aws_smithy_runtime_api::client::orchestrator::HttpRequest; use aws_smithy_runtime_api::client::retries::SharedRetryStrategy; use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder; + use aws_smithy_runtime_api::client::runtime_plugin::SharedRuntimePlugin; use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, RuntimePlugins}; use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer}; use std::borrow::Cow; @@ -629,10 +630,10 @@ mod tests { let input = Input::doesnt_matter(); let runtime_plugins = RuntimePlugins::new() - .with_client_plugin(FailingInterceptorsClientRuntimePlugin::new()) - .with_operation_plugin(TestOperationRuntimePlugin::new()) - .with_operation_plugin(NoAuthRuntimePlugin::new()) - .with_operation_plugin(FailingInterceptorsOperationRuntimePlugin::new()); + .with_client_plugin(SharedRuntimePlugin::new(FailingInterceptorsClientRuntimePlugin::new())) + .with_operation_plugin(SharedRuntimePlugin::new(TestOperationRuntimePlugin::new())) + .with_operation_plugin(SharedRuntimePlugin::new(NoAuthRuntimePlugin::new())) + .with_operation_plugin(SharedRuntimePlugin::new(FailingInterceptorsOperationRuntimePlugin::new())); let actual = invoke("test", "test", input, &runtime_plugins) .await .expect_err("should error"); @@ -913,9 +914,9 @@ mod tests { let input = Input::doesnt_matter(); let runtime_plugins = RuntimePlugins::new() - .with_operation_plugin(TestOperationRuntimePlugin::new()) - .with_operation_plugin(NoAuthRuntimePlugin::new()) - .with_operation_plugin(InterceptorsTestOperationRuntimePlugin::new()); + .with_operation_plugin(SharedRuntimePlugin::new(TestOperationRuntimePlugin::new())) + .with_operation_plugin(SharedRuntimePlugin::new(NoAuthRuntimePlugin::new())) + .with_operation_plugin(SharedRuntimePlugin::new(InterceptorsTestOperationRuntimePlugin::new())); let actual = invoke("test", "test", input, &runtime_plugins) .await .expect_err("should error"); @@ -1155,8 +1156,8 @@ mod tests { async fn test_stop_points() { let runtime_plugins = || { RuntimePlugins::new() - .with_operation_plugin(TestOperationRuntimePlugin::new()) - .with_operation_plugin(NoAuthRuntimePlugin::new()) + .with_operation_plugin(SharedRuntimePlugin::new(TestOperationRuntimePlugin::new())) + .with_operation_plugin(SharedRuntimePlugin::new(NoAuthRuntimePlugin::new())) }; // StopPoint::None should result in a response getting set since orchestration doesn't stop @@ -1256,12 +1257,12 @@ mod tests { let interceptor = TestInterceptor::default(); let runtime_plugins = || { RuntimePlugins::new() - .with_operation_plugin(TestOperationRuntimePlugin::new()) - .with_operation_plugin(NoAuthRuntimePlugin::new()) - .with_operation_plugin(TestInterceptorRuntimePlugin { + .with_operation_plugin(SharedRuntimePlugin::new(TestOperationRuntimePlugin::new())) + .with_operation_plugin(SharedRuntimePlugin::new(NoAuthRuntimePlugin::new())) + .with_operation_plugin(SharedRuntimePlugin::new(TestInterceptorRuntimePlugin { builder: RuntimeComponentsBuilder::new("test") .with_interceptor(SharedInterceptor::new(interceptor.clone())), - }) + })) }; // StopPoint::BeforeTransmit will exit right before sending the request, so there should be no response diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs index 5257fa66675..ffb3ac92123 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/endpoints.rs @@ -24,21 +24,22 @@ use tracing::trace; /// An endpoint resolver that uses a static URI. #[derive(Clone, Debug)] pub struct StaticUriEndpointResolver { - endpoint: Uri, + endpoint: String, } impl StaticUriEndpointResolver { /// Create a resolver that resolves to `http://localhost:{port}`. pub fn http_localhost(port: u16) -> Self { Self { - endpoint: Uri::from_str(&format!("http://localhost:{port}")) - .expect("all u16 values are valid ports"), + endpoint: format!("http://localhost:{port}"), } } /// Create a resolver that resolves to the given URI. - pub fn uri(endpoint: Uri) -> Self { - Self { endpoint } + pub fn uri(endpoint: impl Into) -> Self { + Self { + endpoint: endpoint.into(), + } } } diff --git a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs index eb619328704..50740ae85ca 100644 --- a/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs +++ b/rust-runtime/aws-smithy-runtime/src/client/orchestrator/operation.rs @@ -32,7 +32,6 @@ use aws_smithy_runtime_api::client::ser_de::{ }; use aws_smithy_types::config_bag::{ConfigBag, Layer}; use aws_smithy_types::retry::RetryConfig; -use http::Uri; use std::borrow::Cow; use std::fmt; use std::marker::PhantomData; @@ -100,7 +99,7 @@ impl fmt::Debug for FnDeserializer { /// Orchestrates execution of a HTTP request without any modeled input or output. #[doc(hidden)] -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Operation { service_name: Cow<'static, str>, operation_name: Cow<'static, str>, @@ -108,6 +107,18 @@ pub struct Operation { _phantom: PhantomData<(I, O, E)>, } +// Manual Clone implementation needed to get rid of Clone bounds on I, O, and E +impl Clone for Operation { + fn clone(&self) -> Self { + Self { + service_name: self.service_name.clone(), + operation_name: self.operation_name.clone(), + runtime_plugins: self.runtime_plugins.clone(), + _phantom: self._phantom, + } + } +} + impl Operation<(), (), ()> { pub fn builder() -> OperationBuilder { OperationBuilder::new() @@ -186,7 +197,7 @@ impl OperationBuilder { self.config.store_put(EndpointResolverParams::new(())); self.runtime_components .set_endpoint_resolver(Some(SharedEndpointResolver::new( - StaticUriEndpointResolver::uri(Uri::try_from(url).expect("valid URI")), + StaticUriEndpointResolver::uri(url), ))); self } @@ -294,35 +305,49 @@ impl OperationBuilder { pub fn build(self) -> Operation { let service_name = self.service_name.expect("service_name required"); let operation_name = self.operation_name.expect("operation_name required"); - assert!( - self.runtime_components.http_connector().is_some(), - "a http_connector is required" - ); - assert!( - self.runtime_components.endpoint_resolver().is_some(), - "a endpoint_resolver is required" - ); - assert!( - self.runtime_components.retry_strategy().is_some(), - "a retry_strategy is required" - ); - assert!( - self.config.load::().is_some(), - "a serializer is required" - ); - assert!( - self.config.load::().is_some(), - "a deserializer is required" - ); - let mut runtime_plugins = RuntimePlugins::new().with_client_plugin( - StaticRuntimePlugin::new() - .with_config(self.config.freeze()) - .with_runtime_components(self.runtime_components), - ); + let mut runtime_plugins = + RuntimePlugins::new().with_client_plugin(SharedRuntimePlugin::new( + StaticRuntimePlugin::new() + .with_config(self.config.freeze()) + .with_runtime_components(self.runtime_components), + )); for runtime_plugin in self.runtime_plugins { runtime_plugins = runtime_plugins.with_client_plugin(runtime_plugin); } + #[cfg(debug_assertions)] + { + let mut config = ConfigBag::base(); + let components = runtime_plugins + .apply_client_configuration(&mut config) + .expect("the runtime plugins should succeed"); + + assert!( + components.http_connector().is_some(), + "a http_connector is required" + ); + assert!( + components.endpoint_resolver().is_some(), + "a endpoint_resolver is required" + ); + assert!( + components.retry_strategy().is_some(), + "a retry_strategy is required" + ); + assert!( + config.load::().is_some(), + "a serializer is required" + ); + assert!( + config.load::().is_some(), + "a deserializer is required" + ); + assert!( + config.load::().is_some(), + "endpoint resolver params are required" + ); + } + Operation { service_name, operation_name,