Skip to content

Commit

Permalink
Source defaults from the model instead of implicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Sep 14, 2023
1 parent cf8c834 commit aee6933
Show file tree
Hide file tree
Showing 12 changed files with 274 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.MemberShape
Expand All @@ -14,18 +13,12 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName

private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writable {
rust("#T::from($data)", enumSymbol)
}

class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior {
override fun hasFallibleBuilder(shape: StructureShape): Boolean =
BuilderGenerator.hasFallibleBuilder(shape, codegenContext.symbolProvider)
Expand All @@ -40,7 +33,6 @@ class ClientInstantiator(private val codegenContext: ClientCodegenContext) : Ins
codegenContext.model,
codegenContext.runtimeConfig,
ClientBuilderKindBehavior(codegenContext),
::enumFromStringFn,
) {
fun renderFluentCall(
writer: RustWriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ fun Writable.map(f: RustWriter.(Writable) -> Unit): Writable {
return writable { f(self) }
}

/** Returns Some(..arg) */
fun Writable.some(): Writable {
return this.map { rust("Some(#T)", it) }
}

fun Writable.isNotEmpty(): Boolean = !this.isEmpty()

operator fun Writable.plus(other: Writable): Writable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package software.amazon.smithy.rust.codegen.core.smithy

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
Expand Down Expand Up @@ -102,6 +103,8 @@ sealed class Default {
* This symbol should use the Rust `std::default::Default` when unset
*/
object RustDefault : Default()

data class NonZeroDefault(val value: Node) : Default()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.BigDecimalShape
import software.amazon.smithy.model.shapes.BigIntegerShape
import software.amazon.smithy.model.shapes.BlobShape
Expand Down Expand Up @@ -37,6 +38,7 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
Expand All @@ -48,6 +50,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import kotlin.reflect.KClass
Expand Down Expand Up @@ -79,16 +82,18 @@ data class MaybeRenamed(val name: String, val renamedFrom: String?)
/**
* Make the return [value] optional if the [member] symbol is as well optional.
*/
fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String = value.letIf(toSymbol(member).isOptional()) {
"Some($value)"
}
fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String =
value.letIf(toSymbol(member).isOptional()) {
"Some($value)"
}

/**
* Make the return [value] optional if the [member] symbol is not optional.
*/
fun SymbolProvider.toOptional(member: MemberShape, value: String): String = value.letIf(!toSymbol(member).isOptional()) {
"Some($value)"
}
fun SymbolProvider.toOptional(member: MemberShape, value: String): String =
value.letIf(!toSymbol(member).isOptional()) {
"Some($value)"
}

/**
* Services can rename their contained shapes. See https://awslabs.github.io/smithy/1.0/spec/core/model.html#service
Expand Down Expand Up @@ -170,7 +175,7 @@ open class SymbolVisitor(
}

private fun simpleShape(shape: SimpleShape): Symbol {
return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).setDefault(Default.RustDefault).build()
return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).build()
}

override fun booleanShape(shape: BooleanShape): Symbol = simpleShape(shape)
Expand Down Expand Up @@ -263,13 +268,21 @@ open class SymbolVisitor(

override fun memberShape(shape: MemberShape): Symbol {
val target = model.expectShape(shape.target)
val defaultValue = shape.getMemberTrait(model, DefaultTrait::class.java).orNull()?.let { trait ->
when (val value = trait.toNode()) {
Node.from(""), Node.from(0), Node.from(false), Node.arrayNode(), Node.objectNode() -> Default.RustDefault
Node.nullNode() -> Default.NoDefault
else -> { Default.NonZeroDefault(value)
}
}
} ?: Default.NoDefault
// Handle boxing first, so we end up with Option<Box<_>>, not Box<Option<_>>.
return handleOptionality(
handleRustBoxing(toSymbol(target), shape),
shape,
nullableIndex,
config.nullabilityCheckMode,
)
).toBuilder().setDefault(defaultValue).build()
}

override fun timestampShape(shape: TimestampShape?): Symbol {
Expand Down Expand Up @@ -297,7 +310,12 @@ fun symbolBuilder(shape: Shape?, rustType: RustType): Symbol.Builder =
// If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation
.definitionFile("thisisabug.rs")

fun handleOptionality(symbol: Symbol, member: MemberShape, nullableIndex: NullableIndex, nullabilityCheckMode: CheckMode): Symbol =
fun handleOptionality(
symbol: Symbol,
member: MemberShape,
nullableIndex: NullableIndex,
nullabilityCheckMode: CheckMode,
): Symbol =
symbol.letIf(nullableIndex.isMemberNullable(member, nullabilityCheckMode)) { symbol.makeOptional() }

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
Expand All @@ -41,7 +40,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.defaultValue
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
Expand Down Expand Up @@ -385,15 +383,22 @@ class BuilderGenerator(
members.forEach { member ->
val memberName = symbolProvider.toMemberName(member)
val memberSymbol = symbolProvider.toSymbol(member)
val default = memberSymbol.defaultValue()
withBlock("$memberName: self.$memberName", ",") {
// Write the modifier
when {
!memberSymbol.isOptional() && default == Default.RustDefault -> rust(".unwrap_or_default()")
!memberSymbol.isOptional() -> withBlock(
".ok_or_else(||",
")?",
) { missingRequiredField(memberName) }
val generator = DefaultValueGenerator(runtimeConfig, symbolProvider, model)
val default = generator.defaultValue(member)
if (!memberSymbol.isOptional()) {
if (default != null) {
if (default.isRustDefault) {
rust(".unwrap_or_default()")
} else {
rust(".unwrap_or_else(#T)", default.expr)
}
} else {
withBlock(
".ok_or_else(||",
")?",
) { missingRequiredField(memberName) }
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.SimpleShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.defaultValue

class DefaultValueGenerator(
runtimeConfig: RuntimeConfig,
private val symbolProvider: RustSymbolProvider,
private val model: Model,
) {
private val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider)

data class DefaultValue(val isRustDefault: Boolean, val expr: Writable)

/** Returns the default value as set by the defaultValue trait */
fun defaultValue(member: MemberShape): DefaultValue? {
val target = model.expectShape(member.target)
return when (val default = symbolProvider.toSymbol(member).defaultValue()) {
is Default.NoDefault -> null
is Default.RustDefault -> DefaultValue(isRustDefault = true, writable("Default::default"))
is Default.NonZeroDefault -> {
val instantiation = instantiator.instantiate(target as SimpleShape, default.value)
DefaultValue(isRustDefault = false, writable { rust("||#T", instantiation) })
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,28 @@ open class EnumGenerator(
},
)

// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;
fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)

rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
}
}
""",
*preludeScope,
)
Expand Down
Loading

0 comments on commit aee6933

Please sign in to comment.