From 3722ea5f21a3d83e37ff744985a4716330ae965f Mon Sep 17 00:00:00 2001 From: Guy Margalit Date: Tue, 1 Feb 2022 13:31:45 +0200 Subject: [PATCH] Server streaming body Signed-off-by: Guy Margalit --- .../ServerOperationHandlerGenerator.kt | 11 +- .../protocol/ServerProtocolTestGenerator.kt | 126 ++++++++------- .../protocols/ServerHttpProtocolGenerator.kt | 151 ++++++++++-------- .../aws-smithy-http-server/Cargo.toml | 2 +- .../aws-smithy-http-server/src/lib.rs | 3 + .../aws-smithy-http-server/src/rejection.rs | 6 + .../aws-smithy-http/src/byte_stream.rs | 8 + 7 files changed, 175 insertions(+), 132 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt index da1a7153240..e8f24471dc5 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationHandlerGenerator.kt @@ -18,6 +18,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpPro import software.amazon.smithy.rust.codegen.smithy.CodegenContext import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol +import software.amazon.smithy.rust.codegen.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.outputShape /** @@ -132,13 +134,18 @@ class ServerOperationHandlerGenerator( } else { symbolProvider.toSymbol(operation.outputShape(model)).fullName } + val streamingBodyTraitBounds = if (operation.inputShape(model).hasStreamingMember(model)) { + "\n B: Into<#{SmithyHttpServer}::ByteStream>," + } else { + "" + } return """ $inputFn Fut: std::future::Future + Send, - B: $serverCrate::HttpBody + Send + 'static, + B: $serverCrate::HttpBody + Send + 'static, $streamingBodyTraitBounds B::Data: Send, B::Error: Into<$serverCrate::BoxError>, $serverCrate::rejection::SmithyRejection: From<::Error> - """ + """.trimIndent() } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 00f2a79de19..888298000d1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -212,16 +212,16 @@ class ServerProtocolTestGenerator( rustTemplate( """ - ##[allow(unused_mut)] let mut http_request = http::Request::builder() - .uri("${httpRequestTestCase.uri}") - """, - *codegenScope + ##[allow(unused_mut)] let mut http_request = http::Request::builder() + .uri("${httpRequestTestCase.uri}") + """, + *codegenScope ) for (header in httpRequestTestCase.headers) { rust(".header(${header.key.dq()}, ${header.value.dq()})") } rustTemplate( - """ + """ .body(#{SmithyHttpServer}::Body::from(#{Bytes}::from_static(b${httpRequestTestCase.body.orNull()?.dq()}))) .unwrap(); """, @@ -363,9 +363,9 @@ class ServerProtocolTestGenerator( basicCheck( requireHeaders, rustWriter, - "required_headers", - actualExpression, - "require_headers" + "required_headers", + actualExpression, + "require_headers" ) } @@ -373,9 +373,9 @@ class ServerProtocolTestGenerator( basicCheck( forbidHeaders, rustWriter, - "forbidden_headers", + "forbidden_headers", actualExpression, - "forbid_headers" + "forbid_headers" ) } @@ -525,14 +525,10 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "RestJsonSupportsNegativeInfinityFloatInputs", Action.Response), FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Request), FailingTest(RestJson, "RestJsonStreamingTraitsWithNoBlobBody", Action.Request), - FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Response), - FailingTest(RestJson, "RestJsonStreamingTraitsWithNoBlobBody", Action.Response), FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Request), FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithNoBlobBody", Action.Request), FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Response), - FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithNoBlobBody", Action.Response), FailingTest(RestJson, "RestJsonStreamingTraitsWithMediaTypeWithBlob", Action.Request), - FailingTest(RestJson, "RestJsonStreamingTraitsWithMediaTypeWithBlob", Action.Response), FailingTest(RestJson, "RestJsonHttpWithEmptyBlobPayload", Action.Request), FailingTest(RestJson, "RestJsonHttpWithEmptyStructurePayload", Action.Request), @@ -603,56 +599,64 @@ class ServerProtocolTestGenerator( ).asObjectNode().get() ).build() private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase): HttpRequestTestCase = - testCase.toBuilder().params( - Node.parse("""{ - "queryString": "Hello there", - "queryStringList": ["a", "b", "c"], - "queryStringSet": ["a", "b", "c"], - "queryByte": 1, - "queryShort": 2, - "queryInteger": 3, - "queryIntegerList": [1, 2, 3], - "queryIntegerSet": [1, 2, 3], - "queryLong": 4, - "queryFloat": 1.1, - "queryDouble": 1.1, - "queryDoubleList": [1.1, 2.1, 3.1], - "queryBoolean": true, - "queryBooleanList": [true, false, true], - "queryTimestamp": 1, - "queryTimestampList": [1, 2, 3], - "queryEnum": "Foo", - "queryEnumList": ["Foo", "Baz", "Bar"], - "queryParamsMapOfStringList": { - "String": ["Hello there"], - "StringList": ["a", "b", "c"], - "StringSet": ["a", "b", "c"], - "Byte": ["1"], - "Short": ["2"], - "Integer": ["3"], - "IntegerList": ["1", "2", "3"], - "IntegerSet": ["1", "2", "3"], - "Long": ["4"], - "Float": ["1.1"], - "Double": ["1.1"], - "DoubleList": ["1.1", "2.1", "3.1"], - "Boolean": ["true"], - "BooleanList": ["true", "false", "true"], - "Timestamp": ["1970-01-01T00:00:01Z"], - "TimestampList": ["1970-01-01T00:00:01Z", "1970-01-01T00:00:02Z", "1970-01-01T00:00:03Z"], - "Enum": ["Foo"], - "EnumList": ["Foo", "Baz", "Bar"] + testCase.toBuilder().params( + Node.parse( + """ + { + "queryString": "Hello there", + "queryStringList": ["a", "b", "c"], + "queryStringSet": ["a", "b", "c"], + "queryByte": 1, + "queryShort": 2, + "queryInteger": 3, + "queryIntegerList": [1, 2, 3], + "queryIntegerSet": [1, 2, 3], + "queryLong": 4, + "queryFloat": 1.1, + "queryDouble": 1.1, + "queryDoubleList": [1.1, 2.1, 3.1], + "queryBoolean": true, + "queryBooleanList": [true, false, true], + "queryTimestamp": 1, + "queryTimestampList": [1, 2, 3], + "queryEnum": "Foo", + "queryEnumList": ["Foo", "Baz", "Bar"], + "queryParamsMapOfStringList": { + "String": ["Hello there"], + "StringList": ["a", "b", "c"], + "StringSet": ["a", "b", "c"], + "Byte": ["1"], + "Short": ["2"], + "Integer": ["3"], + "IntegerList": ["1", "2", "3"], + "IntegerSet": ["1", "2", "3"], + "Long": ["4"], + "Float": ["1.1"], + "Double": ["1.1"], + "DoubleList": ["1.1", "2.1", "3.1"], + "Boolean": ["true"], + "BooleanList": ["true", "false", "true"], + "Timestamp": ["1970-01-01T00:00:01Z"], + "TimestampList": ["1970-01-01T00:00:01Z", "1970-01-01T00:00:02Z", "1970-01-01T00:00:03Z"], + "Enum": ["Foo"], + "EnumList": ["Foo", "Baz", "Bar"] + } } - }""".trimMargin()).asObjectNode().get() - ).build() + """.trimMargin() + ).asObjectNode().get() + ).build() private fun fixRestJsonQueryStringEscaping(testCase: HttpRequestTestCase): HttpRequestTestCase = testCase.toBuilder().params( - Node.parse("""{ - "queryString": "%:/?#[]@!${'$'}&'()*+,;=😹", - "queryParamsMapOfStringList": { - "String": ["%:/?#[]@!${'$'}&'()*+,;=😹"] - } - }""".trimMargin()).asObjectNode().get() + Node.parse( + """ + { + "queryString": "%:/?#[]@!${'$'}&'()*+,;=😹", + "queryParamsMapOfStringList": { + "String": ["%:/?#[]@!${'$'}&'()*+,;=😹"] + } + } + """.trimMargin() + ).asObjectNode().get() ).build() // This test assumes that errors in responses are identified by an `X-Amzn-Errortype` header with the error shape name. // However, Smithy specifications for AWS protocols that serialize to JSON recommend that new server implementations diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt index cca36a0adf5..82c4d8a4acf 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpProtocolGenerator.kt @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols import software.amazon.smithy.aws.traits.protocols.RestJson1Trait import software.amazon.smithy.aws.traits.protocols.RestXmlTrait import software.amazon.smithy.codegen.core.Symbol -import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex import software.amazon.smithy.model.node.ExpectationNotMetException import software.amazon.smithy.model.shapes.CollectionShape @@ -55,6 +54,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredData import software.amazon.smithy.rust.codegen.util.UNREACHABLE import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectTrait +import software.amazon.smithy.rust.codegen.util.findStreamingMember import software.amazon.smithy.rust.codegen.util.getTrait import software.amazon.smithy.rust.codegen.util.hasStreamingMember import software.amazon.smithy.rust.codegen.util.inputShape @@ -132,13 +132,12 @@ private class ServerHttpProtocolImplGenerator( } /* - * Generation of `FromRequest` and `IntoResponse`. They are currently only implemented for non-streaming request - * and response bodies, that is, models without streaming traits - * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html). - * For non-streaming request bodies, we require the HTTP body to be fully read in memory before parsing or - * deserialization. From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize + * Generation of `FromRequest` and `IntoResponse`. + * For non-streaming request bodies, that is, models without streaming traits + * (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html) + * we require the HTTP body to be fully read in memory before parsing or deserialization. + * From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize * an HTTP response to `Bytes`. - * TODO Add support for streaming. * These traits are the public entrypoint of the ser/de logic of the `aws-smithy-http-server` server. */ private fun RustWriter.renderTraits( @@ -147,38 +146,24 @@ private class ServerHttpProtocolImplGenerator( operationShape: OperationShape ) { val operationName = symbolProvider.toSymbol(operationShape).name - // Implement Axum `FromRequest` trait for input types. val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" - val fromRequest = if (operationShape.inputShape(model).hasStreamingMember(model)) { - // For streaming request bodies, we need to generate a different implementation of the `FromRequest` trait. - // It will first offer the streaming input to the parser and potentially read the body into memory - // if an error occurred or if the streaming parser indicates that it needs the full data to proceed. - """ - async fn from_request(_req: &mut #{AxumCore}::extract::RequestParts) -> Result { - todo!("Streaming support for input shapes is not yet supported in `smithy-rs`") - } - """.trimIndent() - } else { - """ - async fn from_request(req: &mut #{AxumCore}::extract::RequestParts) -> Result { - Ok($inputName(#{parse_request}(req).await?)) - } - """.trimIndent() - } + // Implement Axum `FromRequest` trait for input types. rustTemplate( """ pub struct $inputName(pub #{I}); ##[#{AsyncTrait}::async_trait] impl #{AxumCore}::extract::FromRequest for $inputName where - B: #{SmithyHttpServer}::HttpBody + Send, + B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)} B::Data: Send, B::Error: Into<#{SmithyHttpServer}::BoxError>, #{SmithyRejection}: From<::Error> { type Rejection = #{SmithyRejection}; - $fromRequest + async fn from_request(req: &mut #{AxumCore}::extract::RequestParts) -> Result { + Ok($inputName(#{parse_request}(req).await?)) + } } """.trimIndent(), *codegenScope, @@ -187,21 +172,19 @@ private class ServerHttpProtocolImplGenerator( ) // Implement Axum `IntoResponse` for output types. + val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}" val errorSymbol = operationShape.errorSymbol(symbolProvider) - val httpExtensions = setHttpExtensions(operationShape) - // For streaming response bodies, we need to generate a different implementation of the `IntoResponse` trait. - // The body type will have to be a `StreamBody`. The service implementer will return a `Stream` from their handler. - val intoResponseStreaming = "todo!(\"Streaming support for output shapes is not yet supported in `smithy-rs`\")" + if (operationShape.errors.isNotEmpty()) { - val intoResponseImpl = if (operationShape.outputShape(model).hasStreamingMember(model)) { - intoResponseStreaming - } else { + // The output of fallible operations is a `Result` which we convert into an + // isomorphic `enum` type we control that can in turn be converted into a response. + val intoResponseImpl = """ let mut response = match self { Self::Output(o) => { - match #{serialize_response}(&o) { + match #{serialize_response}(o) { Ok(response) => response, Err(e) => { e.into_response() @@ -223,9 +206,7 @@ private class ServerHttpProtocolImplGenerator( $httpExtensions response """.trimIndent() - } - // The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control - // that can in turn be converted into a response. + rustTemplate( """ pub enum $outputName { @@ -246,27 +227,25 @@ private class ServerHttpProtocolImplGenerator( "serialize_error" to serverSerializeError(operationShape) ) } else { - val handleSerializeOutput = if (operationShape.outputShape(model).hasStreamingMember(model)) { - intoResponseStreaming - } else { + // The output of non-fallible operations is a model type which we convert into + // a "wrapper" unit `struct` type we control that can in turn be converted into a response. + val intoResponseImpl = """ - let mut response = match #{serialize_response}(&self.0) { + let mut response = match #{serialize_response}(self.0) { Ok(response) => response, Err(e) => e.into_response() }; $httpExtensions response """.trimIndent() - } - // The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type - // we control that can in turn be converted into a response. + rustTemplate( """ pub struct $outputName(pub #{O}); ##[#{AsyncTrait}::async_trait] impl #{AxumCore}::response::IntoResponse for $outputName { fn into_response(self) -> #{AxumCore}::response::Response { - $handleSerializeOutput + $intoResponseImpl } } """.trimIndent(), @@ -335,6 +314,7 @@ private class ServerHttpProtocolImplGenerator( val inputSymbol = symbolProvider.toSymbol(inputShape) val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT) val unusedVars = if (includedMembers.isEmpty()) "##[allow(unused_variables)] " else "" + return RuntimeType.forInlineFun(fnName, operationDeserModule) { Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it) it.rustBlockTemplate( @@ -346,11 +326,11 @@ private class ServerHttpProtocolImplGenerator( #{SmithyRejection} > where - B: #{SmithyHttpServer}::HttpBody + Send, + B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)} B::Data: Send, B::Error: Into<#{SmithyHttpServer}::BoxError>, #{SmithyRejection}: From<::Error> - """, + """.trimIndent(), *codegenScope, "I" to inputSymbol, ) { @@ -371,8 +351,12 @@ private class ServerHttpProtocolImplGenerator( val outputSymbol = symbolProvider.toSymbol(outputShape) return RuntimeType.forInlineFun(fnName, operationSerModule) { Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it) + + // Note we only need to take ownership of the output in the case that it contains streaming members. + // However we currently always take ownership here, but worth noting in case in the future we want + // to generate different signatures for streaming vs non-streaming for some reason. it.rustBlockTemplate( - "pub fn $fnName(output: &#{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>", + "pub fn $fnName(output: #{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>", *codegenScope, "O" to outputSymbol, ) { @@ -459,13 +443,6 @@ private class ServerHttpProtocolImplGenerator( operationShape: OperationShape, bindings: List, ) { - val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) - structuredDataSerializer.serverOutputSerializer(operationShape)?.let { serializer -> - rust( - "let payload = #T(output)?;", - serializer - ) - } ?: rust("""let payload = "";""") // avoid non-usage warnings for response Attribute.AllowUnusedMut.render(this) rustTemplate("let mut builder = #{http}::Response::builder();", *codegenScope) @@ -477,6 +454,24 @@ private class ServerHttpProtocolImplGenerator( serializedValue(this) } } + val streamingMember = operationShape.outputShape(model).findStreamingMember(model) + if (streamingMember != null) { + val memberName = symbolProvider.toMemberName(streamingMember) + rustTemplate( + """ + let payload = #{SmithyHttpServer}::body::Body::wrap_stream(output.$memberName); + """, + *codegenScope, + ) + } else { + val structuredDataSerializer = protocol.structuredDataSerializer(operationShape) + structuredDataSerializer.serverOutputSerializer(operationShape)?.let { serializer -> + rust( + "let payload = #T(&output)?;", + serializer + ) + } ?: rust("""let payload = "";""") + } rustTemplate( """ builder.body(#{SmithyHttpServer}::body::to_boxed(payload))? @@ -510,11 +505,13 @@ private class ServerHttpProtocolImplGenerator( } val bindingGenerator = ServerResponseBindingGenerator(protocol, codegenContext, operationShape) - val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape?: operationShape) + val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape ?: operationShape) if (addHeadersFn != null) { + // notice that we need to borrow the output only for output shapes but not for error shapes + val outputOwnedOrBorrow = if (errorShape == null) "&output" else "output" rust( """ - builder = #{T}(output, builder)?; + builder = #{T}($outputOwnedOrBorrow, builder)?; """.trimIndent(), addHeadersFn ) @@ -528,12 +525,11 @@ private class ServerHttpProtocolImplGenerator( val operationName = symbolProvider.toSymbol(operationShape).name val member = binding.member return when (binding.location) { - HttpLocation.HEADER, HttpLocation.PREFIX_HEADERS, HttpLocation.DOCUMENT -> { - // All of these are handled separately. - null - } + HttpLocation.HEADER, + HttpLocation.PREFIX_HEADERS, + HttpLocation.DOCUMENT, HttpLocation.PAYLOAD -> { - logger.warning("[rust-server-codegen] $operationName: response serialization does not currently support ${binding.location} bindings") + // All of these are handled separately. null } HttpLocation.RESPONSE_CODE -> writable { @@ -562,6 +558,7 @@ private class ServerHttpProtocolImplGenerator( ) { val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape) val structuredDataParser = protocol.structuredDataParser(operationShape) + val streamingMember = inputShape.findStreamingMember(model) Attribute.AllowUnusedMut.render(this) rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider)) val parser = structuredDataParser.serverInputParser(operationShape) @@ -579,13 +576,23 @@ private class ServerHttpProtocolImplGenerator( *codegenScope, "parser" to parser, ) + } else if (streamingMember != null) { + rustTemplate( + """ + let body = request.take_body().ok_or(#{SmithyHttpServer}::rejection::BodyAlreadyExtracted)?; + input = input.${streamingMember.setterName()}(Some(body.into())); + """.trimIndent(), + *codegenScope + ) } - for (binding in bindings) { - val member = binding.member - val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) - if (parsedValue != null) { - withBlock("input = input.${member.setterName()}(", ");") { - parsedValue(this) + if (streamingMember == null) { + for (binding in bindings) { + val member = binding.member + val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser) + if (parsedValue != null) { + withBlock("input = input.${member.setterName()}(", ");") { + parsedValue(this) + } } } } @@ -1047,4 +1054,12 @@ private class ServerHttpProtocolImplGenerator( } } } + + private fun getStreamingBodyTraitBounds(operationShape: OperationShape): String { + if (operationShape.inputShape(model).hasStreamingMember(model)) { + return "\n B: Into<#{SmithyHttpServer}::ByteStream>," + } else { + return "" + } + } } diff --git a/rust-runtime/aws-smithy-http-server/Cargo.toml b/rust-runtime/aws-smithy-http-server/Cargo.toml index 9606e1787e5..5afeaa1e4a5 100644 --- a/rust-runtime/aws-smithy-http-server/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server/Cargo.toml @@ -26,7 +26,7 @@ bytes = "1.1" futures-util = { version = "0.3", default-features = false } http = "0.2" http-body = "0.4" -hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp"] } +hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] } mime = "0.3" nom = "7" pin-project-lite = "0.2" diff --git a/rust-runtime/aws-smithy-http-server/src/lib.rs b/rust-runtime/aws-smithy-http-server/src/lib.rs index 2e604645c4d..6e2d9b031b6 100644 --- a/rust-runtime/aws-smithy-http-server/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server/src/lib.rs @@ -30,6 +30,9 @@ pub use self::routing::Router; #[doc(inline)] pub use tower_http::add_extension::{AddExtension, AddExtensionLayer}; +#[doc(inline)] +pub use aws_smithy_http::byte_stream::ByteStream; + /// Alias for a type-erased error type. pub use axum_core::BoxError; diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index 6af2ff17129..2787bcab2cf 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -178,6 +178,12 @@ impl From for SmithyRejection { } } +impl From for SmithyRejection { + fn from(err: aws_smithy_types::date_time::DateTimeFormatError) -> Self { + SmithyRejection::Serialize(Serialize::from_err(err)) + } +} + impl From for SmithyRejection { fn from(err: aws_smithy_types::primitive::PrimitiveParseError) -> Self { SmithyRejection::Deserialize(Deserialize::from_err(err)) diff --git a/rust-runtime/aws-smithy-http/src/byte_stream.rs b/rust-runtime/aws-smithy-http/src/byte_stream.rs index 00b8cecb00c..fdd7ec91243 100644 --- a/rust-runtime/aws-smithy-http/src/byte_stream.rs +++ b/rust-runtime/aws-smithy-http/src/byte_stream.rs @@ -326,6 +326,14 @@ impl From> for ByteStream { } } +impl From for ByteStream { + fn from(input: hyper::Body) -> Self { + ByteStream::new(SdkBody::from_dyn( + input.map_err(|e| e.into_cause().unwrap()).boxed(), + )) + } +} + #[derive(Debug)] pub struct Error(Box);