Skip to content

Commit

Permalink
BuilderGenerator with custom default renderer
Browse files Browse the repository at this point in the history
Signed-off-by: Daniele Ahmed <[email protected]>
  • Loading branch information
82marbag authored and Daniele Ahmed committed Oct 31, 2022
1 parent dbc42e9 commit a1825e1
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.transformers.AddErrorMe
import software.amazon.smithy.rust.codegen.client.smithy.transformers.RemoveEventStreamOperations
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
Expand Down Expand Up @@ -193,7 +195,11 @@ class CodegenVisitor(
rustCrate.useShapeWriter(shape) {
StructureGenerator(model, symbolProvider, this, shape).render()
if (!shape.hasTrait<SyntheticInputTrait>()) {
val builderGenerator = BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape)
val builderGenerator = BuilderGenerator(
codegenContext.model,
codegenContext.symbolProvider,
shape,
) { writable { rust(".unwrap_or_default()") } }
builderGenerator.render(this)
this.implBlock(shape, symbolProvider) {
builderGenerator.renderConvenienceMethod(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.docLink
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization
Expand Down Expand Up @@ -46,7 +47,11 @@ open class ClientProtocolGenerator(
customizations: List<OperationCustomization>,
) {
val inputShape = operationShape.inputShape(model)
val builderGenerator = BuilderGenerator(model, symbolProvider, operationShape.inputShape(model))
val builderGenerator = BuilderGenerator(
model,
symbolProvider,
operationShape.inputShape(model),
) { writable { rust(".unwrap_or_default()") } }
builderGenerator.render(inputWriter)

// impl OperationInputShape { ... }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,13 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.NullNode
import software.amazon.smithy.model.node.NumberNode
import software.amazon.smithy.model.node.StringNode
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.IntEnumShape
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.asArgument
import software.amazon.smithy.rust.codegen.core.rustlang.asOptional
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
Expand All @@ -37,6 +23,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.documentShape
import software.amazon.smithy.rust.codegen.core.rustlang.render
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.Default
Expand All @@ -49,7 +36,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase

fun StructureShape.builderSymbol(symbolProvider: RustSymbolProvider): Symbol {
Expand Down Expand Up @@ -80,6 +66,7 @@ open class BuilderGenerator(
private val model: Model,
private val symbolProvider: RustSymbolProvider,
private val shape: StructureShape,
private val renderDefault: (member: MemberShape) -> Writable,
) {
private val runtimeConfig = symbolProvider.config().runtimeConfig
private val members: List<MemberShape> = shape.allMembers.values.toList()
Expand Down Expand Up @@ -280,7 +267,7 @@ open class BuilderGenerator(
withBlock("$memberName: ${optionPrefix}self.$memberName", "$optionSuffix,") {
// Write the modifier
when {
member.hasNonNullDefault() -> renderDefaultBuilder(writer, member)
member.hasNonNullDefault() -> writer.rustTemplate("#{default:W}", "default" to renderDefault(member))
!memberSymbol.isOptional() && default == Default.RustDefault -> rust(".unwrap_or_default()")
!memberSymbol.isOptional() -> withBlock(
".ok_or(",
Expand All @@ -291,43 +278,4 @@ open class BuilderGenerator(
}
}
}

private fun renderDefaultBuilder(writer: RustWriter, member: MemberShape) {
val node = member.expectTrait<DefaultTrait>().toNode()!!
when (val target = model.expectShape(member.target)) {
is EnumShape, is IntEnumShape -> {
val value = when (target) {
is IntEnumShape -> node.expectNumberNode().value
else -> node.expectStringNode().value
}
val enumValues = when (target) {
is IntEnumShape -> target.enumValues
is EnumShape -> target.enumValues
else -> mapOf<String, String>()
}
val variant = enumValues
.entries
.filter { entry -> entry.value == value }
.map { entry -> symbolProvider.toEnumVariantName(EnumDefinition.builder().name(entry.key).value(entry.value.toString()).build())!! }
.first()
val symbol = symbolProvider.toSymbol(target)
writer.rust(".unwrap_or($symbol::${variant.name})")
}
is IntegerShape -> writer.rust(".unwrap_or(${node.expectNumberNode().value})")
is FloatShape -> writer.rust(".unwrap_or(${node.expectNumberNode().value.toFloat()})")
is BooleanShape -> writer.rust(".unwrap_or(${node.expectBooleanNode().value})")
is StringShape -> writer.rust(".unwrap_or_else(|| String::from(${node.expectStringNode().value.dq()}))")
is ListShape, is MapShape -> writer.rust(".unwrap_or_default()")
is DocumentShape -> {
when (node) {
is NullNode -> writer.rust(".unwrap_or_default()")
is BooleanNode -> writer.rust(".unwrap_or(${node.value})")
is StringNode -> writer.rust(".unwrap_or_else(|| String::from(${node.value.dq()}))")
is NumberNode -> writer.rust(".unwrap_or(${node.value})")
else -> writer.rust(".unwrap_or_default()")
}
}
else -> writer.rust(".unwrap_or_default()")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWordSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.asType
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.BaseSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
Expand Down Expand Up @@ -108,7 +110,7 @@ fun StructureShape.renderWithModelBuilder(
forWhom: CodegenTarget = CodegenTarget.CLIENT,
) {
StructureGenerator(model, symbolProvider, writer, this).render(forWhom)
val modelBuilder = BuilderGenerator(model, symbolProvider, this)
val modelBuilder = BuilderGenerator(model, symbolProvider, this) { writable { rust(".unwrap_or_default()") } }
modelBuilder.render(writer)
writer.implBlock(this, symbolProvider) {
modelBuilder.renderConvenienceMethod(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,17 @@ import org.junit.jupiter.api.Test
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.core.smithy.setDefault
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.core.util.lookup

internal class BuilderGeneratorTest {
private val model = StructureGeneratorTest.model
Expand All @@ -35,7 +33,7 @@ internal class BuilderGeneratorTest {
writer.rust("##![allow(deprecated)]")
val innerGenerator = StructureGenerator(model, provider, writer, inner)
val generator = StructureGenerator(model, provider, writer, struct)
val builderGenerator = BuilderGenerator(model, provider, struct)
val builderGenerator = BuilderGenerator(model, provider, struct) { writable { rust(".unwrap_or_default()") } }
generator.render()
innerGenerator.render()
builderGenerator.render(writer)
Expand Down Expand Up @@ -84,7 +82,7 @@ internal class BuilderGeneratorTest {
)
generator.render()
innerGenerator.render()
val builderGenerator = BuilderGenerator(model, provider, struct)
val builderGenerator = BuilderGenerator(model, provider, struct) { writable { rust(".unwrap_or_default()") } }
builderGenerator.render(writer)
writer.implBlock(struct, provider) {
builderGenerator.renderConvenienceMethod(this)
Expand All @@ -97,46 +95,4 @@ internal class BuilderGeneratorTest {
""",
)
}

@Test
fun `generate default values`() {
val model =
"""
namespace com.test
structure MyStruct {
@required
foo: String = "foo",
bar: PrimitiveInteger = 42,
baz: Integer = 42,
baw: Float = 42.0,
yes: Boolean = true,
}
@default(42)
integer PrimitiveInteger
""".asSmithyModel(smithyVersion = "2")
val provider = testSymbolProvider(model)
val writer = RustWriter.forModule("model")
writer.rust("##![allow(deprecated)]")
val struct = model.lookup<StructureShape>("com.test#MyStruct")

val generator = StructureGenerator(model, provider, writer, struct)
val builderGenerator = BuilderGenerator(model, provider, struct)

generator.render()
builderGenerator.render(writer)
writer.implBlock(struct, provider) {
builderGenerator.renderConvenienceMethod(this)
}
writer.compileAndTest(
"""
let my_struct = MyStruct::builder().build();
assert_eq!(my_struct.foo.unwrap(), "foo");
assert_eq!(my_struct.bar.unwrap(), 42);
assert_eq!(my_struct.baz.unwrap(), 42);
assert_eq!(my_struct.baw.unwrap(), 42.0);
assert_eq!(my_struct.yes.unwrap(), true);
""",
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class PythonServerCodegenVisitor(
// and #[pymethods] implementation.
PythonServerStructureGenerator(model, symbolProvider, this, shape).render(CodegenTarget.SERVER)
val builderGenerator =
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape)
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape, renderDefaultBuilder)
builderGenerator.render(this)
implBlock(shape, symbolProvider) {
builderGenerator.renderConvenienceMethod(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,34 @@ import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.neighbor.Walker
import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.NullNode
import software.amazon.smithy.model.node.NumberNode
import software.amazon.smithy.model.node.StringNode
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.IntEnumShape
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeVisitor
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.client.smithy.customize.RustCodegenDecorator
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
Expand All @@ -33,6 +51,8 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamN
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer
import software.amazon.smithy.rust.codegen.core.util.CommandFailed
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.runCommand
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGenerator
Expand Down Expand Up @@ -182,7 +202,7 @@ open class ServerCodegenVisitor(
rustCrate.useShapeWriter(shape) {
StructureGenerator(model, symbolProvider, this, shape).render(CodegenTarget.SERVER)
val builderGenerator =
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape)
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape, ::renderDefaultBuilder)
builderGenerator.render(this)
this.implBlock(shape, symbolProvider) {
builderGenerator.renderConvenienceMethod(this)
Expand Down Expand Up @@ -238,4 +258,49 @@ open class ServerCodegenVisitor(
)
.render()
}

fun renderDefaultBuilder(member: MemberShape): Writable {
return writable {
val node = member.expectTrait<DefaultTrait>().toNode()!!
when (val target = model.expectShape(member.target)) {
is EnumShape, is IntEnumShape -> {
val value = when (target) {
is IntEnumShape -> node.expectNumberNode().value
else -> node.expectStringNode().value
}
val enumValues = when (target) {
is IntEnumShape -> target.enumValues
is EnumShape -> target.enumValues
else -> mapOf<String, String>()
}
val variant = enumValues
.entries
.filter { entry -> entry.value == value }
.map { entry ->
symbolProvider.toEnumVariantName(
EnumDefinition.builder().name(entry.key).value(entry.value.toString()).build(),
)!!
}
.first()
val symbol = symbolProvider.toSymbol(target)
rust(".unwrap_or($symbol::${variant.name})")
}
is IntegerShape -> rust(".unwrap_or(${node.expectNumberNode().value})")
is FloatShape -> rust(".unwrap_or(${node.expectNumberNode().value.toFloat()})")
is BooleanShape -> rust(".unwrap_or(${node.expectBooleanNode().value})")
is StringShape -> rust(".unwrap_or_else(|| String::from(${node.expectStringNode().value.dq()}))")
is ListShape, is MapShape -> rust(".unwrap_or_default()")
is DocumentShape -> {
when (node) {
is NullNode -> rust(".unwrap_or_default()")
is BooleanNode -> rust(".unwrap_or(${node.value})")
is StringNode -> rust(".unwrap_or_else(|| String::from(${node.value.dq()}))")
is NumberNode -> rust(".unwrap_or(${node.value})")
else -> rust(".unwrap_or_default()")
}
}
else -> rust(".unwrap_or_default()")
}
}
}
}
Loading

0 comments on commit a1825e1

Please sign in to comment.