Skip to content

Commit

Permalink
Introduce ServerBuilderGenerator
Browse files Browse the repository at this point in the history
Signed-off-by: Daniele Ahmed <[email protected]>
  • Loading branch information
82marbag authored and Daniele Ahmed committed Oct 20, 2022
1 parent 6e45344 commit 5aa8b94
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,9 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.NullNode
import software.amazon.smithy.model.node.NumberNode
import software.amazon.smithy.model.node.StringNode
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.IntEnumShape
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
Expand All @@ -49,7 +35,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase

Expand Down Expand Up @@ -77,7 +62,7 @@ class OperationBuildError(private val runtimeConfig: RuntimeConfig) {
// Setter names will never hit a reserved word and therefore never need escaping.
fun MemberShape.setterName() = "set_${this.memberName.toSnakeCase()}"

class BuilderGenerator(
open class BuilderGenerator(
private val model: Model,
private val symbolProvider: RustSymbolProvider,
private val shape: StructureShape,
Expand Down Expand Up @@ -275,42 +260,7 @@ class BuilderGenerator(
withBlock("$memberName: self.$memberName", ",") {
// Write the modifier
when {
!memberSymbol.isOptional() && member.hasTrait<DefaultTrait>() -> {
val node = member.expectTrait<DefaultTrait>().toNode()!!
when (val target = model.expectShape(member.target)) {
is EnumShape, is IntEnumShape -> {
val value = when (target) {
is IntEnumShape -> node.expectNumberNode().value
else -> node.expectStringNode().value
}
val enumValues = when (target) {
is IntEnumShape -> target.enumValues
is EnumShape -> target.enumValues
else -> mapOf<String, String>()
}
val variant = enumValues
.entries
.filter { entry -> entry.value == value }
.map { entry -> symbolProvider.toEnumVariantName(EnumDefinition.builder().name(entry.key).value(entry.value.toString()).build())!! }
.first()
val symbol = symbolProvider.toSymbol(target)
rust(".unwrap_or($symbol::${variant.name})")
}
is IntegerShape, is FloatShape -> rust(".unwrap_or(${node.expectNumberNode().value})")
is BooleanShape -> rust(".unwrap_or(${node.expectBooleanNode().value})")
is StringShape -> rust(".unwrap_or(String::from(${node.expectStringNode().value.dq()}))")
is ListShape, is MapShape -> rust(".unwrap_or_default()")
is DocumentShape -> {
when (node) {
is NullNode -> rust(".unwrap_or_default()")
is BooleanNode -> rust(".unwrap_or(${node.value})")
is StringNode -> rust(".unwrap_or(String::from(${node.value.dq()}))")
is NumberNode -> rust(".unwrap_or(${node.value})")
else -> rust(".unwrap_or_default()")
}
}
}
}
!memberSymbol.isOptional() && member.hasTrait<DefaultTrait>() -> renderDefaultBuilder(writer, member)
!memberSymbol.isOptional() && default == Default.RustDefault -> rust(".unwrap_or_default()")
!memberSymbol.isOptional() -> withBlock(
".ok_or(",
Expand All @@ -321,4 +271,8 @@ class BuilderGenerator(
}
}
}

open fun renderDefaultBuilder(writer: RustWriter, member: MemberShape) {
writer.rust(".unwrap_or_default()")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ fun StructureShape.renderWithModelBuilder(
model: Model,
symbolProvider: RustSymbolProvider,
writer: RustWriter,
forWhom: CodegenTarget = CodegenTarget.CLIENT,
) {
StructureGenerator(model, symbolProvider, writer, this).render(forWhom)
StructureGenerator(model, symbolProvider, writer, this).render(CodegenTarget.CLIENT)
val modelBuilder = BuilderGenerator(model, symbolProvider, this)
modelBuilder.render(writer)
writer.implBlock(this, symbolProvider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.TestRuntimeConfig
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder
import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.dq
Expand All @@ -31,7 +30,7 @@ import software.amazon.smithy.rust.codegen.core.util.lookup
class InstantiatorTest {
private val model = """
namespace com.test
@documentation("this documents the shape")
structure MyStruct {
foo: String,
Expand Down Expand Up @@ -122,8 +121,8 @@ class InstantiatorTest {
}
rust(
"""
assert_eq!(result.bar, 10);
assert_eq!(result.foo.unwrap(), "hello");
assert_eq!(result.bar, 10);
assert_eq!(result.foo.unwrap(), "hello");
""",
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.RustCodegenDe
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerEnumGenerator
Expand All @@ -27,6 +26,7 @@ import software.amazon.smithy.rust.codegen.server.python.smithy.generators.Pytho
import software.amazon.smithy.rust.codegen.server.smithy.DefaultServerPublicModules
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenVisitor
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
Expand Down Expand Up @@ -90,7 +90,7 @@ class PythonServerCodegenVisitor(
// and #[pymethods] implementation.
PythonServerStructureGenerator(model, symbolProvider, this, shape).render(CodegenTarget.SERVER)
val builderGenerator =
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape)
ServerBuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape)
builderGenerator.render(this)
implBlock(shape, symbolProvider) {
builderGenerator.renderConvenienceMethod(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.CoreRustSettings
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock
Expand All @@ -35,6 +34,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveSha
import software.amazon.smithy.rust.codegen.core.util.CommandFailed
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.runCommand
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerServiceGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
Expand Down Expand Up @@ -182,7 +182,7 @@ open class ServerCodegenVisitor(
rustCrate.useShapeWriter(shape) {
StructureGenerator(model, symbolProvider, this, shape).render(CodegenTarget.SERVER)
val builderGenerator =
BuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape)
ServerBuilderGenerator(codegenContext.model, codegenContext.symbolProvider, shape)
builderGenerator.render(this)
this.implBlock(shape, symbolProvider) {
builderGenerator.renderConvenienceMethod(this)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.NullNode
import software.amazon.smithy.model.node.NumberNode
import software.amazon.smithy.model.node.StringNode
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.IntEnumShape
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectTrait

class ServerBuilderGenerator(
private val model: Model,
private val symbolProvider: RustSymbolProvider,
shape: StructureShape,
) : BuilderGenerator(model, symbolProvider, shape) {
override fun renderDefaultBuilder(writer: RustWriter, member: MemberShape) {
val node = member.expectTrait<DefaultTrait>().toNode()!!
when (val target = model.expectShape(member.target)) {
is EnumShape, is IntEnumShape -> {
val value = when (target) {
is IntEnumShape -> node.expectNumberNode().value
else -> node.expectStringNode().value
}
val enumValues = when (target) {
is IntEnumShape -> target.enumValues
is EnumShape -> target.enumValues
else -> mapOf<String, String>()
}
val variant = enumValues
.entries
.filter { entry -> entry.value == value }
.map { entry -> symbolProvider.toEnumVariantName(EnumDefinition.builder().name(entry.key).value(entry.value.toString()).build())!! }
.first()
val symbol = symbolProvider.toSymbol(target)
writer.rust(".unwrap_or($symbol::${variant.name})")
}
is IntegerShape, is FloatShape -> writer.rust(".unwrap_or(${node.expectNumberNode().value})")
is BooleanShape -> writer.rust(".unwrap_or(${node.expectBooleanNode().value})")
is StringShape -> writer.rust(".unwrap_or_else(|| String::from(${node.expectStringNode().value.dq()}))")
is ListShape, is MapShape -> writer.rust(".unwrap_or_default()")
is DocumentShape -> {
when (node) {
is NullNode -> writer.rust(".unwrap_or_default()")
is BooleanNode -> writer.rust(".unwrap_or(${node.value})")
is StringNode -> writer.rust(".unwrap_or_else(|| String::from(${node.value.dq()}))")
is NumberNode -> writer.rust(".unwrap_or(${node.value})")
else -> writer.rust(".unwrap_or_default()")
}
}
else -> writer.rust(".unwrap_or_default()")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.server.testutil

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerBuilderGenerator

/**
* In tests, we frequently need to generate a struct, a builder, and an impl block to access said builder.
*/
fun StructureShape.renderWithModelBuilder(
model: Model,
symbolProvider: RustSymbolProvider,
writer: RustWriter,
) {
StructureGenerator(model, symbolProvider, writer, this).render(CodegenTarget.SERVER)
val modelBuilder = ServerBuilderGenerator(model, symbolProvider, this)
modelBuilder.render(writer)
writer.implBlock(this, symbolProvider) {
modelBuilder.renderConvenienceMethod(this)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ServerCombinedErrorGenerator
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest
import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider
import software.amazon.smithy.rust.codegen.server.testutil.renderWithModelBuilder

class ServerCombinedErrorGeneratorTest {
private val baseModel = """
Expand Down Expand Up @@ -55,7 +54,7 @@ class ServerCombinedErrorGeneratorTest {
val project = TestWorkspace.testProject(symbolProvider)
project.withModule(RustModule.public("error")) {
listOf("FooException", "ComplexError", "InvalidGreeting", "Deprecated").forEach {
model.lookup<StructureShape>("error#$it").renderWithModelBuilder(model, symbolProvider, this, CodegenTarget.SERVER)
model.lookup<StructureShape>("error#$it").renderWithModelBuilder(model, symbolProvider, this)
}
val errors = listOf("FooException", "ComplexError", "InvalidGreeting").map { model.lookup<StructureShape>("error#$it") }
val generator = ServerCombinedErrorGenerator(model, symbolProvider, symbolProvider.toSymbol(model.lookup("error#Greeting")), errors)
Expand Down
Loading

0 comments on commit 5aa8b94

Please sign in to comment.