diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt index 8544fc1eb1..a7af29426e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/Constraints.kt @@ -94,7 +94,7 @@ fun Shape.isDirectlyConstrained(symbolProvider: SymbolProvider): Boolean = when } fun MemberShape.hasConstraintTraitOrTargetHasConstraintTrait(model: Model, symbolProvider: SymbolProvider): Boolean = - this.isDirectlyConstrained(symbolProvider) || (model.expectShape(this.target).isDirectlyConstrained(symbolProvider)) + this.isDirectlyConstrained(symbolProvider) || model.expectShape(this.target).isDirectlyConstrained(symbolProvider) fun Shape.isTransitivelyButNotDirectlyConstrained(model: Model, symbolProvider: SymbolProvider): Boolean = !this.isDirectlyConstrained(symbolProvider) && this.canReachConstrainedShape(model, symbolProvider) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index 5bfb0f98f2..6b4df3204e 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -29,6 +29,7 @@ object ServerCargoDependency { val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5")) fun smithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-server") + fun smithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-types") } /** diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt index f2c572eedc..4d4efc4c7c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderConstraintViolations.kt @@ -62,13 +62,18 @@ class ServerBuilderConstraintViolations( nonExhaustive: Boolean, shouldRenderAsValidationExceptionFieldList: Boolean, ) { + check(all.isNotEmpty()) { + "Attempted to render constraint violations for the builder for structure shape ${shape.id}, but calculation of the constraint violations resulted in no variants" + } + Attribute.Derives(setOf(RuntimeType.Debug, RuntimeType.PartialEq)).render(writer) writer.docs("Holds one variant for each of the ways the builder can fail.") if (nonExhaustive) Attribute.NonExhaustive.render(writer) val constraintViolationSymbolName = constraintViolationSymbolProvider.toSymbol(shape).name - writer.rustBlock("pub${ if (visibility == Visibility.PUBCRATE) " (crate) " else "" } enum $constraintViolationSymbolName") { + writer.rustBlock("pub${if (visibility == Visibility.PUBCRATE) " (crate) " else ""} enum $constraintViolationSymbolName") { renderConstraintViolations(writer) } + renderImplDisplayConstraintViolation(writer) writer.rust("impl #T for ConstraintViolation { }", RuntimeType.StdError) @@ -93,7 +98,7 @@ class ServerBuilderConstraintViolations( fun forMember(member: MemberShape): ConstraintViolation? { check(members.contains(member)) // TODO(https://github.com/awslabs/smithy-rs/issues/1302, https://github.com/awslabs/smithy/issues/1179): See above. - return if (symbolProvider.toSymbol(member).isOptional()) { + return if (symbolProvider.toSymbol(member).isOptional() || member.hasNonNullDefault()) { null } else { ConstraintViolation(member, ConstraintViolationKind.MISSING_MEMBER) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt index facb2c5b0a..d426717804 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGenerator.kt @@ -27,7 +27,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.smithy.isOptional @@ -99,15 +98,33 @@ class ServerBuilderGenerator( model: Model, symbolProvider: SymbolProvider, takeInUnconstrainedTypes: Boolean, - ): Boolean = - if (takeInUnconstrainedTypes) { - structureShape.canReachConstrainedShape(model, symbolProvider) + ): Boolean { + val members = structureShape.members() + fun isOptional(member: MemberShape) = symbolProvider.toSymbol(member).isOptional() + fun hasDefault(member: MemberShape) = member.hasNonNullDefault() + fun isNotConstrained(member: MemberShape) = !member.canReachConstrainedShape(model, symbolProvider) + + val notFallible = members.all { + if (structureShape.isReachableFromOperationInput()) { + // When deserializing an input structure, constraints might not be satisfied by the data in the + // incoming request. + // For this builder not to be fallible, no members must be constrained (constraints in input must + // always be checked) and all members must _either_ be optional (no need to set it; not required) + // or have a default value. + isNotConstrained(it) && (isOptional(it) || hasDefault(it)) + } else { + // This structure will be constructed manually by the user. + // Constraints will have to be dealt with before members are set in the builder. + isOptional(it) || hasDefault(it) + } + } + + return if (takeInUnconstrainedTypes) { + !notFallible && structureShape.canReachConstrainedShape(model, symbolProvider) } else { - structureShape - .members() - .map { symbolProvider.toSymbol(it) } - .any { !it.isOptional() } + !notFallible } + } } private val takeInUnconstrainedTypes = shape.isReachableFromOperationInput() @@ -497,67 +514,84 @@ class ServerBuilderGenerator( withBlock("$memberName: self.$memberName", ",") { // Write the modifier(s). + + // 1. Enforce constraint traits of data from incoming requests. serverBuilderConstraintViolations.builderConstraintViolationForMember(member)?.also { constraintViolation -> - val hasBox = builderMemberSymbol(member) - .mapRustType { it.stripOuter() } - .isRustBoxed() - if (hasBox) { - rustTemplate( - """ - .map(|v| match *v { - #{MaybeConstrained}::Constrained(x) => Ok(Box::new(x)), - #{MaybeConstrained}::Unconstrained(x) => Ok(Box::new(x.try_into()?)), - }) - .map(|res| - res${ if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())" } - .map_err(|err| ConstraintViolation::${constraintViolation.name()}(Box::new(err))) - ) - .transpose()? - """, - *codegenScope, - ) - } else { - rustTemplate( - """ - .map(|v| match v { - #{MaybeConstrained}::Constrained(x) => Ok(x), - #{MaybeConstrained}::Unconstrained(x) => x.try_into(), - }) - .map(|res| - res${if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())"} - .map_err(ConstraintViolation::${constraintViolation.name()}) - ) - .transpose()? - """, - *codegenScope, - ) - - // Constrained types are not public and this is a member shape that would have generated a - // public constrained type, were the setting to be enabled. - // We've just checked the constraints hold by going through the non-public - // constrained type, but the user wants to work with the unconstrained type, so we have to - // unwrap it. - if (!publicConstrainedTypes && member.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model)) { - rust( - ".map(|v: #T| v.into())", - constrainedShapeSymbolProvider.toSymbol(model.expectShape(member.target)), - ) - } - } + enforceConstraints(this, member, constraintViolation) } - serverBuilderConstraintViolations.forMember(member)?.also { - rust(".ok_or(ConstraintViolation::${it.name()})?") + + if (member.hasNonNullDefault()) { + // 2a. If a `@default` value is modeled and the user did not set a value, fall back to using the + // default value. + generateFallbackCodeToDefaultValue( + this, + member, + model, + runtimeConfig, + symbolProvider, + publicConstrainedTypes, + ) + } else { + // 2b. If the member is `@required` and has no `@default` value, the user must set a value; + // otherwise, we fail with a `ConstraintViolation::Missing*` variant. + serverBuilderConstraintViolations.forMember(member)?.also { + rust(".ok_or(ConstraintViolation::${it.name()})?") + } } } } } } -} -fun buildFnReturnType(isBuilderFallible: Boolean, structureSymbol: Symbol) = writable { - if (isBuilderFallible) { - rust("Result<#T, ConstraintViolation>", structureSymbol) - } else { - rust("#T", structureSymbol) + private fun enforceConstraints(writer: RustWriter, member: MemberShape, constraintViolation: ConstraintViolation) { + // This member is constrained. Enforce the constraint traits on the value set in the builder. + // The code is slightly different in case the member is recursive, since it will be wrapped in + // `std::boxed::Box`. + val hasBox = builderMemberSymbol(member) + .mapRustType { it.stripOuter() } + .isRustBoxed() + if (hasBox) { + writer.rustTemplate( + """ + .map(|v| match *v { + #{MaybeConstrained}::Constrained(x) => Ok(Box::new(x)), + #{MaybeConstrained}::Unconstrained(x) => Ok(Box::new(x.try_into()?)), + }) + .map(|res| + res${ if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())" } + .map_err(|err| ConstraintViolation::${constraintViolation.name()}(Box::new(err))) + ) + .transpose()? + """, + *codegenScope, + ) + } else { + writer.rustTemplate( + """ + .map(|v| match v { + #{MaybeConstrained}::Constrained(x) => Ok(x), + #{MaybeConstrained}::Unconstrained(x) => x.try_into(), + }) + .map(|res| + res${if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())"} + .map_err(ConstraintViolation::${constraintViolation.name()}) + ) + .transpose()? + """, + *codegenScope, + ) + } + + // Constrained types are not public and this is a member shape that would have generated a + // public constrained type, were the setting to be enabled. + // We've just checked the constraints hold by going through the non-public + // constrained type, but the user wants to work with the unconstrained type, so we have to + // unwrap it. + if (!publicConstrainedTypes && member.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model)) { + writer.rust( + ".map(|v: #T| v.into())", + constrainedShapeSymbolProvider.toSymbol(model.expectShape(member.target)), + ) + } } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorCommon.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorCommon.kt new file mode 100644 index 0000000000..56014b6d99 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorCommon.kt @@ -0,0 +1,223 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.node.ArrayNode +import software.amazon.smithy.model.node.BooleanNode +import software.amazon.smithy.model.node.NullNode +import software.amazon.smithy.model.node.NumberNode +import software.amazon.smithy.model.node.ObjectNode +import software.amazon.smithy.model.node.StringNode +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.EnumShape +import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntEnumShape +import software.amazon.smithy.model.shapes.IntegerShape +import software.amazon.smithy.model.shapes.ListShape +import software.amazon.smithy.model.shapes.LongShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.ShortShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.traits.DefaultTrait +import software.amazon.smithy.model.traits.EnumDefinition +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.expectTrait +import software.amazon.smithy.rust.codegen.core.util.isStreaming +import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.hasPublicConstrainedWrapperTupleType + +/** + * Some common freestanding functions shared across: + * - [ServerBuilderGenerator]; and + * - [ServerBuilderGeneratorWithoutPublicConstrainedTypes], + * to keep them DRY and consistent. + */ + +/** + * Returns a writable to render the return type of the server builders' `build()` method. + */ +fun buildFnReturnType(isBuilderFallible: Boolean, structureSymbol: Symbol) = writable { + if (isBuilderFallible) { + rust("Result<#T, ConstraintViolation>", structureSymbol) + } else { + rust("#T", structureSymbol) + } +} + +/** + * Renders code to fall back to the modeled `@default` value on a [member] shape. + * The code is expected to be interpolated right after a value of type `Option`, where `T` is the type of the + * default value. + */ +fun generateFallbackCodeToDefaultValue( + writer: RustWriter, + member: MemberShape, + model: Model, + runtimeConfig: RuntimeConfig, + symbolProvider: RustSymbolProvider, + publicConstrainedTypes: Boolean, +) { + val defaultValue = defaultValue(model, runtimeConfig, symbolProvider, member) + val targetShape = model.expectShape(member.target) + + if (member.isStreaming(model)) { + writer.rust(".unwrap_or_default()") + } else if (targetShape.hasPublicConstrainedWrapperTupleType(model, publicConstrainedTypes)) { + // TODO(https://github.com/awslabs/smithy-rs/issues/2134): Instead of panicking here, which will ungracefully + // shut down the service, perform the `try_into()` check _once_ at service startup time, perhaps + // storing the result in a `OnceCell` that could be reused. + writer.rustTemplate( + """ + .unwrap_or_else(|| + #{DefaultValue:W} + .try_into() + .expect("this check should have failed at generation time; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + ) + """, + "DefaultValue" to defaultValue, + ) + } else { + when (targetShape) { + is NumberShape, is EnumShape, is BooleanShape -> { + writer.rustTemplate(".unwrap_or(#{DefaultValue:W})", "DefaultValue" to defaultValue) + } + // Values for the Rust types of the rest of the shapes require heap allocations, so we calculate them + // in a (lazily-executed) closure for slight performance gains. + else -> { + writer.rustTemplate(".unwrap_or_else(|| #{DefaultValue:W})", "DefaultValue" to defaultValue) + } + } + } +} + +/** + * Returns a writable to construct a Rust value of the correct type holding the modeled `@default` value on the + * [member] shape. + */ +fun defaultValue( + model: Model, + runtimeConfig: RuntimeConfig, + symbolProvider: RustSymbolProvider, + member: MemberShape, +) = writable { + val node = member.expectTrait().toNode()!! + val types = ServerCargoDependency.smithyTypes(runtimeConfig).toType() + // Define the exception once for DRYness. + val unsupportedDefaultValueException = + CodegenException("Default value $node for member shape ${member.id} is unsupported or cannot exist; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + when (val target = model.expectShape(member.target)) { + is EnumShape, is IntEnumShape -> { + val value = when (target) { + is IntEnumShape -> node.expectNumberNode().value + is EnumShape -> node.expectStringNode().value + else -> throw CodegenException("Default value for shape ${target.id} must be of EnumShape or IntEnumShape") + } + val enumValues = when (target) { + is IntEnumShape -> target.enumValues + is EnumShape -> target.enumValues + else -> UNREACHABLE( + "Target shape ${target.id} must be an `EnumShape` or an `IntEnumShape` at this point, otherwise it would have failed above", + ) + } + val variant = enumValues + .entries + .filter { entry -> entry.value == value } + .map { entry -> + symbolProvider.toEnumVariantName( + EnumDefinition.builder().name(entry.key).value(entry.value.toString()).build(), + )!! + } + .first() + rust("#T::${variant.name}", symbolProvider.toSymbol(target)) + } + + is ByteShape -> rust(node.expectNumberNode().value.toString() + "i8") + is ShortShape -> rust(node.expectNumberNode().value.toString() + "i16") + is IntegerShape -> rust(node.expectNumberNode().value.toString() + "i32") + is LongShape -> rust(node.expectNumberNode().value.toString() + "i64") + is FloatShape -> rust(node.expectNumberNode().value.toFloat().toString() + "f32") + is DoubleShape -> rust(node.expectNumberNode().value.toDouble().toString() + "f64") + is BooleanShape -> rust(node.expectBooleanNode().value.toString()) + is StringShape -> rust("String::from(${node.expectStringNode().value.dq()})") + is TimestampShape -> when (node) { + is NumberNode -> rust(node.expectNumberNode().value.toString()) + is StringNode -> { + val value = node.expectStringNode().value + rustTemplate( + """ + #{SmithyTypes}::DateTime::from_str("$value", #{SmithyTypes}::date_time::Format::DateTime) + .expect("default value `$value` cannot be parsed into a valid date time; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + """, + "SmithyTypes" to types, + ) + } + else -> throw unsupportedDefaultValueException + } + is ListShape -> { + check(node is ArrayNode && node.isEmpty) + rust("Vec::new()") + } + is MapShape -> { + check(node is ObjectNode && node.isEmpty) + rust("std::collections::HashMap::new()") + } + is DocumentShape -> { + when (node) { + is NullNode -> rustTemplate( + "#{SmithyTypes}::Document::Null", + "SmithyTypes" to types, + ) + + is BooleanNode -> rustTemplate("""#{SmithyTypes}::Document::Bool(${node.value})""", "SmithyTypes" to types) + is StringNode -> rustTemplate("#{SmithyTypes}::Document::String(String::from(${node.value.dq()}))", "SmithyTypes" to types) + is NumberNode -> { + val value = node.value.toString() + val variant = when (node.value) { + is Float, is Double -> "Float" + else -> if (node.value.toLong() >= 0) "PosInt" else "NegInt" + } + rustTemplate( + "#{SmithyTypes}::Document::Number(#{SmithyTypes}::Number::$variant($value))", + "SmithyTypes" to types, + ) + } + + is ArrayNode -> { + check(node.isEmpty) + rustTemplate("""#{SmithyTypes}::Document::Array(Vec::new())""", "SmithyTypes" to types) + } + + is ObjectNode -> { + check(node.isEmpty) + rustTemplate("#{SmithyTypes}::Document::Object(std::collections::HashMap::new())", "SmithyTypes" to types) + } + + else -> throw unsupportedDefaultValueException + } + } + + is BlobShape -> rust("Default::default()") + + else -> throw unsupportedDefaultValueException + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt index d83446dc83..3e03a9d1a9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderGeneratorWithoutPublicConstrainedTypes.kt @@ -42,7 +42,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType * when `publicConstrainedTypes` is false. */ class ServerBuilderGeneratorWithoutPublicConstrainedTypes( - codegenContext: ServerCodegenContext, + private val codegenContext: ServerCodegenContext, shape: StructureShape, ) { companion object { @@ -55,20 +55,26 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes( fun hasFallibleBuilder( structureShape: StructureShape, symbolProvider: SymbolProvider, - ): Boolean = - structureShape - .members() - .map { symbolProvider.toSymbol(it) } - .any { !it.isOptional() } + ): Boolean { + val members = structureShape.members() + fun isOptional(member: MemberShape) = symbolProvider.toSymbol(member).isOptional() + fun hasDefault(member: MemberShape) = member.hasNonNullDefault() + + val notFallible = members.all { + isOptional(it) || hasDefault(it) + } + + return !notFallible + } } private val model = codegenContext.model private val symbolProvider = codegenContext.symbolProvider private val members: List = shape.allMembers.values.toList() + private val runtimeConfig = codegenContext.runtimeConfig private val structureSymbol = symbolProvider.toSymbol(shape) private val builderSymbol = shape.serverBuilderSymbol(symbolProvider, false) - private val moduleName = builderSymbol.namespace.split("::").last() private val isBuilderFallible = hasFallibleBuilder(shape, symbolProvider) private val serverBuilderConstraintViolations = ServerBuilderConstraintViolations(codegenContext, shape, builderTakesInUnconstrainedTypes = false) @@ -82,6 +88,9 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes( ) fun render(writer: RustWriter) { + check(!codegenContext.settings.codegenConfig.publicConstrainedTypes) { + "ServerBuilderGeneratorWithoutPublicConstrainedTypes should only be used when `publicConstrainedTypes` is false" + } writer.docs("See #D.", structureSymbol) writer.withInlineModule(builderSymbol.module()) { renderBuilder(this) @@ -158,8 +167,23 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes( val memberName = symbolProvider.toMemberName(member) withBlock("$memberName: self.$memberName", ",") { - serverBuilderConstraintViolations.forMember(member)?.also { - rust(".ok_or(ConstraintViolation::${it.name()})?") + if (member.hasNonNullDefault()) { + // 1a. If a `@default` value is modeled and the user did not set a value, fall back to using the + // default value. + generateFallbackCodeToDefaultValue( + this, + member, + model, + runtimeConfig, + symbolProvider, + codegenContext.settings.codegenConfig.publicConstrainedTypes, + ) + } else { + // 1b. If the member is `@required` and has no `@default` value, the user must set a value; + // otherwise, we fail with a `ConstraintViolation::Missing*` variant. + serverBuilderConstraintViolations.forMember(member)?.also { + rust(".ok_or(ConstraintViolation::${it.name()})?") + } } } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt new file mode 100644 index 0000000000..2219c2e653 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt @@ -0,0 +1,371 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.EnumShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.withBlock +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.core.util.toPascalCase +import software.amazon.smithy.rust.codegen.core.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenConfig +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider +import java.util.stream.Stream + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class ServerBuilderDefaultValuesTest { + // When defaults are used, the model will be generated with these in the `@default` trait. + private val defaultValues = mapOf( + "Boolean" to "true", + "String" to "foo".dq(), + "Byte" to "5", + "Short" to "55", + "Integer" to "555", + "Long" to "5555", + "Float" to "0.5", + "Double" to "0.55", + "Timestamp" to "1985-04-12T23:20:50.52Z".dq(), + // "BigInteger" to "55555", "BigDecimal" to "0.555", // TODO(https://github.com/awslabs/smithy-rs/issues/312) + "StringList" to "[]", + "IntegerMap" to "{}", + "Language" to "en".dq(), + "DocumentBoolean" to "true", + "DocumentString" to "foo".dq(), + "DocumentNumberPosInt" to "100", + "DocumentNumberNegInt" to "-100", + "DocumentNumberFloat" to "0.1", + "DocumentList" to "[]", + "DocumentMap" to "{}", + ) + + // When the test applies values to validate we honor custom values, use these (different) values. + private val customValues = mapOf( + "Boolean" to "false", + "String" to "bar".dq(), + "Byte" to "6", + "Short" to "66", + "Integer" to "666", + "Long" to "6666", + "Float" to "0.6", + "Double" to "0.66", + "Timestamp" to "2022-11-25T17:30:50.00Z".dq(), + // "BigInteger" to "55555", "BigDecimal" to "0.555", // TODO(https://github.com/awslabs/smithy-rs/issues/312) + "StringList" to "[]", + "IntegerMap" to "{}", + "Language" to "fr".dq(), + "DocumentBoolean" to "false", + "DocumentString" to "bar".dq(), + "DocumentNumberPosInt" to "1000", + "DocumentNumberNegInt" to "-1000", + "DocumentNumberFloat" to "0.01", + "DocumentList" to "[]", + "DocumentMap" to "{}", + ) + + @ParameterizedTest(name = "(#{index}) Server builders and default values. Params = requiredTrait: {0}, nullDefault: {1}, applyDefaultValues: {2}, builderGeneratorKind: {3}, assertValues: {4}") + @MethodSource("testParameters") + fun `default values are generated and builders respect default and overrides`( + requiredTrait: Boolean, + nullDefault: Boolean, + applyDefaultValues: Boolean, + builderGeneratorKind: BuilderGeneratorKind, + assertValues: Map, + ) { + println("Running test with params = requiredTrait: $requiredTrait, nullDefault: $nullDefault, applyDefaultValues: $applyDefaultValues, builderGeneratorKind: $builderGeneratorKind, assertValues: $assertValues") + val initialSetValues = this.defaultValues.mapValues { if (nullDefault) null else it.value } + val model = generateModel(requiredTrait, applyDefaultValues, nullDefault, initialSetValues) + val symbolProvider = serverTestSymbolProvider(model) + val project = TestWorkspace.testProject(symbolProvider) + + project.withModule(RustModule.public("model")) { + when (builderGeneratorKind) { + BuilderGeneratorKind.SERVER_BUILDER_GENERATOR -> { + writeServerBuilderGenerator(this, model, symbolProvider) + } + BuilderGeneratorKind.SERVER_BUILDER_GENERATOR_WITHOUT_PUBLIC_CONSTRAINED_TYPES -> { + writeServerBuilderGeneratorWithoutPublicConstrainedTypes(this, model, symbolProvider) + } + } + + val rustValues = setupRustValuesForTest(assertValues) + val setters = if (applyDefaultValues) { + structSetters(rustValues, nullDefault && !requiredTrait) + } else { + writable { } + } + val unwrapBuilder = if (nullDefault && requiredTrait && applyDefaultValues) ".unwrap()" else "" + unitTest( + name = "generates_default_required_values", + block = writable { + rustTemplate( + """ + let my_struct = MyStruct::builder() + #{Setters:W} + .build() + $unwrapBuilder; + + #{Assertions:W} + """, + "Assertions" to assertions( + rustValues, + applyDefaultValues, + nullDefault, + requiredTrait, + applyDefaultValues, + ), + "Setters" to setters, + ) + }, + ) + } + + // Run clippy because the builder's code for handling `@default` is prone to upset it. + project.compileAndTest(runClippy = true) + } + + private fun setupRustValuesForTest(valuesMap: Map): Map { + return valuesMap + mapOf( + "Byte" to "${valuesMap["Byte"]}i8", + "Short" to "${valuesMap["Short"]}i16", + "Integer" to "${valuesMap["Integer"]}i32", + "Long" to "${valuesMap["Long"]}i64", + "Float" to "${valuesMap["Float"]}f32", + "Double" to "${valuesMap["Double"]}f64", + "Language" to "crate::model::Language::${valuesMap["Language"]!!.replace(""""""", "").toPascalCase()}", + "Timestamp" to """aws_smithy_types::DateTime::from_str(${valuesMap["Timestamp"]}, aws_smithy_types::date_time::Format::DateTime).unwrap()""", + // These must be empty + "StringList" to "Vec::::new()", + "IntegerMap" to "std::collections::HashMap::::new()", + "DocumentList" to "Vec::::new()", + "DocumentMap" to "std::collections::HashMap::::new()", + ) + valuesMap + .filter { it.value?.startsWith("Document") ?: false } + .map { it.key to "${it.value}.into()" } + } + + private fun writeServerBuilderGeneratorWithoutPublicConstrainedTypes(writer: RustWriter, model: Model, symbolProvider: RustSymbolProvider) { + val struct = model.lookup("com.test#MyStruct") + val codegenContext = serverTestCodegenContext( + model, + settings = serverTestRustSettings( + codegenConfig = ServerCodegenConfig(publicConstrainedTypes = false), + ), + ) + val builderGenerator = ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, struct) + + writer.implBlock(struct, symbolProvider) { + builderGenerator.renderConvenienceMethod(writer) + } + builderGenerator.render(writer) + + ServerEnumGenerator(codegenContext, writer, model.lookup("com.test#Language")).render() + StructureGenerator(model, symbolProvider, writer, struct).render() + } + + private fun writeServerBuilderGenerator(writer: RustWriter, model: Model, symbolProvider: RustSymbolProvider) { + val struct = model.lookup("com.test#MyStruct") + val codegenContext = serverTestCodegenContext(model) + val builderGenerator = ServerBuilderGenerator(codegenContext, struct) + + writer.implBlock(struct, symbolProvider) { + builderGenerator.renderConvenienceMethod(writer) + } + builderGenerator.render(writer) + + ServerEnumGenerator(codegenContext, writer, model.lookup("com.test#Language")).render() + StructureGenerator(model, symbolProvider, writer, struct).render() + } + + private fun structSetters(values: Map, optional: Boolean) = writable { + for ((key, value) in values) { + withBlock(".${key.toSnakeCase()}(", ")") { + conditionalBlock("Some(", ")", optional) { + when (key) { + "String" -> rust("$value.into()") + "DocumentNull" -> rust("aws_smithy_types::Document::Null") + "DocumentString" -> rust("aws_smithy_types::Document::String(String::from($value))") + + else -> { + if (key.startsWith("DocumentNumber")) { + val type = key.replace("DocumentNumber", "") + rust("aws_smithy_types::Document::Number(aws_smithy_types::Number::$type($value))") + } else { + rust("$value.into()") + } + } + } + } + } + } + } + + private fun assertions( + values: Map, + hasSetValues: Boolean, + hasNullValues: Boolean, + requiredTrait: Boolean, + hasDefaults: Boolean, + ) = writable { + for ((key, value) in values) { + val member = "my_struct.${key.toSnakeCase()}" + + if (!hasSetValues) { + rust("assert!($member.is_none());") + } else { + val actual = writable { + rust(member) + if (!requiredTrait && !(hasDefaults && !hasNullValues)) { + rust(".unwrap()") + } + } + val expected = writable { + val expected = if (key == "DocumentNull") { + "aws_smithy_types::Document::Null" + } else if (key == "DocumentString") { + "String::from($value).into()" + } else if (key.startsWith("DocumentNumber")) { + val type = key.replace("DocumentNumber", "") + "aws_smithy_types::Document::Number(aws_smithy_types::Number::$type($value))" + } else if (key.startsWith("Document")) { + "$value.into()" + } else { + "$value" + } + rust(expected) + } + rustTemplate("assert_eq!(#{Actual:W}, #{Expected:W});", "Actual" to actual, "Expected" to expected) + } + } + } + + private fun generateModel( + requiredTrait: Boolean, + applyDefaultValues: Boolean, + nullDefault: Boolean, + values: Map, + ): Model { + val requiredOrNot = if (requiredTrait) "@required" else "" + + val members = values.entries.joinToString(", ") { + val value = if (applyDefaultValues) { + "= ${it.value}" + } else if (nullDefault) { + "= null" + } else { + "" + } + """ + $requiredOrNot + ${it.key.toPascalCase()}: ${it.key} $value + """ + } + val model = + """ + namespace com.test + + structure MyStruct { + $members + } + + enum Language { + EN = "en", + FR = "fr", + } + + list StringList { + member: String + } + + map IntegerMap { + key: String + value: Integer + } + + document DocumentNull + document DocumentBoolean + document DocumentString + document DocumentDecimal + document DocumentNumberNegInt + document DocumentNumberPosInt + document DocumentNumberFloat + document DocumentList + document DocumentMap + """ + return model.asSmithyModel(smithyVersion = "2") + } + + /** + * The builder generator we should test. + * We use an enum instead of directly passing in the closure so that JUnit can print a helpful string in the test + * report. + */ + enum class BuilderGeneratorKind { + SERVER_BUILDER_GENERATOR, + SERVER_BUILDER_GENERATOR_WITHOUT_PUBLIC_CONSTRAINED_TYPES, + } + + private fun testParameters(): Stream { + val builderGeneratorKindList = listOf( + BuilderGeneratorKind.SERVER_BUILDER_GENERATOR, + BuilderGeneratorKind.SERVER_BUILDER_GENERATOR_WITHOUT_PUBLIC_CONSTRAINED_TYPES, + ) + return Stream.of( + TestConfig(defaultValues, requiredTrait = false, nullDefault = true, applyDefaultValues = true), + TestConfig(defaultValues, requiredTrait = false, nullDefault = true, applyDefaultValues = false), + + TestConfig(customValues, requiredTrait = false, nullDefault = true, applyDefaultValues = true), + TestConfig(customValues, requiredTrait = false, nullDefault = true, applyDefaultValues = false), + + TestConfig(defaultValues, requiredTrait = true, nullDefault = true, applyDefaultValues = true), + TestConfig(customValues, requiredTrait = true, nullDefault = true, applyDefaultValues = true), + + TestConfig(defaultValues, requiredTrait = false, nullDefault = false, applyDefaultValues = true), + TestConfig(defaultValues, requiredTrait = false, nullDefault = false, applyDefaultValues = false), + + TestConfig(customValues, requiredTrait = false, nullDefault = false, applyDefaultValues = true), + TestConfig(customValues, requiredTrait = false, nullDefault = false, applyDefaultValues = false), + + TestConfig(defaultValues, requiredTrait = true, nullDefault = false, applyDefaultValues = true), + TestConfig(customValues, requiredTrait = true, nullDefault = false, applyDefaultValues = true), + ).flatMap { (assertValues, requiredTrait, nullDefault, applyDefaultValues) -> + builderGeneratorKindList.stream().map { builderGeneratorKind -> + Arguments.of(requiredTrait, nullDefault, applyDefaultValues, builderGeneratorKind, assertValues) + } + } + } + + data class TestConfig( + // The values in the `assert!()` calls and for the `@default` trait + val assertValues: Map, + // Whether to apply @required to all members + val requiredTrait: Boolean, + // Whether to set all members to `null` and force them to be optional + val nullDefault: Boolean, + // Whether to set `assertValues` in the builder + val applyDefaultValues: Boolean, + ) +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/types.rs b/rust-runtime/aws-smithy-http-server-python/src/types.rs index db3aa183da..1ce7779482 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/types.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/types.rs @@ -322,6 +322,12 @@ impl ByteStream { } } +impl Default for ByteStream { + fn default() -> Self { + Self::new(aws_smithy_http::body::SdkBody::from("")) + } +} + /// ByteStream Abstractions. #[pymethods] impl ByteStream {