Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for operationContextParams Endpoints trait #3755

Merged
merged 5 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,16 @@
# message = "Fix typos in module documentation for generated crates"
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"
# author = "rcoh"

[[smithy-rs]]
message = "Support `stringArray` type in endpoints params"
references = ["smithy-rs#3742"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client"}
author = "landonxjames"

[[smithy-rs]]
message = "Add support for `operationContextParams` Endpoints trait"
references = ["smithy-rs#3755"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client"}
author = "landonxjames"
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ internal class EndpointParamsGenerator(
fun memberName(parameterName: String) = Identifier.of(parameterName).rustName()

fun setterName(parameterName: String) = "set_${memberName(parameterName)}"

fun getterName(parameterName: String) = "get_${memberName(parameterName)}"
}

fun paramsStruct(): RuntimeType =
Expand Down Expand Up @@ -230,7 +232,9 @@ internal class EndpointParamsGenerator(

private fun generateEndpointParamsBuilder(rustWriter: RustWriter) {
rustWriter.docs("Builder for [`Params`]")
Attribute(derive(RuntimeType.Debug, RuntimeType.Default, RuntimeType.PartialEq, RuntimeType.Clone)).render(rustWriter)
Attribute(derive(RuntimeType.Debug, RuntimeType.Default, RuntimeType.PartialEq, RuntimeType.Clone)).render(
rustWriter,
)
rustWriter.rustBlock("pub struct ParamsBuilder") {
parameters.toList().forEach { parameter ->
val name = parameter.memberName()
Expand All @@ -253,7 +257,8 @@ internal class EndpointParamsGenerator(
rustBlockTemplate("#{Params}", "Params" to paramsStruct()) {
parameters.toList().forEach { parameter ->
rust("${parameter.memberName()}: self.${parameter.memberName()}")
parameter.default.orNull()?.also { default -> rust(".or_else(||Some(${value(default)}))") }
parameter.default.orNull()
?.also { default -> rust(".or_else(||Some(${value(default)}))") }
if (parameter.isRequired) {
rustTemplate(
".ok_or_else(||#{Error}::missing(${parameter.memberName().dq()}))?",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators

import software.amazon.smithy.jmespath.JmespathExpression
import software.amazon.smithy.model.node.ArrayNode
import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.Node
Expand All @@ -20,16 +21,23 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.EndpointTraitBindings
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.configParamNewtype
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.loadFromConfigBag
import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.RustJmespathShapeTraversalGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.TraversalBinding
import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.TraversedShape
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.asRef
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.enforceRequired
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.inputShape
Expand Down Expand Up @@ -103,10 +111,16 @@ class EndpointParamsInterceptorGenerator(
#{Ok}(())
}
}

// The get_* functions below are generated from JMESPath expressions in the
// operationContextParams trait. They target the operation's input shape.

#{jmespath_getters}
""",
*codegenScope,
"endpoint_prefix" to endpointPrefix(operationShape),
"param_setters" to paramSetters(operationShape, endpointTypesGenerator.params),
"jmespath_getters" to jmesPathGetters(operationShape),
)
}

Expand Down Expand Up @@ -140,6 +154,33 @@ class EndpointParamsInterceptorGenerator(
rust(".$setterName(#W)", value)
}

idx.getOperationContextParams(operationShape).orNull()?.parameters?.forEach { (name, param) ->
val setterName = EndpointParamsGenerator.setterName(name)
val getterName = EndpointParamsGenerator.getterName(name)
val pathValue = param.path
val pathExpression = JmespathExpression.parse(pathValue)
val pathTraversal =
RustJmespathShapeTraversalGenerator(codegenContext).generate(
pathExpression,
listOf(
TraversalBinding.Global(
"input",
TraversedShape.from(model, operationShape.inputShape(model)),
),
),
)

when (pathTraversal.outputType) {
is RustType.Vec -> {
rust(".$setterName($getterName(_input))")
}

else -> {
rust(".$setterName($getterName(_input).cloned())")
}
}
}

// lastly, allow these to be overridden by members
memberParams.forEach { (memberShape, param) ->
val memberName = codegenContext.symbolProvider.toMemberName(memberShape)
Expand All @@ -151,6 +192,39 @@ class EndpointParamsInterceptorGenerator(
}
}

private fun jmesPathGetters(operationShape: OperationShape) =
writable {
val idx = ContextIndex.of(codegenContext.model)
val inputShape = operationShape.inputShape(codegenContext.model)
val input = symbolProvider.toSymbol(inputShape)

idx.getOperationContextParams(operationShape).orNull()?.parameters?.forEach { (name, param) ->
val getterName = EndpointParamsGenerator.getterName(name)
val pathValue = param.path
val pathExpression = JmespathExpression.parse(pathValue)
val pathTraversal =
RustJmespathShapeTraversalGenerator(codegenContext).generate(
pathExpression,
listOf(
TraversalBinding.Global(
"input",
TraversedShape.from(model, operationShape.inputShape(model)),
),
),
)

rust("// Generated from JMESPath Expression: $pathValue")
rustBlockTemplate(
"fn $getterName(input: #{Input}) -> Option<#{Ret}>",
"Input" to input.rustType().asRef(),
"Ret" to pathTraversal.outputType,
) {
pathTraversal.output(this)
rust("Some(${pathTraversal.identifier})")
}
}
}

private fun Node.toWritable(): Writable {
val node = this
return writable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ data class GeneratedExpression(

internal fun isStringOrEnum(): Boolean = isString() || isEnum()

internal fun isObject(): Boolean = outputShape is TraversedShape.Object

/** Dereferences this expression if it is a reference. */
internal fun dereference(namer: SafeNamer): GeneratedExpression =
if (outputType is RustType.Reference) {
Expand Down Expand Up @@ -278,7 +280,7 @@ class JmesPathTraversalCodegenBugException(msg: String?, what: Throwable? = null
* - Object projections
* - Multi-select lists (but only when every item in the list is the exact same type)
* - And/or/not boolean operations
* - Functions `contains` and `length`. The `keys` function may be supported in the future.
* - Functions `contains`, `length`, and `keys`.
*/
class RustJmespathShapeTraversalGenerator(
codegenContext: ClientCodegenContext,
Expand Down Expand Up @@ -429,6 +431,41 @@ class RustJmespathShapeTraversalGenerator(
}
}

"keys" -> {
if (expr.arguments.size != 1) {
throw InvalidJmesPathTraversalException("Keys function takes exactly one argument")
}
val arg = generate(expr.arguments[0], bindings)
if (!arg.isObject()) {
throw InvalidJmesPathTraversalException("Argument to `keys` function must be an object type")
}
GeneratedExpression(
identifier = ident,
outputType = RustType.Vec(RustType.String),
outputShape = TraversedShape.Array(null, TraversedShape.String(null)),
output =
writable {
arg.output(this)
val outputShape = arg.outputShape.shape
when (outputShape) {
is StructureShape -> {
// Can't iterate a struct in Rust so source the keys from smithy
val keys =
outputShape.allMembers.keys.joinToString(",") { "${it.dq()}.to_string()" }
rust("let $ident = vec![$keys];")
}

is MapShape -> {
rust("let $ident = ${arg.identifier}.keys().map(Clone::clone).collect::<Vec<String>>();")
}

else ->
throw UnsupportedJmesPathException("The shape type for an input to the keys function must be a struct or a map, got ${outputShape?.type}")
}
},
)
}

else -> throw UnsupportedJmesPathException("The `${expr.name}` function is not supported by smithy-rs")
}
}
Expand Down
Loading