Skip to content

Commit

Permalink
Refactor endpoints to be Smithy-native
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Aug 17, 2022
1 parent fd98756 commit 9a5fa97
Show file tree
Hide file tree
Showing 9 changed files with 520 additions and 83 deletions.
1 change: 1 addition & 0 deletions aws/rust-runtime/aws-endpoint/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ repository = "https://github.com/awslabs/smithy-rs"

[dependencies]
aws-smithy-http = { path = "../../../rust-runtime/aws-smithy-http" }
aws-smithy-types = { path = "../../../rust-runtime/aws-smithy-types"}
aws-types = { path = "../aws-types" }
http = "0.2.3"
regex = { version = "1.5.5", default-features = false, features = ["std"] }
Expand Down
209 changes: 179 additions & 30 deletions aws/rust-runtime/aws-endpoint/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,79 @@ pub mod partition;
pub use partition::Partition;
#[doc(hidden)]
pub use partition::PartitionResolver;
use std::collections::HashMap;

use aws_smithy_http::endpoint::EndpointPrefix;
use aws_smithy_http::endpoint::Error as EndpointError;
use aws_smithy_http::endpoint::{apply_endpoint, EndpointPrefix, ResolveEndpoint};
use aws_smithy_http::middleware::MapRequest;
use aws_smithy_http::operation::Request;
use aws_smithy_http::property_bag::PropertyBag;
use aws_smithy_types::endpoint::Endpoint as SmithyEndpoint;
use aws_smithy_types::Document;
use aws_types::region::{Region, SigningRegion};
use aws_types::SigningService;
use http::header::HeaderName;
use http::{HeaderValue, Uri};
use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::str::FromStr;
use std::sync::Arc;

pub use aws_types::endpoint::{AwsEndpoint, BoxError, CredentialScope, ResolveAwsEndpoint};

