Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Source defaults from the model instead of implicitly #2985

Merged
merged 3 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,9 @@ message = "Fix regression with redacting sensitive HTTP response bodies."
references = ["smithy-rs#2926", "smithy-rs#2972"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" }
author = "ysaito1001"

[[smithy-rs]]
message = "Source defaults from the default trait instead of implicitly based on type. This has minimal changes in the generated code."
references = ["smithy-rs#2985"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" }
author = "rcoh"
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.MemberShape
Expand All @@ -14,18 +13,12 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.Instantiator
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName

private fun enumFromStringFn(enumSymbol: Symbol, data: String): Writable = writable {
rust("#T::from($data)", enumSymbol)
}

class ClientBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior {
override fun hasFallibleBuilder(shape: StructureShape): Boolean =
BuilderGenerator.hasFallibleBuilder(shape, codegenContext.symbolProvider)
Expand All @@ -40,7 +33,6 @@ class ClientInstantiator(private val codegenContext: ClientCodegenContext) : Ins
codegenContext.model,
codegenContext.runtimeConfig,
ClientBuilderKindBehavior(codegenContext),
::enumFromStringFn,
) {
fun renderFluentCall(
writer: RustWriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ fun Writable.map(f: RustWriter.(Writable) -> Unit): Writable {
return writable { f(self) }
}

/** Returns Some(..arg) */
fun Writable.some(): Writable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

return this.map { rust("Some(#T)", it) }
}

fun Writable.isNotEmpty(): Boolean = !this.isEmpty()

operator fun Writable.plus(other: Writable): Writable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package software.amazon.smithy.rust.codegen.core.smithy

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
Expand Down Expand Up @@ -102,6 +103,11 @@ sealed class Default {
* This symbol should use the Rust `std::default::Default` when unset
*/
object RustDefault : Default()

/**
* This symbol has a custom default value different from `Default::default`
*/
data class NonZeroDefault(val value: Node) : Default()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs would be nice.

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.NullableIndex
import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.BigDecimalShape
import software.amazon.smithy.model.shapes.BigIntegerShape
import software.amazon.smithy.model.shapes.BlobShape
Expand Down Expand Up @@ -37,6 +38,7 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
Expand All @@ -48,6 +50,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.orNull
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import kotlin.reflect.KClass
Expand Down Expand Up @@ -79,16 +82,18 @@ data class MaybeRenamed(val name: String, val renamedFrom: String?)
/**
* Make the return [value] optional if the [member] symbol is as well optional.
*/
fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String = value.letIf(toSymbol(member).isOptional()) {
"Some($value)"
}
fun SymbolProvider.wrapOptional(member: MemberShape, value: String): String =
value.letIf(toSymbol(member).isOptional()) {
"Some($value)"
}

/**
* Make the return [value] optional if the [member] symbol is not optional.
*/
fun SymbolProvider.toOptional(member: MemberShape, value: String): String = value.letIf(!toSymbol(member).isOptional()) {
"Some($value)"
}
fun SymbolProvider.toOptional(member: MemberShape, value: String): String =
value.letIf(!toSymbol(member).isOptional()) {
"Some($value)"
}

/**
* Services can rename their contained shapes. See https://awslabs.github.io/smithy/1.0/spec/core/model.html#service
Expand All @@ -111,7 +116,7 @@ fun Shape.contextName(serviceShape: ServiceShape?): String {
*/
open class SymbolVisitor(
settings: CoreRustSettings,
override val model: Model,
final override val model: Model,
private val serviceShape: ServiceShape?,
override val config: RustSymbolProviderConfig,
) : RustSymbolProvider, ShapeVisitor<Symbol> {
Expand Down Expand Up @@ -170,7 +175,7 @@ open class SymbolVisitor(
}

private fun simpleShape(shape: SimpleShape): Symbol {
return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).setDefault(Default.RustDefault).build()
return symbolBuilder(shape, SimpleShapes.getValue(shape::class)).build()
}

override fun booleanShape(shape: BooleanShape): Symbol = simpleShape(shape)
Expand Down Expand Up @@ -263,13 +268,20 @@ open class SymbolVisitor(

override fun memberShape(shape: MemberShape): Symbol {
val target = model.expectShape(shape.target)
val defaultValue = shape.getMemberTrait(model, DefaultTrait::class.java).orNull()?.let { trait ->
when (val value = trait.toNode()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to distinguish the case where the user passes in a value to @default that happens to coincide with the default value for the type in Rust, from the case where it doesn't (NonZeroDefault)? Why can't we always set the value from the @default trait?

Generated client code would always use unwrap_or_else(default_value) instead of unwrap_or_default().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • triggers clippy lint
  • causes a huge codegen diff
  • non-zero defaults are only allowed on simple shapes so it enables some simplification of generating defaults

Node.from(""), Node.from(0), Node.from(false), Node.arrayNode(), Node.objectNode() -> Default.RustDefault
Node.nullNode() -> Default.NoDefault
else -> Default.NonZeroDefault(value)
}
} ?: Default.NoDefault
// Handle boxing first, so we end up with Option<Box<_>>, not Box<Option<_>>.
return handleOptionality(
handleRustBoxing(toSymbol(target), shape),
shape,
nullableIndex,
config.nullabilityCheckMode,
)
).toBuilder().setDefault(defaultValue).build()
}

override fun timestampShape(shape: TimestampShape?): Symbol {
Expand Down Expand Up @@ -297,7 +309,12 @@ fun symbolBuilder(shape: Shape?, rustType: RustType): Symbol.Builder =
// If we ever generate a `thisisabug.rs`, there is a bug in our symbol generation
.definitionFile("thisisabug.rs")

fun handleOptionality(symbol: Symbol, member: MemberShape, nullableIndex: NullableIndex, nullabilityCheckMode: CheckMode): Symbol =
fun handleOptionality(
symbol: Symbol,
member: MemberShape,
nullableIndex: NullableIndex,
nullabilityCheckMode: CheckMode,
): Symbol =
symbol.letIf(nullableIndex.isMemberNullable(member, nullabilityCheckMode)) { symbol.makeOptional() }

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
Expand All @@ -41,7 +40,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.canUseDefault
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.defaultValue
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
Expand Down Expand Up @@ -385,15 +383,22 @@ class BuilderGenerator(
members.forEach { member ->
val memberName = symbolProvider.toMemberName(member)
val memberSymbol = symbolProvider.toSymbol(member)
val default = memberSymbol.defaultValue()
withBlock("$memberName: self.$memberName", ",") {
// Write the modifier
when {
!memberSymbol.isOptional() && default == Default.RustDefault -> rust(".unwrap_or_default()")
!memberSymbol.isOptional() -> withBlock(
".ok_or_else(||",
")?",
) { missingRequiredField(memberName) }
val generator = DefaultValueGenerator(runtimeConfig, symbolProvider, model)
val default = generator.defaultValue(member)
if (!memberSymbol.isOptional()) {
if (default != null) {
if (default.isRustDefault) {
rust(".unwrap_or_default()")
} else {
rust(".unwrap_or_else(#T)", default.expr)
}
} else {
withBlock(
".ok_or_else(||",
")?",
) { missingRequiredField(memberName) }
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.SimpleShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.Default
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.defaultValue

class DefaultValueGenerator(
runtimeConfig: RuntimeConfig,
private val symbolProvider: RustSymbolProvider,
private val model: Model,
) {
private val instantiator = PrimitiveInstantiator(runtimeConfig, symbolProvider)

data class DefaultValue(val isRustDefault: Boolean, val expr: Writable)

/** Returns the default value as set by the defaultValue trait */
fun defaultValue(member: MemberShape): DefaultValue? {
val target = model.expectShape(member.target)
return when (val default = symbolProvider.toSymbol(member).defaultValue()) {
is Default.NoDefault -> null
is Default.RustDefault -> DefaultValue(isRustDefault = true, writable("Default::default"))
is Default.NonZeroDefault -> {
val instantiation = instantiator.instantiate(target as SimpleShape, default.value)
DefaultValue(isRustDefault = false, writable { rust("||#T", instantiation) })
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,28 @@ open class EnumGenerator(
},
)

// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;

fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)

rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
}
}

""",
*preludeScope,
)
Expand Down
Loading