Skip to content

Commit

Permalink
Support server event streams
Browse files Browse the repository at this point in the history
* Server event streams
* Rename EventStreamInput to EventStreamSender
* Make event stream errors optional
* Pokemon service model updated
* Pokemon server event handler
* Pokemon client to test event streams
* EventStreamDecorator to make optional using SigV4 signing

Closes: #1157

Signed-off-by: Daniele Ahmed <[email protected]>
  • Loading branch information
82marbag authored and Daniele Ahmed committed Jul 14, 2022
1 parent c6193bd commit 7ec3c9c
Show file tree
Hide file tree
Showing 37 changed files with 1,253 additions and 306 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ references = ["smithy-rs#1263"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "Velfi"

[[smithy-rs]]
message = "Rename EventStreamInput to EventStreamSender"
references = ["smithy-rs#1157"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "82marbag"

[[aws-sdk-rust]]
message = "Rename EventStreamInput to EventStreamSender"
references = ["smithy-rs#1157"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "82marbag"

[[aws-sdk-rust]]
message = "Re-export aws_types::SdkConfig in aws_config"
references = ["smithy-rs#1457"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.OptionalAuthTrait
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.asType
import software.amazon.smithy.rust.codegen.rustlang.rust
Expand All @@ -27,7 +26,7 @@ import software.amazon.smithy.rust.codegen.smithy.customize.OperationCustomizati
import software.amazon.smithy.rust.codegen.smithy.customize.OperationSection
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.smithy.generators.config.EventStreamSigningConfig
import software.amazon.smithy.rust.codegen.smithy.letIf
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
Expand Down Expand Up @@ -82,51 +81,37 @@ class SigV4SigningConfig(
runtimeConfig: RuntimeConfig,
private val serviceHasEventStream: Boolean,
private val sigV4Trait: SigV4Trait
) : ConfigCustomization() {
) : EventStreamSigningConfig(runtimeConfig) {
private val codegenScope = arrayOf(
"SigV4Signer" to RuntimeType(
"SigV4Signer",
runtimeConfig.awsRuntimeDependency("aws-sig-auth", setOf("sign-eventstream")),
"aws_sig_auth::event_stream"
),
"SharedPropertyBag" to RuntimeType(
"SharedPropertyBag",
CargoDependency.SmithyHttp(runtimeConfig),
"aws_smithy_http::property_bag"
)
)

override fun section(section: ServiceConfig): Writable {
return when (section) {
is ServiceConfig.ConfigImpl -> writable {
override fun configImplSection() = renderEventStreamSignerFn { propertiesName ->
writable {
rustTemplate(
"""
/// The signature version 4 service signing name to use in the credential scope when signing requests.
///
/// The signing service may be overridden by the `Endpoint`, or by specifying a custom
/// [`SigningService`](aws_types::SigningService) during operation construction
pub fn signing_service(&self) -> &'static str {
${sigV4Trait.name.dq()}
}
""",
*codegenScope
)
if (serviceHasEventStream) {
rustTemplate(
"""
/// The signature version 4 service signing name to use in the credential scope when signing requests.
///
/// The signing service may be overridden by the `Endpoint`, or by specifying a custom
/// [`SigningService`](aws_types::SigningService) during operation construction
pub fn signing_service(&self) -> &'static str {
${sigV4Trait.name.dq()}
}
#{SigV4Signer}::new($propertiesName)
""",
*codegenScope
)
if (serviceHasEventStream) {
rustTemplate(
"""
/// Creates a new Event Stream `SignMessage` implementor.
pub fn new_event_stream_signer(
&self,
properties: #{SharedPropertyBag}
) -> #{SigV4Signer} {
#{SigV4Signer}::new(properties)
}
""",
*codegenScope
)
}
}
else -> emptySection
}
}
}
Expand Down
78 changes: 77 additions & 1 deletion codegen-server-test/model/pokemon.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use aws.protocols#restJson1
service PokemonService {
version: "2021-12-01",
resources: [PokemonSpecies],
operations: [GetServerStatistics, EmptyOperation],
operations: [GetServerStatistics, EmptyOperation, CapturePokemonOperation],
}

/// A Pokémon species forms the basis for at least one Pokémon.
Expand All @@ -22,6 +22,82 @@ resource PokemonSpecies {
read: GetPokemonSpecies,
}

/// Capture Pokémons via event streams
@http(uri: "/capture-pokemon-event/{region}", method: "POST")
operation CapturePokemonOperation {
input: CapturePokemonOperationEventsInput,
output: CapturePokemonOperationEventsOutput,
errors: [UnsupportedRegionError, ThrottlingError]
}

@input
structure CapturePokemonOperationEventsInput {
@httpPayload
events: AttemptCapturingPokemonEvent,

@httpLabel
@required
region: String,
}

@output
structure CapturePokemonOperationEventsOutput {
@httpPayload
events: CapturePokemonEvents,
}

@streaming
union AttemptCapturingPokemonEvent {
event: CapturingEvent,
masterball_unsuccessful: MasterBallUnsuccessful,
}

structure CapturingEvent {
@eventPayload
payload: CapturingPayload,
}

structure CapturingPayload {
name: String,
pokeball: String,
}

@streaming
union CapturePokemonEvents {
event: CaptureEvent,
invalid_pokeball: InvalidPokeballError,
throttlingError: ThrottlingError,
}

structure CaptureEvent {
@eventHeader
name: String,
@eventHeader
captured: Boolean,
@eventHeader
shiny: Boolean,
@eventPayload
pokedex_update: Blob,
}

@error("server")
structure UnsupportedRegionError {
@required
region: String,
}
@error("client")
structure InvalidPokeballError {
@required
pokeball: String,
}
@error("server")
structure MasterBallUnsuccessful {
@required
message: String,
}
@error("client")
structure ThrottlingError {}

/// Retrieve information about a Pokémon species.
@readonly
@http(uri: "/pokemon-species/{name}", method: "GET")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PythonCodegenServerPlugin : SmithyBuildPlugin {
override fun execute(context: PluginContext) {
// Suppress extremely noisy logs about reserved words
Logger.getLogger(ReservedWordSymbolProvider::class.java.name).level = Level.OFF
// Discover [RustCodegenDecorators] on the classpath. [RustCodegenDectorator] return different types of
// Discover [RustCodegenDecorators] on the classpath. [RustCodegenDecorator] return different types of
// customization. A customization is a function of:
// - location (e.g. the mutate section of an operation)
// - context (e.g. the of the operation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import software.amazon.smithy.rust.codegen.smithy.StreamingShapeSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.SymbolVisitor
import software.amazon.smithy.rust.codegen.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.smithy.customize.CombinedCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.generators.CodegenTarget
import java.util.logging.Level
import java.util.logging.Logger

Expand Down Expand Up @@ -64,7 +65,7 @@ class RustCodegenServerPlugin : SmithyBuildPlugin {
SymbolVisitor(model, serviceShape = serviceShape, config = symbolVisitorConfig)
// Generate different types for EventStream shapes (e.g. transcribe streaming)
.let {
EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model)
EventStreamSymbolProvider(symbolVisitorConfig.runtimeConfig, it, model, CodegenTarget.SERVER)
}
// Generate [ByteStream] instead of `Blob` for streaming binary shapes (e.g. S3 GetObject)
.let { StreamingShapeSymbolProvider(it, model) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import software.amazon.smithy.model.traits.RequiredTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.transformers.allErrors

/**
* Add at least one error to all operations in the model.
Expand All @@ -35,7 +36,7 @@ class AddInternalServerErrorToInfallibleOperationsDecorator : RustCodegenDecorat
override val order: Byte = 0

override fun transformModel(service: ServiceShape, model: Model): Model =
addErrorShapeToModelOperations(service, model) { shape -> shape.errors.isEmpty() }
addErrorShapeToModelOperations(service, model) { shape -> shape.allErrors(model).isEmpty() }
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.OperationIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.rustlang.Attribute
import software.amazon.smithy.rust.codegen.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
Expand All @@ -19,6 +20,10 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.generators.error.eventStreamErrorSymbol
import software.amazon.smithy.rust.codegen.smithy.transformers.eventStreamErrors
import software.amazon.smithy.rust.codegen.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.util.isEventStream
import software.amazon.smithy.rust.codegen.util.toSnakeCase

/**
Expand All @@ -30,12 +35,35 @@ open class ServerCombinedErrorGenerator(
private val symbolProvider: RustSymbolProvider,
private val operation: OperationShape
) {
private val operationIndex = OperationIndex.of(model)

open fun render(writer: RustWriter) {
val errors = operationIndex.getErrors(operation)
val operationSymbol = symbolProvider.toSymbol(operation)
fun render(writer: RustWriter) {
val errors = operation.operationErrors(model)
val symbol = operation.errorSymbol(symbolProvider)
val operationSymbol = symbolProvider.toSymbol(operation)
if (errors.isNotEmpty()) {
renderErrors(writer, errors.map { it.asStructureShape().get() }, symbol, operationSymbol)
}

if (operation.isEventStream(model)) {
operation.eventStreamErrors(model)
.forEach { (unionShape, unionErrors) ->
if (unionErrors.isNotEmpty()) {
renderErrors(
writer,
unionErrors,
unionShape.eventStreamErrorSymbol(symbolProvider),
symbolProvider.toSymbol(unionShape)
)
}
}
}
}

private fun renderErrors(
writer: RustWriter,
errors: List<StructureShape>,
errorSymbol: RuntimeType,
operationSymbol: Symbol
) {
val meta = RustMetadata(
derives = Attribute.Derives(setOf(RuntimeType.Debug)),
visibility = Visibility.PUBLIC
Expand All @@ -44,52 +72,52 @@ open class ServerCombinedErrorGenerator(
writer.rust("/// Error type for the `${operationSymbol.name}` operation.")
writer.rust("/// Each variant represents an error that can occur for the `${operationSymbol.name}` operation.")
meta.render(writer)
writer.rustBlock("enum ${symbol.name}") {
writer.rustBlock("enum ${errorSymbol.name}") {
errors.forEach { errorVariant ->
documentShape(errorVariant, model)
val errorVariantSymbol = symbolProvider.toSymbol(errorVariant)
write("${errorVariantSymbol.name}(#T),", errorVariantSymbol)
}
}

writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.Display) {
writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.Display) {
rustBlock("fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result") {
delegateToVariants {
delegateToVariants(errors, errorSymbol) {
rust("_inner.fmt(f)")
}
}
}

writer.rustBlock("impl ${symbol.name}") {
writer.rustBlock("impl ${errorSymbol.name}") {
errors.forEach { error ->
val errorSymbol = symbolProvider.toSymbol(error)
val fnName = errorSymbol.name.toSnakeCase()
writer.rust("/// Returns `true` if the error kind is `${symbol.name}::${errorSymbol.name}`.")
val errorVariantSymbol = symbolProvider.toSymbol(error)
val fnName = errorVariantSymbol.name.toSnakeCase()
writer.rust("/// Returns `true` if the error kind is `${errorSymbol.name}::${errorVariantSymbol.name}`.")
writer.rustBlock("pub fn is_$fnName(&self) -> bool") {
rust("matches!(&self, ${symbol.name}::${errorSymbol.name}(_))")
rust("matches!(&self, ${errorSymbol.name}::${errorVariantSymbol.name}(_))")
}
}
writer.rust("/// Returns the error name string by matching the correct variant.")
writer.rustBlock("pub fn name(&self) -> &'static str") {
delegateToVariants {
delegateToVariants(errors, errorSymbol) {
rust("_inner.name()")
}
}
}

writer.rustBlock("impl #T for ${symbol.name}", RuntimeType.StdError) {
writer.rustBlock("impl #T for ${errorSymbol.name}", RuntimeType.StdError) {
rustBlock("fn source(&self) -> Option<&(dyn #T + 'static)>", RuntimeType.StdError) {
delegateToVariants {
delegateToVariants(errors, errorSymbol) {
rust("Some(_inner)")
}
}
}

for (error in errors) {
val errorSymbol = symbolProvider.toSymbol(error)
writer.rustBlock("impl #T<#T> for #T", RuntimeType.From, errorSymbol, symbol) {
rustBlock("fn from(variant: #T) -> #T", errorSymbol, symbol) {
rust("Self::${errorSymbol.name}(variant)")
val errorVariantSymbol = symbolProvider.toSymbol(error)
writer.rustBlock("impl #T<#T> for #T", RuntimeType.From, errorVariantSymbol, errorSymbol) {
rustBlock("fn from(variant: #T) -> #T", errorVariantSymbol, errorSymbol) {
rust("Self::${errorVariantSymbol.name}(variant)")
}
}
}
Expand All @@ -112,10 +140,10 @@ open class ServerCombinedErrorGenerator(
* The field will always be bound as `_inner`.
*/
private fun RustWriter.delegateToVariants(
writable: Writable
errors: List<StructureShape>,
symbol: RuntimeType,
writable: Writable,
) {
val errors = operationIndex.getErrors(operation)
val symbol = operation.errorSymbol(symbolProvider)
rustBlock("match &self") {
errors.forEach {
val errorSymbol = symbolProvider.toSymbol(it)
Expand Down
Loading

0 comments on commit 7ec3c9c

Please sign in to comment.