type AwsEndpointResolver = Arc<dyn ResolveAwsEndpoint>;
pub fn get_endpoint_resolver(properties: &PropertyBag) -> Option<&AwsEndpointResolver> {
properties.get()
#[doc(hidden)]
pub struct Params {
region: Option<Region>,
}

pub fn set_endpoint_resolver(properties: &mut PropertyBag, provider: AwsEndpointResolver) {
properties.insert(provider);
impl Params {
pub fn new(region: Option<Region>) -> Self {
Self { region }
}
}

#[doc(hidden)]
pub struct EndpointShim(Arc<dyn ResolveAwsEndpoint>);
impl EndpointShim {
pub fn from_resolver(resolver: impl ResolveAwsEndpoint + 'static) -> Self {
Self(Arc::new(resolver))
}

pub fn from_arc(arc: Arc<dyn ResolveAwsEndpoint>) -> Self {
Self(arc)
}
}
impl ResolveEndpoint<Params> for EndpointShim {
fn resolve_endpoint(
&self,
params: &Params,
) -> Result<SmithyEndpoint, aws_smithy_http::endpoint::Error> {
let aws_endpoint = self
.0
.resolve_endpoint(params.region.as_ref().unwrap_or(
/*EndpointError::message("no region in params")*/
&Region::from_static("us-east-1"),
))
.map_err(|err| EndpointError::message("failure resolving endpoint").with_cause(err))?;
let uri = aws_endpoint.endpoint().uri();
let mut auth_scheme = HashMap::from([("name".to_string(), Document::String("v4".into()))]);
if let Some(region) = aws_endpoint.credential_scope().region() {
auth_scheme.insert(
"signingScope".to_string(),
region.as_ref().to_string().into(),
);
}
if let Some(service) = aws_endpoint.credential_scope().service() {
auth_scheme.insert(
"signingName".to_string(),
service.as_ref().to_string().into(),
);
}
Ok(SmithyEndpoint::builder()
.url(uri.to_string())
.property("authSchemes", vec![Document::Object(auth_scheme)])
.build())
}
}

/// Middleware Stage to Add an Endpoint to a Request
Expand All @@ -56,37 +108,95 @@ impl Display for AwsEndpointStageError {
Debug::fmt(self, f)
}
}

impl Error for AwsEndpointStageError {}

impl MapRequest for AwsEndpointStage {
type Error = AwsEndpointStageError;

fn apply(&self, request: Request) -> Result<Request, Self::Error> {
request.augment(|mut http_req, props| {
let provider =
get_endpoint_resolver(props).ok_or(AwsEndpointStageError::NoEndpointResolver)?;
let region = props
.get::<Region>()
.ok_or(AwsEndpointStageError::NoRegion)?;
let endpoint = provider
.resolve_endpoint(region)
.map_err(AwsEndpointStageError::EndpointResolutionError)?;
tracing::debug!(endpoint = ?endpoint, base_region = ?region, "resolved endpoint");
let signing_region = endpoint
.credential_scope()
.region()
.cloned()
.unwrap_or_else(|| region.clone().into());
props.insert::<SigningRegion>(signing_region);
if let Some(signing_service) = endpoint.credential_scope().service() {
props.insert::<SigningService>(signing_service.clone());
/*
We want to return an endpoint resolution error if there was one. To do that, we
need to take ownership. But also, for retries, we need to be careful to ensure that the
property bag doesn't get broken across subsequent requests, so we need to remove it,
potentially return a (non retryable) endpoint resolution error
*/
let endpoint_result = props
.remove::<aws_smithy_http::endpoint::Result>()
.ok_or(AwsEndpointStageError::NoEndpointResolver)?;
let endpoint = endpoint_result.map_err(|err|AwsEndpointStageError::EndpointResolutionError(err.into()))?;
props.insert::<aws_smithy_http::endpoint::Result>(Ok(endpoint));
// unwrap safety: we just put it in the bag, we can take it out again.
let endpoint = props
.get::<aws_smithy_http::endpoint::Result>()
.map(|res|res.as_ref().unwrap()).unwrap();
let (uri, signing_scope_override, signing_service_override) = smithy_to_aws(&endpoint)
.map_err(|err| AwsEndpointStageError::EndpointResolutionError(err))?;
tracing::debug!(endpoint = ?endpoint, base_region = ?signing_scope_override, "resolved endpoint");
apply_endpoint(http_req.uri_mut(), &uri, props.get::<EndpointPrefix>());
for (header_name, header_values) in endpoint.headers() {
http_req.headers_mut().remove(header_name);
for value in header_values {
http_req.headers_mut().insert(
HeaderName::from_str(header_name)
.map_err(|err|AwsEndpointStageError::EndpointResolutionError(err.into()))?,
HeaderValue::from_str(value)
.map_err(|err|AwsEndpointStageError::EndpointResolutionError(err.into()))?,
);
}
}

if let Some(signing_scope) = signing_scope_override {
props.insert(signing_scope);
}
if let Some(signing_service) = signing_service_override {
props.insert(signing_service);
}
endpoint.set_endpoint(http_req.uri_mut(), props.get::<EndpointPrefix>());
Ok(http_req)
})
}
}

fn smithy_to_aws(
value: &SmithyEndpoint,
) -> Result<(Uri, Option<SigningRegion>, Option<SigningService>), Box<dyn Error + Send + Sync>> {
let uri: Uri = value.url().parse()?;
// look for v4 as an auth scheme
let auth_schemes = match value
.properties()
.get("authSchemes")
.ok_or("no auth schemes in metadata")?
{
Document::Array(schemes) => schemes,
_other => return Err("expected an array for authSchemes".into()),
};
let v4 = auth_schemes
.iter()
.flat_map(|doc| match doc {
Document::Object(map)
if map.get("name") == Some(&Document::String("v4".to_string())) =>
{
Some(map)
}
_ => None,
})
.next()
.ok_or("could not find v4 as an acceptable auth scheme")?;

let signing_scope = match v4.get("signingScope") {
Some(Document::String(s)) => Some(SigningRegion::from(Region::new(s.clone()))),
None => None,
_ => return Err("unexpected type".into()),
};
let signing_service = match v4.get("signingName") {
Some(Document::String(s)) => Some(SigningService::from(s.to_string())),
None => None,
_ => return Err("unexpected type".into()),
};
Ok((uri, signing_scope, signing_service))
}

#[cfg(test)]
mod test {
use std::sync::Arc;
Expand All @@ -95,13 +205,15 @@ mod test {
use http::Uri;

use aws_smithy_http::body::SdkBody;
use aws_smithy_http::endpoint::ResolveEndpoint;
use aws_smithy_http::middleware::MapRequest;
use aws_smithy_http::operation;
use aws_types::endpoint::CredentialScope;
use aws_types::region::{Region, SigningRegion};
use aws_types::SigningService;

use crate::partition::endpoint::{Metadata, Protocol, SignatureVersion};
use crate::{set_endpoint_resolver, AwsEndpointStage, CredentialScope};
use crate::{AwsEndpointStage, EndpointShim, Params};

#[test]
fn default_endpoint_updates_request() {
Expand All @@ -118,7 +230,10 @@ mod test {
let mut props = req.properties_mut();
props.insert(region.clone());
props.insert(SigningService::from_static("kinesis"));
set_endpoint_resolver(&mut props, provider);
props.insert(
EndpointShim::from_arc(provider)
.resolve_endpoint(&Params::new(Some(region.clone()))),
);
};
let req = AwsEndpointStage.apply(req).expect("should succeed");
assert_eq!(req.properties().get(), Some(&SigningRegion::from(region)));
Expand Down Expand Up @@ -151,9 +266,12 @@ mod test {
let mut req = operation::Request::new(req);
{
let mut props = req.properties_mut();
props.insert(region);
props.insert(SigningService::from_static("kinesis"));
set_endpoint_resolver(&mut props, provider);
props.insert(region.clone());
props.insert(SigningService::from_static("qldb"));
props.insert(
EndpointShim::from_arc(provider)
.resolve_endpoint(&Params::new(Some(region.clone()))),
);
};
let req = AwsEndpointStage.apply(req).expect("should succeed");
assert_eq!(
Expand All @@ -165,4 +283,35 @@ mod test {
Some(&SigningService::from_static("qldb-override"))
);
}

#[test]
fn supports_fallback_when_scope_is_unset() {
let provider = Arc::new(Metadata {
uri_template: "www.service.com",
protocol: Protocol::Http,
credential_scope: CredentialScope::builder().build(),
signature_versions: SignatureVersion::V4,
});
let req = http::Request::new(SdkBody::from(""));
let region = Region::new("us-east-1");
let mut req = operation::Request::new(req);
{
let mut props = req.properties_mut();
props.insert(region.clone());
props.insert(SigningService::from_static("qldb"));
props.insert(
EndpointShim::from_arc(provider)
.resolve_endpoint(&Params::new(Some(region.clone()))),
);
};
let req = AwsEndpointStage.apply(req).expect("should succeed");
assert_eq!(
req.properties().get(),
Some(&SigningRegion::from(Region::new("us-east-1")))
);
assert_eq!(
req.properties().get(),
Some(&SigningService::from_static("qldb"))
);
}
}
Loading

0 comments on commit 9a5fa97

Please sign in to comment.