Skip to content

Commit

Permalink
Add operation input/output wrapper conversion functions (#863)
Browse files Browse the repository at this point in the history
The functions allow to wrap the model types into the wrappers and unwrap
the wrappers into the model types.

The wrapper types have been made private.
  • Loading branch information
david-perez authored Nov 15, 2021
1 parent ada03d4 commit e6600aa
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class ServerProtocolTestGenerator(
}
rustTemplate(
"""
let output = super::$operationName(output);
let output = super::$operationName::Output(output);
use #{Axum}::response::IntoResponse;
let http_response = output.into_response();
""",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,22 @@ private class ServerHttpProtocolImplGenerator(
private val logger = Logger.getLogger(javaClass.name)
private val symbolProvider = codegenContext.symbolProvider
private val model = codegenContext.model
private val errorType = RuntimeType("error", null, "crate")
private val runtimeConfig = codegenContext.runtimeConfig
private val httpBindingResolver = protocol.httpBindingResolver
private val operationDeserModule = RustModule.private("operation_deser")
private val operationSerModule = RustModule.private("operation_ser")
private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType()

private val codegenScope = arrayOf(
"JsonObjectWriter" to smithyJson.member("serialize::JsonObjectWriter"),
"http" to RuntimeType.http,
"Bytes" to RuntimeType.Bytes,
"LazyStatic" to CargoDependency.LazyStatic.asType(),
"Regex" to CargoDependency.Regex.asType(),
"PercentEncoding" to CargoDependency.PercentEncoding.asType(),
"Axum" to CargoDependency.Axum.asType(),
"DateTime" to RuntimeType.DateTime(runtimeConfig),
"HttpBody" to CargoDependency.HttpBody.asType(),
"Hyper" to CargoDependency.Hyper.asType(),
"LazyStatic" to CargoDependency.LazyStatic.asType(),
"PercentEncoding" to CargoDependency.PercentEncoding.asType(),
"Regex" to CargoDependency.Regex.asType(),
"SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(),
"SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig),
"SdkBody" to RuntimeType.sdkBody(runtimeConfig),
"DateTime" to RuntimeType.DateTime(runtimeConfig)
"http" to RuntimeType.http,
)

override fun generateTraitImpls(operationWriter: RustWriter, operationShape: OperationShape) {
Expand Down Expand Up @@ -140,12 +135,11 @@ private class ServerHttpProtocolImplGenerator(
outputSymbol: Symbol,
operationShape: OperationShape
) {
val errorSymbol = operationShape.errorSymbol(symbolProvider)
// Implement Axum `FromRequest` trait for non streaming input types.
val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"
rustTemplate(
"""
pub struct $inputName(#{I});
struct $inputName(#{I});
##[#{Axum}::async_trait]
impl<B> #{Axum}::extract::FromRequest<B> for $inputName
where
Expand All @@ -159,55 +153,119 @@ private class ServerHttpProtocolImplGenerator(
#{SmithyHttpServer}::protocols::check_json_content_type(req)?;
Ok($inputName(#{parse_request}(req).await?))
}
}""",
}""".trimIndent(),
*codegenScope,
"I" to inputSymbol,
"parse_request" to serverParseRequest(operationShape)
)

// Implement Axum `IntoResponse` for non streaming output types.
val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
rustTemplate(
"""
pub struct $outputName(#{O});
##[#{Axum}::async_trait]
impl #{Axum}::response::IntoResponse for $outputName {
type Body = #{SmithyHttpServer}::Body;
type BodyError = <Self::Body as #{SmithyHttpServer}::HttpBody>::Error;
fn into_response(self) -> #{http}::Response<Self::Body> {
match #{serialize_response}(&self.0) {
Ok(response) => response,
Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error")
}
}
}""",
*codegenScope,
"O" to outputSymbol,
"serialize_response" to serverSerializeResponse(operationShape)
)
val errorSymbol = operationShape.errorSymbol(symbolProvider)

val handleSerializeOutput = """
Ok(response) => response,
Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from output")
""".trimIndent()
if (operationShape.errors.isNotEmpty()) {
// Implement Axum `IntoResponse` for non streaming error types.
val errorName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_ERROR_WRAPPER_SUFFIX}"
// The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control
// that can in turn be converted into a response.
rustTemplate(
"""
pub struct $errorName(#{E});
enum $outputName {
Output(#{O}),
Error(#{E})
}
##[#{Axum}::async_trait]
impl #{Axum}::response::IntoResponse for $errorName {
impl #{Axum}::response::IntoResponse for $outputName {
type Body = #{SmithyHttpServer}::Body;
type BodyError = <Self::Body as #{SmithyHttpServer}::HttpBody>::Error;
fn into_response(self) -> #{http}::Response<Self::Body> {
match #{serialize_error}(&self.0) {
Ok(response) => response,
Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error")
match self {
Self::Output(o) => {
match #{serialize_response}(&o) {
$handleSerializeOutput
}
},
Self::Error(err) => {
match #{serialize_error}(&err) {
Ok(response) => response,
Err(e) => #{http}::Response::builder().body(Self::Body::from(e.to_string())).expect("unable to build response from error")
}
}
}
}
}""",
}""".trimIndent(),
*codegenScope,
"O" to outputSymbol,
"E" to errorSymbol,
"serialize_response" to serverSerializeResponse(operationShape),
"serialize_error" to serverSerializeError(operationShape)
)
} else {
// The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type
// we control that can in turn be converted into a response.
rustTemplate(
"""
struct $outputName(#{O});
##[#{Axum}::async_trait]
impl #{Axum}::response::IntoResponse for $outputName {
type Body = #{SmithyHttpServer}::Body;
type BodyError = <Self::Body as #{SmithyHttpServer}::HttpBody>::Error;
fn into_response(self) -> #{http}::Response<Self::Body> {
match #{serialize_response}(&self.0) {
$handleSerializeOutput
}
}
}""".trimIndent(),
*codegenScope,
"O" to outputSymbol,
"serialize_response" to serverSerializeResponse(operationShape)
)
}

// Implement conversion function to "wrap" from the model operation output types.
if (operationShape.errors.isNotEmpty()) {
rustTemplate(
"""
impl From<Result<#{O}, #{E}>> for $outputName {
fn from(res: Result<#{O}, #{E}>) -> Self {
match res {
Ok(v) => Self::Output(v),
Err(e) => Self::Error(e),
}
}
}
""".trimIndent(),
"O" to outputSymbol,
"E" to errorSymbol
)
} else {
rustTemplate(
"""
impl From<#{O}> for $outputName {
fn from(o: #{O}) -> Self {
Self(o)
}
}
""".trimIndent(),
"O" to outputSymbol
)
}

// Implement conversion function to "unwrap" into the model operation input types.
rustTemplate(
"""
impl From<$inputName> for #{I} {
fn from(i: $inputName) -> Self {
i.0
}
}
""".trimIndent(),
"I" to inputSymbol
)
}

/*
Expand Down

0 comments on commit e6600aa

Please sign in to comment.