Skip to content

Commit

Permalink
Server streaming body
Browse files Browse the repository at this point in the history
Signed-off-by: Guy Margalit <[email protected]>
  • Loading branch information
guymguym committed Jan 27, 2022
1 parent 8676219 commit a4ff63f
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerHttpPro
import software.amazon.smithy.rust.codegen.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape

/**
Expand Down Expand Up @@ -132,13 +134,18 @@ class ServerOperationHandlerGenerator(
} else {
symbolProvider.toSymbol(operation.outputShape(model)).fullName
}
val streamingBodyTraitBounds = if (operation.inputShape(model).hasStreamingMember(model)) {
"\n B: Into<#{SmithyHttpServer}::ByteStream>,"
} else {
""
}
return """
$inputFn
Fut: std::future::Future<Output = $outputType> + Send,
B: $serverCrate::HttpBody + Send + 'static,
B: $serverCrate::HttpBody + Send + 'static, $streamingBodyTraitBounds
B::Data: Send,
B::Error: Into<$serverCrate::BoxError>,
$serverCrate::rejection::SmithyRejection: From<<B as $serverCrate::HttpBody>::Error>
"""
""".trimIndent()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.node.ExpectationNotMetException
import software.amazon.smithy.model.shapes.CollectionShape
Expand Down Expand Up @@ -55,6 +54,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredData
import software.amazon.smithy.rust.codegen.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.findStreamingMember
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.inputShape
Expand Down Expand Up @@ -132,13 +132,12 @@ private class ServerHttpProtocolImplGenerator(
}

/*
* Generation of `FromRequest` and `IntoResponse`. They are currently only implemented for non-streaming request
* and response bodies, that is, models without streaming traits
* (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html).
* For non-streaming request bodies, we require the HTTP body to be fully read in memory before parsing or
* deserialization. From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize
* Generation of `FromRequest` and `IntoResponse`.
* For non-streaming request bodies, that is, models without streaming traits
* (https://awslabs.github.io/smithy/1.0/spec/core/stream-traits.html)
* we require the HTTP body to be fully read in memory before parsing or deserialization.
* From a server perspective we need a way to parse an HTTP request from `Bytes` and serialize
* an HTTP response to `Bytes`.
* TODO Add support for streaming.
* These traits are the public entrypoint of the ser/de logic of the `aws-smithy-http-server` server.
*/
private fun RustWriter.renderTraits(
Expand All @@ -147,38 +146,24 @@ private class ServerHttpProtocolImplGenerator(
operationShape: OperationShape
) {
val operationName = symbolProvider.toSymbol(operationShape).name
// Implement Axum `FromRequest` trait for input types.
val inputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"

val fromRequest = if (operationShape.inputShape(model).hasStreamingMember(model)) {
// For streaming request bodies, we need to generate a different implementation of the `FromRequest` trait.
// It will first offer the streaming input to the parser and potentially read the body into memory
// if an error occurred or if the streaming parser indicates that it needs the full data to proceed.
"""
async fn from_request(_req: &mut #{AxumCore}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
todo!("Streaming support for input shapes is not yet supported in `smithy-rs`")
}
""".trimIndent()
} else {
"""
async fn from_request(req: &mut #{AxumCore}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
Ok($inputName(#{parse_request}(req).await?))
}
""".trimIndent()
}
// Implement Axum `FromRequest` trait for input types.
rustTemplate(
"""
pub struct $inputName(pub #{I});
##[#{AsyncTrait}::async_trait]
impl<B> #{AxumCore}::extract::FromRequest<B> for $inputName
where
B: #{SmithyHttpServer}::HttpBody + Send,
B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)}
B::Data: Send,
B::Error: Into<#{SmithyHttpServer}::BoxError>,
#{SmithyRejection}: From<<B as #{SmithyHttpServer}::HttpBody>::Error>
{
type Rejection = #{SmithyRejection};
$fromRequest
async fn from_request(req: &mut #{AxumCore}::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
Ok($inputName(#{parse_request}(req).await?))
}
}
""".trimIndent(),
*codegenScope,
Expand All @@ -187,21 +172,19 @@ private class ServerHttpProtocolImplGenerator(
)

// Implement Axum `IntoResponse` for output types.

val outputName = "${operationName}${ServerHttpProtocolGenerator.OPERATION_OUTPUT_WRAPPER_SUFFIX}"
val errorSymbol = operationShape.errorSymbol(symbolProvider)

val httpExtensions = setHttpExtensions(operationShape)
// For streaming response bodies, we need to generate a different implementation of the `IntoResponse` trait.
// The body type will have to be a `StreamBody`. The service implementer will return a `Stream` from their handler.
val intoResponseStreaming = "todo!(\"Streaming support for output shapes is not yet supported in `smithy-rs`\")"

if (operationShape.errors.isNotEmpty()) {
val intoResponseImpl = if (operationShape.outputShape(model).hasStreamingMember(model)) {
intoResponseStreaming
} else {
// The output of fallible operations is a `Result` which we convert into an
// isomorphic `enum` type we control that can in turn be converted into a response.
val intoResponseImpl =
"""
let mut response = match self {
Self::Output(o) => {
match #{serialize_response}(&o) {
match #{serialize_response}(o) {
Ok(response) => response,
Err(e) => {
e.into_response()
Expand All @@ -223,9 +206,7 @@ private class ServerHttpProtocolImplGenerator(
$httpExtensions
response
""".trimIndent()
}
// The output of fallible operations is a `Result` which we convert into an isomorphic `enum` type we control
// that can in turn be converted into a response.

rustTemplate(
"""
pub enum $outputName {
Expand All @@ -246,25 +227,18 @@ private class ServerHttpProtocolImplGenerator(
"serialize_error" to serverSerializeError(operationShape)
)
} else {
val handleSerializeOutput = if (operationShape.outputShape(model).hasStreamingMember(model)) {
intoResponseStreaming
} else {
"""
match #{serialize_response}(&self.0) {
Ok(response) => response,
Err(e) => e.into_response()
}
""".trimIndent()
}
// The output of non-fallible operations is a model type which we convert into a "wrapper" unit `struct` type
// we control that can in turn be converted into a response.
// The output of non-fallible operations is a model type which we convert into
// a "wrapper" unit `struct` type we control that can in turn be converted into a response.
rustTemplate(
"""
pub struct $outputName(pub #{O});
##[#{AsyncTrait}::async_trait]
impl #{AxumCore}::response::IntoResponse for $outputName {
fn into_response(self) -> #{AxumCore}::response::Response {
$handleSerializeOutput
match #{serialize_response}(self.0) {
Ok(response) => response,
Err(e) => e.into_response()
}
}
}
""".trimIndent(),
Expand Down Expand Up @@ -333,6 +307,7 @@ private class ServerHttpProtocolImplGenerator(
val inputSymbol = symbolProvider.toSymbol(inputShape)
val includedMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT)
val unusedVars = if (includedMembers.isEmpty()) "##[allow(unused_variables)] " else ""

return RuntimeType.forInlineFun(fnName, operationDeserModule) {
Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it)
it.rustBlockTemplate(
Expand All @@ -344,11 +319,11 @@ private class ServerHttpProtocolImplGenerator(
#{SmithyRejection}
>
where
B: #{SmithyHttpServer}::HttpBody + Send,
B: #{SmithyHttpServer}::HttpBody + Send, ${getStreamingBodyTraitBounds(operationShape)}
B::Data: Send,
B::Error: Into<#{SmithyHttpServer}::BoxError>,
#{SmithyRejection}: From<<B as #{SmithyHttpServer}::HttpBody>::Error>
""",
""".trimIndent(),
*codegenScope,
"I" to inputSymbol,
) {
Expand All @@ -369,8 +344,12 @@ private class ServerHttpProtocolImplGenerator(
val outputSymbol = symbolProvider.toSymbol(outputShape)
return RuntimeType.forInlineFun(fnName, operationSerModule) {
Attribute.Custom("allow(clippy::unnecessary_wraps)").render(it)

// Note we only need to take ownership of the output in the case that it contains streaming members.
// However we currently always take ownership here, but worth noting in case in the future we want
// to generate different signatures for streaming vs non-streaming for some reason.
it.rustBlockTemplate(
"pub fn $fnName(output: &#{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>",
"pub fn $fnName(output: #{O}) -> std::result::Result<#{AxumCore}::response::Response, #{SmithyRejection}>",
*codegenScope,
"O" to outputSymbol,
) {
Expand Down Expand Up @@ -457,13 +436,6 @@ private class ServerHttpProtocolImplGenerator(
operationShape: OperationShape,
bindings: List<HttpBindingDescriptor>,
) {
val structuredDataSerializer = protocol.structuredDataSerializer(operationShape)
structuredDataSerializer.serverOutputSerializer(operationShape)?.let { serializer ->
rust(
"let payload = #T(output)?;",
serializer
)
} ?: rust("""let payload = "";""")
// avoid non-usage warnings for response
Attribute.AllowUnusedMut.render(this)
rustTemplate("let mut builder = #{http}::Response::builder();", *codegenScope)
Expand All @@ -475,6 +447,24 @@ private class ServerHttpProtocolImplGenerator(
serializedValue(this)
}
}
val streamingMember = operationShape.outputShape(model).findStreamingMember(model)
if (streamingMember != null) {
val memberName = symbolProvider.toMemberName(streamingMember)
rustTemplate(
"""
let payload = #{SmithyHttpServer}::body::Body::wrap_stream(output.$memberName);
""",
*codegenScope,
)
} else {
val structuredDataSerializer = protocol.structuredDataSerializer(operationShape)
structuredDataSerializer.serverOutputSerializer(operationShape)?.let { serializer ->
rust(
"let payload = #T(&output)?;",
serializer
)
} ?: rust("""let payload = "";""")
}
rustTemplate(
"""
builder.body(#{SmithyHttpServer}::body::to_boxed(payload))?
Expand Down Expand Up @@ -508,7 +498,7 @@ private class ServerHttpProtocolImplGenerator(
}

val bindingGenerator = ServerResponseBindingGenerator(protocol, codegenContext, operationShape)
val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape?: operationShape)
val addHeadersFn = bindingGenerator.generateAddHeadersFn(errorShape ?: operationShape)
if (addHeadersFn != null) {
rust(
"""
Expand All @@ -526,12 +516,11 @@ private class ServerHttpProtocolImplGenerator(
val operationName = symbolProvider.toSymbol(operationShape).name
val member = binding.member
return when (binding.location) {
HttpLocation.HEADER, HttpLocation.PREFIX_HEADERS, HttpLocation.DOCUMENT -> {
// All of these are handled separately.
null
}
HttpLocation.HEADER,
HttpLocation.PREFIX_HEADERS,
HttpLocation.DOCUMENT,
HttpLocation.PAYLOAD -> {
logger.warning("[rust-server-codegen] $operationName: response serialization does not currently support ${binding.location} bindings")
// All of these are handled separately.
null
}
HttpLocation.RESPONSE_CODE -> writable {
Expand Down Expand Up @@ -560,6 +549,7 @@ private class ServerHttpProtocolImplGenerator(
) {
val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape)
val structuredDataParser = protocol.structuredDataParser(operationShape)
val streamingMember = inputShape.findStreamingMember(model)
Attribute.AllowUnusedMut.render(this)
rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider))
val parser = structuredDataParser.serverInputParser(operationShape)
Expand All @@ -577,13 +567,23 @@ private class ServerHttpProtocolImplGenerator(
*codegenScope,
"parser" to parser,
)
} else if (streamingMember != null) {
rustTemplate(
"""
let body = request.take_body().ok_or(#{SmithyHttpServer}::rejection::BodyAlreadyExtracted)?;
input = input.${streamingMember.setterName()}(Some(body.into()));
""".trimIndent(),
*codegenScope
)
}
for (binding in bindings) {
val member = binding.member
val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser)
if (parsedValue != null) {
withBlock("input = input.${member.setterName()}(", ");") {
parsedValue(this)
if (streamingMember == null) {
for (binding in bindings) {
val member = binding.member
val parsedValue = serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser)
if (parsedValue != null) {
withBlock("input = input.${member.setterName()}(", ");") {
parsedValue(this)
}
}
}
}
Expand Down Expand Up @@ -1052,4 +1052,12 @@ private class ServerHttpProtocolImplGenerator(
}
}
}

private fun getStreamingBodyTraitBounds(operationShape: OperationShape): String {
if (operationShape.inputShape(model).hasStreamingMember(model)) {
return "\n B: Into<#{SmithyHttpServer}::ByteStream>,"
} else {
return ""
}
}
}
2 changes: 1 addition & 1 deletion rust-runtime/aws-smithy-http-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ bytes = "1.1"
futures-util = { version = "0.3", default-features = false }
http = "0.2"
http-body = "0.4"
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp"] }
hyper = { version = "0.14", features = ["server", "http1", "http2", "tcp", "stream"] }
mime = "0.3"
nom = "7"
pin-project-lite = "0.2"
Expand Down
3 changes: 3 additions & 0 deletions rust-runtime/aws-smithy-http-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ pub use self::routing::Router;
#[doc(inline)]
pub use tower_http::add_extension::{AddExtension, AddExtensionLayer};

#[doc(inline)]
pub use aws_smithy_http::byte_stream::ByteStream;

/// Alias for a type-erased error type.
pub use axum_core::BoxError;

Expand Down
6 changes: 6 additions & 0 deletions rust-runtime/aws-smithy-http-server/src/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ impl From<aws_smithy_types::date_time::DateTimeParseError> for SmithyRejection {
}
}

impl From<aws_smithy_types::date_time::DateTimeFormatError> for SmithyRejection {
fn from(err: aws_smithy_types::date_time::DateTimeFormatError) -> Self {
SmithyRejection::Serialize(Serialize::from_err(err))
}
}

impl From<aws_smithy_types::primitive::PrimitiveParseError> for SmithyRejection {
fn from(err: aws_smithy_types::primitive::PrimitiveParseError) -> Self {
SmithyRejection::Deserialize(Deserialize::from_err(err))
Expand Down
8 changes: 8 additions & 0 deletions rust-runtime/aws-smithy-http/src/byte_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,14 @@ impl From<Vec<u8>> for ByteStream {
}
}

impl From<hyper::Body> for ByteStream {
fn from(input: hyper::Body) -> Self {
ByteStream::new(SdkBody::from_dyn(
input.map_err(|e| e.into_cause().unwrap()).boxed(),
))
}
}

#[derive(Debug)]
pub struct Error(Box<dyn StdError + Send + Sync + 'static>);

Expand Down

0 comments on commit a4ff63f

Please sign in to comment.