Skip to content

Commit

Permalink
Use a single struct to represent the RequestExtension (#978)
Browse files Browse the repository at this point in the history
  • Loading branch information
crisidev authored Dec 15, 2021
1 parent d5750d4 commit 64ccdcc
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,9 @@ class ServerProtocolTestGenerator(
rustWriter.rust(
"""
let extensions = http_request.extensions().expect("unable to extract http request extensions");
let namespace = extensions.get::<aws_smithy_http_server::ExtensionNamespace>().expect("extension ExtensionNamespace not found");
assert_eq!(**namespace, ${operationShape.id.getNamespace().dq()});
let operation_name = extensions.get::<aws_smithy_http_server::ExtensionOperationName>().expect("extension ExtensionOperationName not found");
assert_eq!(**operation_name, ${operationSymbol.name.dq()});
let request_extensions = extensions.get::<aws_smithy_http_server::RequestExtensions>().expect("extension RequestExtensions not found");
assert_eq!(request_extensions.namespace, ${operationShape.id.getNamespace().dq()});
assert_eq!(request_extensions.operation_name, ${operationSymbol.name.dq()});
""".trimIndent()
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,7 @@ private class ServerHttpProtocolImplGenerator(
val operationName = symbolProvider.toSymbol(operationShape).name
return """
let extensions = req.extensions_mut().ok_or(#{SmithyHttpServer}::rejection::ExtensionsAlreadyExtracted)?;
extensions.insert(#{SmithyHttpServer}::ExtensionNamespace::new(${namespace.dq()}));
extensions.insert(#{SmithyHttpServer}::ExtensionOperationName::new(${operationName.dq()}));
extensions.insert(#{SmithyHttpServer}::RequestExtensions::new(${namespace.dq()}, ${operationName.dq()}));
""".trimIndent()
}

Expand Down Expand Up @@ -640,11 +639,13 @@ private class ServerHttpProtocolImplGenerator(
} else if (targetMapValue.isSetShape) {
QueryParamsTargetMapValueType.SET
} else {
throw ExpectationNotMetException("""
throw ExpectationNotMetException(
"""
@httpQueryParams trait applied to non-supported target
$targetMapValue of type ${targetMapValue.type}
""".trimIndent(),
targetMapValue.sourceLocation)
targetMapValue.sourceLocation
)
}

private fun serverRenderQueryStringParser(writer: RustWriter, operationShape: OperationShape) {
Expand All @@ -661,23 +662,25 @@ private class ServerHttpProtocolImplGenerator(
return
}

fun HttpBindingDescriptor.queryParamsBindingTargetMapValueType(): QueryParamsTargetMapValueType {
fun HttpBindingDescriptor.queryParamsBindingTargetMapValueType(): QueryParamsTargetMapValueType {
check(this.location == HttpLocation.QUERY_PARAMS)
val queryParamsTarget = model.expectShape(this.member.target)
val mapTarget = queryParamsTarget.asMapShape().get()
return queryParamsTargetMapValueType(model.expectShape(mapTarget.value.target))
}

with(writer) {
rustTemplate("""
rustTemplate(
"""
let query_string = request.uri().query().ok_or(#{SmithyHttpServer}::rejection::MissingQueryString)?;
let pairs = #{SerdeUrlEncoded}::from_str::<Vec<(&str, &str)>>(query_string)?;
""".trimIndent(),
*codegenScope
)

if (queryParamsBinding != null) {
rustTemplate("let mut query_params: #{HashMap}<String, " +
rustTemplate(
"let mut query_params: #{HashMap}<String, " +
"${queryParamsBinding.queryParamsBindingTargetMapValueType().asRustType().render()}> = #{HashMap}::new();",
"HashMap" to RustType.HashMap.RuntimeType,
)
Expand All @@ -694,15 +697,17 @@ private class ServerHttpProtocolImplGenerator(
rustBlock("for (k, v) in pairs") {
queryBindingsTargettingSimple.forEach {
val deserializer = generateParsePercentEncodedStrFn(it)
rustTemplate("""
rustTemplate(
"""
if !seen_${it.memberName.toSnakeCase()} && k == "${it.locationName}" {
input = input.${it.member.setterName()}(
#{deserializer}(v)?
);
seen_${it.memberName.toSnakeCase()} = true;
}
""".trimIndent(),
"deserializer" to deserializer)
""".trimIndent(),
"deserializer" to deserializer
)
}
queryBindingsTargettingCollection.forEach {
rustBlock("if k == ${it.locationName.dq()}") {
Expand All @@ -714,9 +719,12 @@ private class ServerHttpProtocolImplGenerator(
// `<_>::from()` is necessary to convert the `&str` into:
// * the Rust enum in case the `string` shape has the `enum` trait; or
// * `String` in case it doesn't.
rustTemplate("""
rustTemplate(
"""
let v = <_>::from(#{PercentEncoding}::percent_decode_str(v).decode_utf8()?.as_ref());
""".trimIndent(), *codegenScope)
""".trimIndent(),
*codegenScope
)
}
memberShape.isTimestampShape -> {
val index = HttpBindingIndex.of(model)
Expand All @@ -727,18 +735,22 @@ private class ServerHttpProtocolImplGenerator(
protocol.defaultTimestampFormat,
)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
rustTemplate("""
rustTemplate(
"""
let v = #{PercentEncoding}::percent_decode_str(v).decode_utf8()?;
let v = #{DateTime}::from_str(&v, #{format})?;
""".trimIndent(),
""".trimIndent(),
*codegenScope,
"format" to timestampFormatType,
)
}
else -> { // Number or boolean.
rust("""
rust(
"""
let v = <_ as #T>::parse_smithy_primitive(v)?;
""".trimIndent(), CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse"))
""".trimIndent(),
CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse")
)
}
}
rust("${it.memberName.toSnakeCase()}.push(v);")
Expand All @@ -750,10 +762,12 @@ private class ServerHttpProtocolImplGenerator(
QueryParamsTargetMapValueType.STRING -> {
rust("query_params.entry(String::from(k)).or_insert_with(|| String::from(v));")
} else -> {
rustTemplate("""
rustTemplate(
"""
let entry = query_params.entry(String::from(k)).or_default();
entry.push(String::from(v));
""".trimIndent())
""".trimIndent()
)
}
}
}
Expand All @@ -762,9 +776,11 @@ private class ServerHttpProtocolImplGenerator(
rust("input = input.${queryParamsBinding.member.setterName()}(Some(query_params));")
}
queryBindingsTargettingCollection.forEach {
rustTemplate("""
rustTemplate(
"""
input = input.${it.member.setterName()}(Some(${it.memberName.toSnakeCase()}));
""".trimIndent())
""".trimIndent()
)
}
}
}
Expand Down Expand Up @@ -810,10 +826,11 @@ private class ServerHttpProtocolImplGenerator(
// `<_>::from()` is necessary to convert the `&str` into:
// * the Rust enum in case the `string` shape has the `enum` trait; or
// * `String` in case it doesn't.
rustTemplate("""
rustTemplate(
"""
let value = <_>::from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref());
Ok(Some(value))
""".trimIndent(),
""".trimIndent(),
*codegenScope,
)
}
Expand Down
6 changes: 0 additions & 6 deletions rust-runtime/aws-smithy-http-server/rustfmt.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
edition = "2018"
max_width = 120
# The "Default" setting has a heuristic which splits lines too aggresively.
# We are willing to revisit this setting in future versions of rustfmt.
# Bugs:
# * https://github.com/rust-lang/rustfmt/issues/3119
# * https://github.com/rust-lang/rustfmt/issues/3120
use_small_heuristics = "Max"
# Prevent carriage returns
newline_style = "Unix"
30 changes: 22 additions & 8 deletions rust-runtime/aws-smithy-http-server/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,29 @@ use async_trait::async_trait;
use axum_core::extract::{FromRequest, RequestParts};
use std::ops::Deref;

/// Extension type used to store the Smithy model namespace.
#[derive(Debug, Clone)]
pub struct ExtensionNamespace(&'static str);
impl_extension_new_and_deref!(ExtensionNamespace);
/// Extension type used to store Smithy request information.
#[derive(Debug, Clone, Default, Copy)]
pub struct RequestExtensions {
/// Smithy model namespace.
pub namespace: &'static str,
/// Smithy operation name.
pub operation_name: &'static str,
}

/// Extension type used to store the Smithy operation name.
#[derive(Debug, Clone)]
pub struct ExtensionOperationName(&'static str);
impl_extension_new_and_deref!(ExtensionOperationName);
impl RequestExtensions {
/// Generates a new `RequestExtensions`.
pub fn new(namespace: &'static str, operation_name: &'static str) -> Self {
Self {
namespace,
operation_name,
}
}

/// Returns the current operation formatted as <namespace>#<operation_name>.
pub fn operation(&self) -> String {
format!("{}#{}", self.namespace, self.operation_name)
}
}

/// Extension type used to store the type of user defined error returned by an operation.
/// These are modeled errors, defined in the Smithy model.
Expand Down
4 changes: 1 addition & 3 deletions rust-runtime/aws-smithy-http-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ pub use self::body::{boxed, to_boxed, Body, BoxBody, HttpBody};
#[doc(inline)]
pub use self::error::Error;
#[doc(inline)]
pub use self::extension::{
Extension, ExtensionModeledError, ExtensionNamespace, ExtensionOperationName, ExtensionRejection,
};
pub use self::extension::{Extension, ExtensionModeledError, ExtensionRejection, RequestExtensions};
#[doc(inline)]
pub use self::routing::Router;
#[doc(inline)]
Expand Down
21 changes: 14 additions & 7 deletions rust-runtime/aws-smithy-http-server/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,19 @@ macro_rules! opaque_future {

pub use opaque_future;

/// Implements `Deref` for all `Extension` holding a `&'static, str`.
macro_rules! impl_deref {
($name:ident) => {
impl Deref for $name {
type Target = &'static str;

fn deref(&self) -> &Self::Target {
&self.0
}
}
};
}

/// Implements `new` for all `Extension` holding a `&'static, str`.
macro_rules! impl_extension_new_and_deref {
($name:ident) => {
Expand All @@ -242,12 +255,6 @@ macro_rules! impl_extension_new_and_deref {
}
}

impl Deref for $name {
type Target = &'static str;

fn deref(&self) -> &Self::Target {
&self.0
}
}
impl_deref!($name);
};
}
39 changes: 31 additions & 8 deletions rust-runtime/aws-smithy-http-server/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ pub struct Router<B = Body> {

impl<B> Clone for Router<B> {
fn clone(&self) -> Self {
Self { routes: self.routes.clone() }
Self {
routes: self.routes.clone(),
}
}
}

Expand All @@ -66,7 +68,9 @@ where
/// all requests.
#[doc(hidden)]
pub fn new() -> Self {
Self { routes: Default::default() }
Self {
routes: Default::default(),
}
}

/// Add a route to the router.
Expand Down Expand Up @@ -107,9 +111,15 @@ where
NewResBody: HttpBody<Data = bytes::Bytes> + Send + 'static,
NewResBody::Error: Into<BoxError>,
{
let layer = ServiceBuilder::new().layer_fn(Route::new).layer(MapResponseBodyLayer::new(boxed)).layer(layer);
let routes =
self.routes.into_iter().map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec)).collect();
let layer = ServiceBuilder::new()
.layer_fn(Route::new)
.layer(MapResponseBodyLayer::new(boxed))
.layer(layer);
let routes = self
.routes
.into_iter()
.map(|(route, request_spec)| (Layer::layer(&layer, route), request_spec))
.collect();
Router { routes }
}
}
Expand Down Expand Up @@ -142,8 +152,17 @@ where
}
}

let status_code = if method_not_allowed { StatusCode::METHOD_NOT_ALLOWED } else { StatusCode::NOT_FOUND };
RouterFuture::from_response(Response::builder().status(status_code).body(crate::body::empty()).unwrap())
let status_code = if method_not_allowed {
StatusCode::METHOD_NOT_ALLOWED
} else {
StatusCode::NOT_FOUND
};
RouterFuture::from_response(
Response::builder()
.status(status_code)
.body(crate::body::empty())
.unwrap(),
)
}
}

Expand Down Expand Up @@ -201,7 +220,11 @@ mod tests {
(
RequestSpec::from_parts(
Method::GET,
vec![PathSegment::Literal(String::from("a")), PathSegment::Label, PathSegment::Label],
vec![
PathSegment::Literal(String::from("a")),
PathSegment::Label,
PathSegment::Label,
],
vec![],
),
"A",
Expand Down
Loading

0 comments on commit 64ccdcc

Please sign in to comment.