Skip to content

Commit

Permalink
Add support for the awsQueryCompatible trait (#2398)
Browse files Browse the repository at this point in the history
* Add support for the awsQueryCompatible trait

This commit adds support for the awsQueryCompatible trait. This allows
services already supporting custom error codes through the AWS Query
protocol with the awsQueryError trait to continue supporting them after
the services switch to the AWS JSON 1.0 protocol.

* Add copyright header

* Fix clippy warning for clippy::manual-map

* Update CHANGELOG.next.toml

* Update CHANGELOG.next.toml

* Update CHANGELOG.next.toml

* Remove unused variables from `errorScope`

This commit addresses #2398 (comment)

* Reorder arguments for test verification

This commit addresses #2398 (comment)

---------

Co-authored-by: Yuki Saito <[email protected]>
  • Loading branch information
2 people authored and Velfi committed Feb 23, 2023
1 parent 2a4702b commit adc8540
Show file tree
Hide file tree
Showing 8 changed files with 416 additions and 1 deletion.
41 changes: 41 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,44 @@ message = "Support for constraint traits on member shapes (constraint trait prec
references = ["smithy-rs#1969"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "server" }
author = "drganjoo"

[[smithy-rs]]
message = """
Add support for the `awsQueryCompatible` trait. This allows services to continue supporting a custom error code (via the `awsQueryError` trait) when the services migrate their protocol from `awsQuery` to `awsJson1_0` annotated with `awsQueryCompatible`.
<details>
<summary>Click to expand for more details...</summary>
After the migration, services will include an additional header `x-amzn-query-error` in their responses whose value is in the form of `<error code>;<error type>`. An example response looks something like
```
HTTP/1.1 400
x-amzn-query-error: AWS.SimpleQueueService.NonExistentQueue;Sender
Date: Wed, 08 Sep 2021 23:46:52 GMT
Content-Type: application/x-amz-json-1.0
Content-Length: 163
{
"__type": "com.amazonaws.sqs#QueueDoesNotExist",
"message": "some user-visible message"
}
```
`<error code>` is `AWS.SimpleQueueService.NonExistentQueue` and `<error type>` is `Sender`.
If an operation results in an error that causes a service to send back the response above, you can access `<error code>` and `<error type>` as follows:
```rust
match client.some_operation().send().await {
Ok(_) => { /* success */ }
Err(sdk_err) => {
let err = sdk_err.into_service_error();
assert_eq!(
error.meta().code(),
Some("AWS.SimpleQueueService.NonExistentQueue"),
);
assert_eq!(error.meta().extra("type"), Some("Sender"));
}
}
</details>
```
"""
references = ["smithy-rs#2398"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "ysaito1001"
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@ package software.amazon.smithy.rust.codegen.client.smithy.protocols

import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.AwsQueryCompatibleTrait
import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait
import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsQueryCompatible
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsQueryProtocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Ec2QueryProtocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
Expand All @@ -25,6 +28,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolLoader
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.core.util.hasTrait

class ClientProtocolLoader(supportedProtocols: ProtocolMap<ClientProtocolGenerator, ClientCodegenContext>) :
ProtocolLoader<ClientProtocolGenerator, ClientCodegenContext>(supportedProtocols) {
Expand Down Expand Up @@ -57,12 +61,20 @@ private val CLIENT_PROTOCOL_SUPPORT = ProtocolSupport(

private class ClientAwsJsonFactory(private val version: AwsJsonVersion) :
ProtocolGeneratorFactory<HttpBoundProtocolGenerator, ClientCodegenContext> {
override fun protocol(codegenContext: ClientCodegenContext): Protocol = AwsJson(codegenContext, version)
override fun protocol(codegenContext: ClientCodegenContext): Protocol =
if (compatibleWithAwsQuery(codegenContext.serviceShape, version)) {
AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version))
} else {
AwsJson(codegenContext, version)
}

override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): HttpBoundProtocolGenerator =
HttpBoundProtocolGenerator(codegenContext, protocol(codegenContext))

override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT

private fun compatibleWithAwsQuery(serviceShape: ServiceShape, version: AwsJsonVersion) =
serviceShape.hasTrait<AwsQueryCompatibleTrait>() && version == AwsJsonVersion.Json10
}

private class ClientAwsQueryFactory : ProtocolGeneratorFactory<HttpBoundProtocolGenerator, ClientCodegenContext> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.client.smithy.protocols

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest

class AwsQueryCompatibleTest {
@Test
fun `aws-query-compatible json with aws query error should allow for retrieving error code and type from custom header`() {
val model = """
namespace test
use aws.protocols#awsJson1_0
use aws.protocols#awsQueryCompatible
use aws.protocols#awsQueryError
@awsQueryCompatible
@awsJson1_0
service TestService {
version: "2023-02-20",
operations: [SomeOperation]
}
operation SomeOperation {
input: SomeOperationInputOutput,
output: SomeOperationInputOutput,
errors: [InvalidThingException],
}
structure SomeOperationInputOutput {
a: String,
b: Integer
}
@awsQueryError(
code: "InvalidThing",
httpResponseCode: 400,
)
@error("client")
structure InvalidThingException {
message: String
}
""".asSmithyModel()

clientIntegrationTest(model) { clientCodegenContext, rustCrate ->
val moduleName = clientCodegenContext.moduleUseName()
rustCrate.integrationTest("should_parse_code_and_type_fields") {
rust(
"""
##[test]
fn should_parse_code_and_type_fields() {
use aws_smithy_http::response::ParseStrictResponse;
let response = http::Response::builder()
.header(
"x-amzn-query-error",
http::HeaderValue::from_static("AWS.SimpleQueueService.NonExistentQueue;Sender"),
)
.status(400)
.body(
r##"{
"__type": "com.amazonaws.sqs##QueueDoesNotExist",
"message": "Some user-visible message"
}"##,
)
.unwrap();
let some_operation = $moduleName::operation::SomeOperation::new();
let error = some_operation
.parse(&response.map(bytes::Bytes::from))
.err()
.unwrap();
assert_eq!(
Some("AWS.SimpleQueueService.NonExistentQueue"),
error.meta().code(),
);
assert_eq!(Some("Sender"), error.meta().extra("type"));
}
""",
)
}
}
}

@Test
fun `aws-query-compatible json without aws query error should allow for retrieving error code from payload`() {
val model = """
namespace test
use aws.protocols#awsJson1_0
use aws.protocols#awsQueryCompatible
@awsQueryCompatible
@awsJson1_0
service TestService {
version: "2023-02-20",
operations: [SomeOperation]
}
operation SomeOperation {
input: SomeOperationInputOutput,
output: SomeOperationInputOutput,
errors: [InvalidThingException],
}
structure SomeOperationInputOutput {
a: String,
b: Integer
}
@error("client")
structure InvalidThingException {
message: String
}
""".asSmithyModel()

clientIntegrationTest(model) { clientCodegenContext, rustCrate ->
val moduleName = clientCodegenContext.moduleUseName()
rustCrate.integrationTest("should_parse_code_from_payload") {
rust(
"""
##[test]
fn should_parse_code_from_payload() {
use aws_smithy_http::response::ParseStrictResponse;
let response = http::Response::builder()
.status(400)
.body(
r##"{
"__type": "com.amazonaws.sqs##QueueDoesNotExist",
"message": "Some user-visible message"
}"##,
)
.unwrap();
let some_operation = $moduleName::operation::SomeOperation::new();
let error = some_operation
.parse(&response.map(bytes::Bytes::from))
.err()
.unwrap();
assert_eq!(Some("QueueDoesNotExist"), error.meta().code());
assert_eq!(None, error.meta().extra("type"));
}
""",
)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ class InlineDependency(
CargoDependency.Http,
)

fun awsQueryCompatibleErrors(runtimeConfig: RuntimeConfig) =
forInlineableRustFile(
"aws_query_compatible_errors",
CargoDependency.smithyJson(runtimeConfig),
CargoDependency.Http,
)

fun idempotencyToken() =
forInlineableRustFile("idempotency_token", CargoDependency.FastRand)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null)
fun provideErrorMetadataTrait(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("error::metadata::ProvideErrorMetadata")
fun unhandledError(runtimeConfig: RuntimeConfig) = smithyTypes(runtimeConfig).resolve("error::Unhandled")
fun jsonErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.jsonErrors(runtimeConfig))
fun awsQueryCompatibleErrors(runtimeConfig: RuntimeConfig) = forInlineDependency(InlineDependency.awsQueryCompatibleErrors(runtimeConfig))
fun labelFormat(runtimeConfig: RuntimeConfig, func: String) = smithyHttp(runtimeConfig).resolve("label::$func")
fun operation(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).resolve("operation::Operation")
fun operationModule(runtimeConfig: RuntimeConfig) = smithyHttp(runtimeConfig).resolve("operation")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.core.smithy.protocols

import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.StructuredDataSerializerGenerator

class AwsQueryCompatibleHttpBindingResolver(
private val awsQueryBindingResolver: AwsQueryBindingResolver,
private val awsJsonHttpBindingResolver: AwsJsonHttpBindingResolver,
) : HttpBindingResolver {
override fun httpTrait(operationShape: OperationShape): HttpTrait =
awsJsonHttpBindingResolver.httpTrait(operationShape)

override fun requestBindings(operationShape: OperationShape): List<HttpBindingDescriptor> =
awsJsonHttpBindingResolver.requestBindings(operationShape)

override fun responseBindings(operationShape: OperationShape): List<HttpBindingDescriptor> =
awsJsonHttpBindingResolver.responseBindings(operationShape)

override fun errorResponseBindings(errorShape: ToShapeId): List<HttpBindingDescriptor> =
awsJsonHttpBindingResolver.errorResponseBindings(errorShape)

override fun errorCode(errorShape: ToShapeId): String =
awsQueryBindingResolver.errorCode(errorShape)

override fun requestContentType(operationShape: OperationShape): String =
awsJsonHttpBindingResolver.requestContentType(operationShape)

override fun responseContentType(operationShape: OperationShape): String =
awsJsonHttpBindingResolver.requestContentType(operationShape)
}

class AwsQueryCompatible(
val codegenContext: CodegenContext,
private val awsJson: AwsJson,
) : Protocol {
private val runtimeConfig = codegenContext.runtimeConfig
private val errorScope = arrayOf(
"Bytes" to RuntimeType.Bytes,
"ErrorMetadataBuilder" to RuntimeType.errorMetadataBuilder(runtimeConfig),
"JsonError" to CargoDependency.smithyJson(runtimeConfig).toType()
.resolve("deserialize::error::DeserializeError"),
"Response" to RuntimeType.Http.resolve("Response"),
"json_errors" to RuntimeType.jsonErrors(runtimeConfig),
"aws_query_compatible_errors" to RuntimeType.awsQueryCompatibleErrors(runtimeConfig),
)
private val jsonDeserModule = RustModule.private("json_deser")

override val httpBindingResolver: HttpBindingResolver =
AwsQueryCompatibleHttpBindingResolver(
AwsQueryBindingResolver(codegenContext.model),
AwsJsonHttpBindingResolver(codegenContext.model, awsJson.version),
)

override val defaultTimestampFormat = awsJson.defaultTimestampFormat

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator =
awsJson.structuredDataParser(operationShape)

override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
awsJson.structuredDataSerializer(operationShape)

override fun parseHttpErrorMetadata(operationShape: OperationShape): RuntimeType =
RuntimeType.forInlineFun("parse_http_error_metadata", jsonDeserModule) {
rustTemplate(
"""
pub fn parse_http_error_metadata(response: &#{Response}<#{Bytes}>) -> Result<#{ErrorMetadataBuilder}, #{JsonError}> {
let mut builder =
#{json_errors}::parse_error_metadata(response.body(), response.headers())?;
if let Some((error_code, error_type)) =
#{aws_query_compatible_errors}::parse_aws_query_compatible_error(response.headers())
{
builder = builder.code(error_code);
builder = builder.custom("type", error_type);
}
Ok(builder)
}
""",
*errorScope,
)
}

override fun parseEventStreamErrorMetadata(operationShape: OperationShape): RuntimeType =
awsJson.parseEventStreamErrorMetadata(operationShape)
}
Loading

0 comments on commit adc8540

Please sign in to comment.