Skip to content

Commit

Permalink
Extract builderInstantiator interface to prepare for nullability changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Sep 15, 2023
1 parent cf8c834 commit 432642d
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientBuilderInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.ModuleDocProvider
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol

/**
Expand All @@ -36,4 +38,7 @@ data class ClientCodegenContext(
model, symbolProvider, moduleDocProvider, serviceShape, protocol, settings, CodegenTarget.CLIENT,
) {
val enableUserConfigurableRuntimePlugins: Boolean get() = settings.codegenConfig.enableUserConfigurableRuntimePlugins
override fun builderInstantiator(): BuilderInstantiator {
return ClientBuilderInstantiator(symbolProvider)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.builderInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
Expand Down Expand Up @@ -63,9 +64,9 @@ private class ClientAwsJsonFactory(private val version: AwsJsonVersion) :
ProtocolGeneratorFactory<OperationGenerator, ClientCodegenContext> {
override fun protocol(codegenContext: ClientCodegenContext): Protocol =
if (compatibleWithAwsQuery(codegenContext.serviceShape, version)) {
AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version))
AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version, codegenContext.builderInstantiator()))
} else {
AwsJson(codegenContext, version)
AwsJson(codegenContext, version, codegenContext.builderInstantiator())
}

override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator =
Expand All @@ -87,10 +88,10 @@ private class ClientAwsQueryFactory : ProtocolGeneratorFactory<OperationGenerato
}

private class ClientRestJsonFactory : ProtocolGeneratorFactory<OperationGenerator, ClientCodegenContext> {
override fun protocol(codegenContext: ClientCodegenContext): Protocol = RestJson(codegenContext)
override fun protocol(codegenContext: ClientCodegenContext): Protocol = RestJson(codegenContext, codegenContext.builderInstantiator())

override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator =
OperationGenerator(codegenContext, RestJson(codegenContext))
OperationGenerator(codegenContext, RestJson(codegenContext, codegenContext.builderInstantiator()))

override fun support(): ProtocolSupport = CLIENT_PROTOCOL_SUPPORT
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.core.smithy
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator

/**
* [CodegenContext] contains code-generation context that is _common to all_ smithy-rs plugins.
Expand All @@ -17,7 +18,7 @@ import software.amazon.smithy.model.shapes.ShapeId
* If your data is specific to the `rust-client-codegen` client plugin, put it in [ClientCodegenContext] instead.
* If your data is specific to the `rust-server-codegen` server plugin, put it in [ServerCodegenContext] instead.
*/
open class CodegenContext(
abstract class CodegenContext(
/**
* The smithy model.
*
Expand Down Expand Up @@ -89,4 +90,6 @@ open class CodegenContext(
fun expectModuleDocProvider(): ModuleDocProvider = checkNotNull(moduleDocProvider) {
"A ModuleDocProvider must be set on the CodegenContext"
}

abstract fun builderInstantiator(): BuilderInstantiator
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
Expand Down Expand Up @@ -122,6 +123,7 @@ class AwsJsonSerializerGenerator(
open class AwsJson(
val codegenContext: CodegenContext,
val awsJsonVersion: AwsJsonVersion,
val builderInstantiator: BuilderInstantiator,
) : Protocol {
private val runtimeConfig = codegenContext.runtimeConfig
private val errorScope = arrayOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.JsonSerializerGenerator
Expand Down Expand Up @@ -59,7 +60,7 @@ class RestJsonHttpBindingResolver(
}
}

open class RestJson(val codegenContext: CodegenContext) : Protocol {
open class RestJson(val codegenContext: CodegenContext, private val builderInstantiator: BuilderInstantiator) : Protocol {
private val runtimeConfig = codegenContext.runtimeConfig
private val errorScope = arrayOf(
"Bytes" to RuntimeType.Bytes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class EventStreamUnmarshallerGenerator(
private val unionShape: UnionShape,
) {
private val model = codegenContext.model
private val builderInstantiator = codegenContext.builderInstantiator()
private val symbolProvider = codegenContext.symbolProvider
private val codegenTarget = codegenContext.target
private val runtimeConfig = codegenContext.runtimeConfig
Expand Down Expand Up @@ -339,6 +340,7 @@ class EventStreamUnmarshallerGenerator(
// TODO(EventStream): Errors on the operation can be disjoint with errors in the union,
// so we need to generate a new top-level Error type for each event stream union.
when (codegenTarget) {
// TODO(https://github.com/awslabs/smithy-rs/issues/1970) It should be possible to unify these branches now
CodegenTarget.CLIENT -> {
val target = model.expectShape(member.target, StructureShape::class.java)
val parser = protocol.structuredDataParser().errorParser(target)
Expand All @@ -352,9 +354,19 @@ class EventStreamUnmarshallerGenerator(
})?;
builder.set_meta(Some(generic));
return Ok(#{UnmarshalledMessage}::Error(
#{OpError}::${member.target.name}(builder.build())
#{OpError}::${member.target.name}(
#{build}
)
))
""",
"build" to builderInstantiator.finalizeBuilder(
"builder", target,
mapErr = {
rustTemplate(
"""|err|#{Error}::unmarshalling(format!("{}", err))""", *codegenScope,
)
},
),
"parser" to parser,
*codegenScope,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,16 @@ import software.amazon.smithy.utils.StringUtils
* Class describing a JSON parser section that can be used in a customization.
*/
sealed class JsonParserSection(name: String) : Section(name) {
data class BeforeBoxingDeserializedMember(val shape: MemberShape) : JsonParserSection("BeforeBoxingDeserializedMember")
data class BeforeBoxingDeserializedMember(val shape: MemberShape) :
JsonParserSection("BeforeBoxingDeserializedMember")

data class AfterTimestampDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterTimestampDeserializedMember")
data class AfterTimestampDeserializedMember(val shape: MemberShape) :
JsonParserSection("AfterTimestampDeserializedMember")

data class AfterBlobDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterBlobDeserializedMember")

data class AfterDocumentDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterDocumentDeserializedMember")
data class AfterDocumentDeserializedMember(val shape: MemberShape) :
JsonParserSection("AfterDocumentDeserializedMember")
}

/**
Expand Down Expand Up @@ -100,6 +103,7 @@ class JsonParserGenerator(
private val codegenTarget = codegenContext.target
private val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType()
private val protocolFunctions = ProtocolFunctions(codegenContext)
private val builderInstantiator = codegenContext.builderInstantiator()
private val codegenScope = arrayOf(
"Error" to smithyJson.resolve("deserialize::error::DeserializeError"),
"expect_blob_or_null" to smithyJson.resolve("deserialize::token::expect_blob_or_null"),
Expand Down Expand Up @@ -251,6 +255,7 @@ class JsonParserGenerator(
deserializeMember(member)
}
}

CodegenTarget.SERVER -> {
if (symbolProvider.toSymbol(member).isOptional()) {
withBlock("builder = builder.${member.setterName()}(", ");") {
Expand Down Expand Up @@ -508,12 +513,14 @@ class JsonParserGenerator(
"Builder" to symbolProvider.symbolForBuilder(shape),
)
deserializeStructInner(shape.members())
// Only call `build()` if the builder is not fallible. Otherwise, return the builder.
if (returnSymbolToParse.isUnconstrained) {
rust("Ok(Some(builder))")
} else {
rust("Ok(Some(builder.build()))")
val builder = builderInstantiator.finalizeBuilder(
"builder", shape,
) {
rustTemplate(
"""|err|#{Error}::custom_source("Response was invalid", err)""", *codegenScope,
)
}
rust("Ok(Some(#T))", builder)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.renderUnknownVariant
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
Expand Down Expand Up @@ -101,6 +100,7 @@ class XmlBindingTraitParserGenerator(
private val runtimeConfig = codegenContext.runtimeConfig
private val protocolFunctions = ProtocolFunctions(codegenContext)
private val codegenTarget = codegenContext.target
private val builderInstantiator = codegenContext.builderInstantiator()

// The symbols we want all the time
private val codegenScope = arrayOf(
Expand Down Expand Up @@ -159,6 +159,7 @@ class XmlBindingTraitParserGenerator(
is StructureShape -> {
parseStructure(shape, ctx)
}

is UnionShape -> parseUnion(shape, ctx)
}
}
Expand Down Expand Up @@ -294,7 +295,10 @@ class XmlBindingTraitParserGenerator(
}
rust("$builder = $builder.${member.setterName()}($temp);")
}
rustTemplate("_ => return Err(#{XmlDecodeError}::custom(\"expected ${member.xmlName()} tag\"))", *codegenScope)
rustTemplate(
"_ => return Err(#{XmlDecodeError}::custom(\"expected ${member.xmlName()} tag\"))",
*codegenScope,
)
}
}

Expand Down Expand Up @@ -359,19 +363,23 @@ class XmlBindingTraitParserGenerator(
parsePrimitiveInner(memberShape) {
rustTemplate("#{try_data}(&mut ${ctx.tag})?.as_ref()", *codegenScope)
}

is MapShape -> if (memberShape.isFlattened()) {
parseFlatMap(target, ctx)
} else {
parseMap(target, ctx)
}

is CollectionShape -> if (memberShape.isFlattened()) {
parseFlatList(target, ctx)
} else {
parseList(target, ctx)
}

is StructureShape -> {
parseStructure(target, ctx)
}

is UnionShape -> parseUnion(target, ctx)
else -> PANIC("Unhandled: $target")
}
Expand Down Expand Up @@ -436,10 +444,16 @@ class XmlBindingTraitParserGenerator(
}
when (target.renderUnknownVariant()) {
true -> rust("_unknown => base = Some(#T::${UnionGenerator.UnknownVariantName}),", symbol)
false -> rustTemplate("""variant => return Err(#{XmlDecodeError}::custom(format!("unexpected union variant: {:?}", variant)))""", *codegenScope)
false -> rustTemplate(
"""variant => return Err(#{XmlDecodeError}::custom(format!("unexpected union variant: {:?}", variant)))""",
*codegenScope,
)
}
}
rustTemplate("""base.ok_or_else(||#{XmlDecodeError}::custom("expected union, got nothing"))""", *codegenScope)
rustTemplate(
"""base.ok_or_else(||#{XmlDecodeError}::custom("expected union, got nothing"))""",
*codegenScope,
)
}
}
rust("#T(&mut ${ctx.tag})", nestedParser)
Expand Down Expand Up @@ -474,17 +488,17 @@ class XmlBindingTraitParserGenerator(
} else {
rust("let _ = decoder;")
}
withBlock("Ok(builder.build()", ")") {
if (BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) {
// NOTE:(rcoh) This branch is unreachable given the current nullability rules.
// Only synthetic inputs can have fallible builders, but synthetic inputs can never be parsed
// (because they're inputs, only outputs will be parsed!)

// I'm leaving this branch here so that the binding trait parser generator would work for a server
// side implementation in the future.
rustTemplate(""".map_err(|_|#{XmlDecodeError}::custom("missing field"))?""", *codegenScope)
}
}
val builder = builderInstantiator.finalizeBuilder(
"builder",
shape,
mapErr = {
rustTemplate(
""".map_err(|_|#{XmlDecodeError}::custom("missing field"))?""",
*codegenScope,
)
},
)
rust("Ok(#T)", builder)
}
}
rust("#T(&mut ${ctx.tag})", nestedParser)
Expand Down Expand Up @@ -622,14 +636,16 @@ class XmlBindingTraitParserGenerator(
)
}
}

is TimestampShape -> {
val timestampFormat =
index.determineTimestampFormat(
member,
HttpBinding.Location.DOCUMENT,
TimestampFormatTrait.Format.DATE_TIME,
)
val timestampFormatType = RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat)
val timestampFormatType =
RuntimeType.parseTimestampFormat(codegenTarget, runtimeConfig, timestampFormat)
withBlock("#T::from_str(", ")", RuntimeType.dateTime(runtimeConfig)) {
provider()
rust(", #T", timestampFormatType)
Expand All @@ -639,6 +655,7 @@ class XmlBindingTraitParserGenerator(
*codegenScope,
)
}

is BlobShape -> {
withBlock("#T(", ")", RuntimeType.base64Decode(runtimeConfig)) {
provider()
Expand All @@ -648,6 +665,7 @@ class XmlBindingTraitParserGenerator(
*codegenScope,
)
}

else -> PANIC("unexpected shape: $shape")
}
}
Expand All @@ -660,7 +678,10 @@ class XmlBindingTraitParserGenerator(
withBlock("#T::try_from(", ")", enumSymbol) {
provider()
}
rustTemplate(""".map_err(|e| #{XmlDecodeError}::custom(format!("unknown variant {}", e)))?""", *codegenScope)
rustTemplate(
""".map_err(|e| #{XmlDecodeError}::custom(format!("unknown variant {}", e)))?""",
*codegenScope,
)
} else {
withBlock("#T::from(", ")", enumSymbol) {
provider()
Expand All @@ -674,7 +695,8 @@ class XmlBindingTraitParserGenerator(
}
}

private fun convertsToEnumInServer(shape: StringShape) = target == CodegenTarget.SERVER && shape.hasTrait<EnumTrait>()
private fun convertsToEnumInServer(shape: StringShape) =
target == CodegenTarget.SERVER && shape.hasTrait<EnumTrait>()

private fun MemberShape.xmlName(): XmlName {
return XmlName(xmlIndex.memberName(this))
Expand Down
Loading

0 comments on commit 432642d

Please sign in to comment.