Skip to content

Commit

Permalink
Support default trait in server
Browse files Browse the repository at this point in the history
Signed-off-by: Daniele Ahmed <[email protected]>
  • Loading branch information
82marbag authored and Daniele Ahmed committed Nov 30, 2022
1 parent e1de6fa commit 605d33a
Show file tree
Hide file tree
Showing 6 changed files with 476 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ object ServerCargoDependency {
val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5"))

fun SmithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-server")
fun SmithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("types")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,11 @@ class ServerBuilderConstraintViolations(
writer.rustBlock("pub${ if (visibility == Visibility.PUBCRATE) " (crate) " else "" } enum $constraintViolationSymbolName") {
renderConstraintViolations(writer)
}
renderImplDisplayConstraintViolation(writer)
writer.rust("impl #T for ConstraintViolation { }", RuntimeType.StdError)

if (all.isNotEmpty()) {
renderImplDisplayConstraintViolation(writer)
writer.rust("impl #T for ConstraintViolation { }", RuntimeType.StdError)
}

if (shouldRenderAsValidationExceptionFieldList) {
renderAsValidationExceptionFieldList(writer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,41 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.ArrayNode
import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.NullNode
import software.amazon.smithy.model.node.NumberNode
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.node.StringNode
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.ByteShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.DoubleShape
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.IntEnumShape
import software.amazon.smithy.model.shapes.IntegerShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.LongShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.ShortShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.TimestampShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock
import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape
import software.amazon.smithy.rust.codegen.core.rustlang.docs
Expand All @@ -28,7 +53,9 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.isRustBoxed
Expand All @@ -39,9 +66,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.mapRustType
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
Expand Down Expand Up @@ -168,7 +198,7 @@ class ServerBuilderGenerator(
val baseDerives = structureSymbol.expectRustMetadata().derives
val derives = baseDerives.derives.intersect(setOf(RuntimeType.Debug, RuntimeType.Clone)) + RuntimeType.Default
baseDerives.copy(derives = derives).render(writer)
writer.rustBlock("pub${ if (visibility == Visibility.PUBCRATE) " (crate)" else "" } struct Builder") {
writer.rustBlock("pub${if (visibility == Visibility.PUBCRATE) " (crate)" else ""} struct Builder") {
members.forEach { renderBuilderMember(this, it) }
}

Expand Down Expand Up @@ -290,7 +320,8 @@ class ServerBuilderGenerator(
val memberName = symbolProvider.toMemberName(member)

val hasBox = symbol.mapRustType { it.stripOuter<RustType.Option>() }.isRustBoxed()
val wrapInMaybeConstrained = takeInUnconstrainedTypes && member.targetCanReachConstrainedShape(model, symbolProvider)
val wrapInMaybeConstrained =
takeInUnconstrainedTypes && member.targetCanReachConstrainedShape(model, symbolProvider)

writer.documentShape(member, model)
writer.deprecatedShape(member)
Expand All @@ -315,7 +346,11 @@ class ServerBuilderGenerator(
if (!constrainedTypeHoldsFinalType(member)) varExpr = "($varExpr).into()"

if (wrapInMaybeConstrained) {
conditionalBlock("input.map(##[allow(clippy::redundant_closure)] |v| ", ")", conditional = symbol.isOptional()) {
conditionalBlock(
"input.map(##[allow(clippy::redundant_closure)] |v| ",
")",
conditional = symbol.isOptional(),
) {
conditionalBlock("Box::new(", ")", conditional = hasBox) {
rust("$maybeConstrainedVariant($varExpr)")
}
Expand Down Expand Up @@ -474,54 +509,69 @@ class ServerBuilderGenerator(

withBlock("$memberName: self.$memberName", ",") {
// Write the modifier(s).
serverBuilderConstraintViolations.builderConstraintViolationForMember(member)?.also { constraintViolation ->
val hasBox = builderMemberSymbol(member)
.mapRustType { it.stripOuter<RustType.Option>() }
.isRustBoxed()
if (hasBox) {
rustTemplate(
"""
.map(|v| match *v {
#{MaybeConstrained}::Constrained(x) => Ok(Box::new(x)),
#{MaybeConstrained}::Unconstrained(x) => Ok(Box::new(x.try_into()?)),
})
.map(|res|
res${ if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())" }
.map_err(|err| ConstraintViolation::${constraintViolation.name()}(Box::new(err)))
serverBuilderConstraintViolations.builderConstraintViolationForMember(member)
?.also { constraintViolation ->
val hasBox = builderMemberSymbol(member)
.mapRustType { it.stripOuter<RustType.Option>() }
.isRustBoxed()
if (hasBox) {
rustTemplate(
"""
.map(|v| match *v {
#{MaybeConstrained}::Constrained(x) => Ok(Box::new(x)),
#{MaybeConstrained}::Unconstrained(x) => Ok(Box::new(x.try_into()?)),
})
.map(|res|
res${if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())"}
.map_err(|err| ConstraintViolation::${constraintViolation.name()}(Box::new(err)))
)
.transpose()?
""",
*codegenScope,
)
.transpose()?
""",
*codegenScope,
)
} else {
rustTemplate(
"""
.map(|v| match v {
#{MaybeConstrained}::Constrained(x) => Ok(x),
#{MaybeConstrained}::Unconstrained(x) => x.try_into(),
})
.map(|res|
res${if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())"}
.map_err(ConstraintViolation::${constraintViolation.name()})
)
.transpose()?
""",
*codegenScope,
)

// Constrained types are not public and this is a member shape that would have generated a
// public constrained type, were the setting to be enabled.
// We've just checked the constraints hold by going through the non-public
// constrained type, but the user wants to work with the unconstrained type, so we have to
// unwrap it.
if (!publicConstrainedTypes && member.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(model)) {
rust(
".map(|v: #T| v.into())",
constrainedShapeSymbolProvider.toSymbol(model.expectShape(member.target)),
} else {
if (member.hasNonNullDefault()) {
rustTemplate(
"""#{default:W}""",
"default" to renderDefaultBuilder(
model,
runtimeConfig,
symbolProvider,
member,
) { ".or_else(|| Some($it.into()))" },
)
}
rustTemplate(
"""
.map(|v| match v {
#{MaybeConstrained}::Constrained(x) => Ok(x),
#{MaybeConstrained}::Unconstrained(x) => x.try_into(),
})
.map(|res|
res${if (constrainedTypeHoldsFinalType(member)) "" else ".map(|v| v.into())"}
.map_err(ConstraintViolation::${constraintViolation.name()})
)
.transpose()?
""",
*codegenScope,
)

// Constrained types are not public and this is a member shape that would have generated a
// public constrained type, were the setting to be enabled.
// We've just checked the constraints hold by going through the non-public
// constrained type, but the user wants to work with the unconstrained type, so we have to
// unwrap it.
if (!publicConstrainedTypes && member.wouldHaveConstrainedWrapperTupleTypeWerePublicConstrainedTypesEnabled(
model,
)
) {
rust(
".map(|v: #T| v.into())",
constrainedShapeSymbolProvider.toSymbol(model.expectShape(member.target)),
)
}
}
}
}
serverBuilderConstraintViolations.forMember(member)?.also {
rust(".ok_or(ConstraintViolation::${it.name()})?")
}
Expand All @@ -531,6 +581,102 @@ class ServerBuilderGenerator(
}
}

fun renderDefaultBuilder(model: Model, runtimeConfig: RuntimeConfig, symbolProvider: RustSymbolProvider, member: MemberShape, wrap: (s: String) -> String = { it }): Writable {
return writable {
val node = member.expectTrait<DefaultTrait>().toNode()!!
val name = member.memberName
val types = ServerCargoDependency.SmithyTypes(runtimeConfig).toType()
when (val target = model.expectShape(member.target)) {
is EnumShape, is IntEnumShape -> {
val value = when (target) {
is IntEnumShape -> node.expectNumberNode().value
is EnumShape -> node.expectStringNode().value
else -> throw CodegenException("Default value for $name must be of EnumShape or IntEnumShape")
}
val enumValues = when (target) {
is IntEnumShape -> target.enumValues
is EnumShape -> target.enumValues
else -> software.amazon.smithy.rust.codegen.core.util.UNREACHABLE("It must be an [Int]EnumShape, otherwise it'd have failed above")
}
val variant = enumValues
.entries
.filter { entry -> entry.value == value }
.map { entry ->
symbolProvider.toEnumVariantName(
EnumDefinition.builder().name(entry.key).value(entry.value.toString()).build(),
)!!
}
.first()
val symbol = symbolProvider.toSymbol(target)
val result = "$symbol::${variant.name}"
rust(wrap(result))
}

is ByteShape -> rust(wrap(node.expectNumberNode().value.toString() + "i8"))
is ShortShape -> rust(wrap(node.expectNumberNode().value.toString() + "i16"))
is IntegerShape -> rust(wrap(node.expectNumberNode().value.toString() + "i32"))
is LongShape -> rust(wrap(node.expectNumberNode().value.toString() + "i64"))
is FloatShape -> rust(wrap(node.expectNumberNode().value.toFloat().toString() + "f32"))
is DoubleShape -> rust(wrap(node.expectNumberNode().value.toDouble().toString() + "f64"))
is BooleanShape -> rust(wrap(node.expectBooleanNode().value.toString()))
is StringShape -> rust(wrap("String::from(${node.expectStringNode().value.dq()})"))
is TimestampShape -> when (node) {
is NumberNode -> rust(wrap(node.expectNumberNode().value.toString()))
is StringNode -> rustTemplate(
wrap("""#{SmithyTypes}::DateTime::from_str(${wrap(node.expectStringNode().value.dq())}, #{SmithyTypes}::date_time::Format::DateTime).expect("Default value not correct")"""),
"SmithyTypes" to types,
)

else -> throw CodegenException("Default value for $name is unsupported")
}

is ListShape, is MapShape -> {
check((node is ArrayNode && node.isEmpty) || (node is ObjectNode && node.isEmpty))
rust(wrap("std::default::default()"))
}

is DocumentShape -> {
when (node) {
is NullNode -> rustTemplate(
"#{SmithyTypes}::Document::Null",
"SmithyTypes" to types,
)

is BooleanNode -> rust(wrap(node.value.toString()))
is StringNode -> rust(wrap("String::from(${node.value.dq()})"))
is NumberNode -> {
val value = node.value.toString()
val variant = when (node.value) {
is Float, is Double -> "Float"
else -> if (node.value.toLong() >= 0) "PosInt" else "NegInt"
}
rustTemplate(
wrap(
"#{SmithyTypes}::Document::Number(#{SmithyTypes}::Number::$variant($value))",
),
"SmithyTypes" to types,
)
}

is ArrayNode -> {
check(node.isEmpty)
rust(wrap("Vec::new()"))
}

is ObjectNode -> {
check(node.isEmpty)
rust(wrap("std::collections::HashMap::new()"))
}

else -> throw CodegenException("Default value for $name is unsupported or cannot exist")
}
}
is BlobShape -> rust(wrap(RuntimeType.ByteStream(runtimeConfig).toSymbol().fullName + "::default()"))
else -> throw CodegenException("Default value for $name is unsupported or cannot exist")
}
}
}

fun buildFnReturnType(isBuilderFallible: Boolean, structureSymbol: Symbol) = writable {
if (isBuilderFallible) {
rust("Result<#T, ConstraintViolation>", structureSymbol)
Expand Down
Loading

0 comments on commit 605d33a

Please sign in to comment.