Skip to content

Commit

Permalink
Server streaming body (#1023)
Browse files Browse the repository at this point in the history
Add support for server blob streaming requests and responses

Data is streamed over the HTTP body.

Signed-off-by: Guy Margalit <[email protected]>
Co-authored-by: david-perez <[email protected]>
  • Loading branch information
guymguym and david-perez authored Feb 3, 2022
1 parent 907c0f3 commit f76bc15
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpPro
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape

/**
Expand All @@ -39,6 +41,7 @@ class ServerOperationHandlerGenerator(
"PinProjectLite" to ServerCargoDependency.PinProjectLite.asType(),
"Tower" to ServerCargoDependency.Tower.asType(),
"FuturesUtil" to ServerCargoDependency.FuturesUtil.asType(),
"SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(),
"SmithyHttpServer" to CargoDependency.SmithyHttpServer(runtimeConfig).asType(),
"SmithyRejection" to ServerHttpProtocolGenerator.smithyRejection(runtimeConfig),
"Phantom" to ServerRuntimeType.Phantom,
Expand Down Expand Up @@ -132,13 +135,18 @@ class ServerOperationHandlerGenerator(
} else {
symbolProvider.toSymbol(operation.outputShape(model)).fullName
}
val streamingBodyTraitBounds = if (operation.inputShape(model).hasStreamingMember(model)) {
"\n B: Into<#{SmithyHttp}::byte_stream::ByteStream>,"
} else {
""
}
return """
$inputFn
Fut: std::future::Future<Output = $outputType> + Send,
B: $serverCrate::HttpBody + Send + 'static,
B: $serverCrate::HttpBody + Send + 'static, $streamingBodyTraitBounds
B::Data: Send,
B::Error: Into<$serverCrate::BoxError>,
$serverCrate::rejection::SmithyRejection: From<<B as $serverCrate::HttpBody>::Error>
"""
""".trimIndent()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.isStreaming
import software.amazon.smithy.rust.codegen.util.orNull
import software.amazon.smithy.rust.codegen.util.outputShape
import software.amazon.smithy.rust.codegen.util.toSnakeCase
Expand Down Expand Up @@ -212,16 +213,16 @@ class ServerProtocolTestGenerator(

rustTemplate(
"""
##[allow(unused_mut)] let mut http_request = http::Request::builder()
.uri("${httpRequestTestCase.uri}")
""",
*codegenScope
##[allow(unused_mut)] let mut http_request = http::Request::builder()
.uri("${httpRequestTestCase.uri}")
""",
*codegenScope
)
for (header in httpRequestTestCase.headers) {
rust(".header(${header.key.dq()}, ${header.value.dq()})")
}
rustTemplate(
"""
"""
.body(#{SmithyHttpServer}::Body::from(#{Bytes}::from_static(b${httpRequestTestCase.body.orNull()?.dq()})))
.unwrap();
""",
Expand Down Expand Up @@ -326,15 +327,37 @@ class ServerProtocolTestGenerator(
"""
use #{AxumCore}::extract::FromRequest;
let mut http_request = #{AxumCore}::extract::RequestParts::new(http_request);
let input_wrapper = super::$operationName::from_request(&mut http_request).await.expect("failed to parse request");
let input = input_wrapper.0;
let parsed = super::$operationName::from_request(&mut http_request).await.expect("failed to parse request").0;
""",
*codegenScope,
)
if (operationShape.outputShape(model).hasStreamingMember(model)) {
rustWriter.rust("""todo!("streaming types aren't supported yet");""")

if (inputShape.hasStreamingMember(model)) {
// A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members
// and handle the equality assertion separately.
for (member in inputShape.members()) {
val memberName = codegenContext.symbolProvider.toMemberName(member)
if (member.isStreaming(codegenContext.model)) {
rustWriter.rustTemplate(
"""
#{AssertEq}(
parsed.$memberName.collect().await.unwrap().into_bytes(),
expected.$memberName.collect().await.unwrap().into_bytes()
);
""",
*codegenScope
)
} else {
rustWriter.rustTemplate(
"""
#{AssertEq}(parsed.$memberName, expected.$memberName, "Unexpected value for `$memberName`");
""",
*codegenScope
)
}
}
} else {
rustWriter.rustTemplate("#{AssertEq}(input, expected);", *codegenScope)
rustWriter.rustTemplate("#{AssertEq}(parsed, expected);", *codegenScope)
}
}

Expand All @@ -357,7 +380,7 @@ class ServerProtocolTestGenerator(
assertOk(rustWriter) {
rustWriter.write(
"#T(&body, ${
rustWriter.escape(body).dq()
rustWriter.escape(body).dq()
}, #T::from(${(mediaType ?: "unknown").dq()}))",
RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "validate_body"),
RuntimeType.ProtocolTestHelper(codegenContext.runtimeConfig, "MediaType")
Expand Down Expand Up @@ -386,19 +409,19 @@ class ServerProtocolTestGenerator(
basicCheck(
requireHeaders,
rustWriter,
"required_headers",
actualExpression,
"require_headers"
"required_headers",
actualExpression,
"require_headers"
)
}

private fun checkForbidHeaders(rustWriter: RustWriter, actualExpression: String, forbidHeaders: List<String>) {
basicCheck(
forbidHeaders,
rustWriter,
"forbidden_headers",
"forbidden_headers",
actualExpression,
"forbid_headers"
"forbid_headers"
)
}

Expand Down Expand Up @@ -511,16 +534,7 @@ class ServerProtocolTestGenerator(
FailingTest(RestJson, "RestJsonNoInputAndNoOutput", Action.Response),
FailingTest(RestJson, "RestJsonNoInputAndOutputWithJson", Action.Response),
FailingTest(RestJson, "RestJsonSupportsNaNFloatInputs", Action.Request),
FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Request),
FailingTest(RestJson, "RestJsonStreamingTraitsWithNoBlobBody", Action.Request),
FailingTest(RestJson, "RestJsonStreamingTraitsWithBlob", Action.Response),
FailingTest(RestJson, "RestJsonStreamingTraitsWithNoBlobBody", Action.Response),
FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Request),
FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithNoBlobBody", Action.Request),
FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithBlob", Action.Response),
FailingTest(RestJson, "RestJsonStreamingTraitsRequireLengthWithNoBlobBody", Action.Response),
FailingTest(RestJson, "RestJsonStreamingTraitsWithMediaTypeWithBlob", Action.Request),
FailingTest(RestJson, "RestJsonStreamingTraitsWithMediaTypeWithBlob", Action.Response),
FailingTest(RestJson, "RestJsonHttpWithEmptyBlobPayload", Action.Request),
FailingTest(RestJson, "RestJsonHttpWithEmptyStructurePayload", Action.Request),

Expand Down Expand Up @@ -591,56 +605,64 @@ class ServerProtocolTestGenerator(
).asObjectNode().get()
).build()
private fun fixRestJsonAllQueryStringTypes(testCase: HttpRequestTestCase): HttpRequestTestCase =
testCase.toBuilder().params(
Node.parse("""{
"queryString": "Hello there",
"queryStringList": ["a", "b", "c"],
"queryStringSet": ["a", "b", "c"],
"queryByte": 1,
"queryShort": 2,
"queryInteger": 3,
"queryIntegerList": [1, 2, 3],
"queryIntegerSet": [1, 2, 3],
"queryLong": 4,
"queryFloat": 1.1,
"queryDouble": 1.1,
"queryDoubleList": [1.1, 2.1, 3.1],
"queryBoolean": true,
"queryBooleanList": [true, false, true],
"queryTimestamp": 1,
"queryTimestampList": [1, 2, 3],
"queryEnum": "Foo",
"queryEnumList": ["Foo", "Baz", "Bar"],
"queryParamsMapOfStringList": {
"String": ["Hello there"],
"StringList": ["a", "b", "c"],
"StringSet": ["a", "b", "c"],
"Byte": ["1"],
"Short": ["2"],
"Integer": ["3"],
"IntegerList": ["1", "2", "3"],
"IntegerSet": ["1", "2", "3"],
"Long": ["4"],
"Float": ["1.1"],
"Double": ["1.1"],
"DoubleList": ["1.1", "2.1", "3.1"],
"Boolean": ["true"],
"BooleanList": ["true", "false", "true"],
"Timestamp": ["1970-01-01T00:00:01Z"],
"TimestampList": ["1970-01-01T00:00:01Z", "1970-01-01T00:00:02Z", "1970-01-01T00:00:03Z"],
"Enum": ["Foo"],
"EnumList": ["Foo", "Baz", "Bar"]
testCase.toBuilder().params(
Node.parse(
"""
{
"queryString": "Hello there",
"queryStringList": ["a", "b", "c"],
"queryStringSet": ["a", "b", "c"],
"queryByte": 1,
"queryShort": 2,
"queryInteger": 3,
"queryIntegerList": [1, 2, 3],
"queryIntegerSet": [1, 2, 3],
"queryLong": 4,
"queryFloat": 1.1,
"queryDouble": 1.1,
"queryDoubleList": [1.1, 2.1, 3.1],
"queryBoolean": true,
"queryBooleanList": [true, false, true],
"queryTimestamp": 1,
"queryTimestampList": [1, 2, 3],
"queryEnum": "Foo",
"queryEnumList": ["Foo", "Baz", "Bar"],
"queryParamsMapOfStringList": {
"String": ["Hello there"],
"StringList": ["a", "b", "c"],
"StringSet": ["a", "b", "c"],
"Byte": ["1"],
"Short": ["2"],
"Integer": ["3"],
"IntegerList": ["1", "2", "3"],
"IntegerSet": ["1", "2", "3"],
"Long": ["4"],
"Float": ["1.1"],
"Double": ["1.1"],
"DoubleList": ["1.1", "2.1", "3.1"],
"Boolean": ["true"],
"BooleanList": ["true", "false", "true"],
"Timestamp": ["1970-01-01T00:00:01Z"],
"TimestampList": ["1970-01-01T00:00:01Z", "1970-01-01T00:00:02Z", "1970-01-01T00:00:03Z"],
"Enum": ["Foo"],
"EnumList": ["Foo", "Baz", "Bar"]
}
}
}""".trimMargin()).asObjectNode().get()
).build()
""".trimMargin()
).asObjectNode().get()
).build()
private fun fixRestJsonQueryStringEscaping(testCase: HttpRequestTestCase): HttpRequestTestCase =
testCase.toBuilder().params(
Node.parse("""{
"queryString": "%:/?#[]@!${'$'}&'()*+,;=😹",
"queryParamsMapOfStringList": {
"String": ["%:/?#[]@!${'$'}&'()*+,;=😹"]
}
}""".trimMargin()).asObjectNode().get()
Node.parse(
"""
{
"queryString": "%:/?#[]@!${'$'}&'()*+,;=😹",
"queryParamsMapOfStringList": {
"String": ["%:/?#[]@!${'$'}&'()*+,;=😹"]
}
}
""".trimMargin()
).asObjectNode().get()
).build()
// This test assumes that errors in responses are identified by an `X-Amzn-Errortype` header with the error shape name.
// However, Smithy specifications for AWS protocols that serialize to JSON recommend that new server implementations
Expand Down
Loading

0 comments on commit f76bc15

Please sign in to comment.