Skip to content

Commit

Permalink
Add support for operationContextParams Endpoints trait (#3755)
Browse files Browse the repository at this point in the history
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here -->
We have to support the new [`operationContextParams`
trait](https://smithy.io/2.0/additional-specs/rules-engine/parameters.html#smithy-rules-operationcontextparams-trait)
for endpoint resolution. This trait specifies JMESPath expressions for
selecting parameter data from the operation's input type.

## Description
<!--- Describe your changes in detail -->
* Add codegen support for the [JMESPath
`keys`](https://jmespath.org/specification.html#keys) function (required
by the trait
[spec](https://smithy.io/2.0/additional-specs/rules-engine/parameters.html#smithy-rules-operationcontextparams-trait))
* Add codegen support for the trait itself. This is achieved by
generating `get_param_name` functions for each param specified in
`operationContextParams`. These functions pull the data out of the input
object and it is added to the endpoint params in the
`${operationName}EndpointParamsInterceptor`

## Testing
<!--- Please describe in detail how you tested your changes -->
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->
Updated the existing test suite for JMESPath codegen to test the `keys`
function. Updated the existing EndpointsDecoratorTest with an
`operationContextParams` trait specifying one param of each supported
type to test the codegen.

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
landonxjames authored Jul 13, 2024
1 parent 2313eb9 commit e4a58c3
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 25 deletions.
14 changes: 13 additions & 1 deletion CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,16 @@
# message = "Fix typos in module documentation for generated crates"
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"
# author = "rcoh"

[[smithy-rs]]
message = "Support `stringArray` type in endpoints params"
references = ["smithy-rs#3742"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client"}
author = "landonxjames"

[[smithy-rs]]
message = "Add support for `operationContextParams` Endpoints trait"
references = ["smithy-rs#3755"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client"}
author = "landonxjames"
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ internal class EndpointParamsGenerator(
fun memberName(parameterName: String) = Identifier.of(parameterName).rustName()

fun setterName(parameterName: String) = "set_${memberName(parameterName)}"

fun getterName(parameterName: String) = "get_${memberName(parameterName)}"
}

fun paramsStruct(): RuntimeType =
Expand Down Expand Up @@ -230,7 +232,9 @@ internal class EndpointParamsGenerator(

private fun generateEndpointParamsBuilder(rustWriter: RustWriter) {
rustWriter.docs("Builder for [`Params`]")
Attribute(derive(RuntimeType.Debug, RuntimeType.Default, RuntimeType.PartialEq, RuntimeType.Clone)).render(rustWriter)
Attribute(derive(RuntimeType.Debug, RuntimeType.Default, RuntimeType.PartialEq, RuntimeType.Clone)).render(
rustWriter,
)
rustWriter.rustBlock("pub struct ParamsBuilder") {
parameters.toList().forEach { parameter ->
val name = parameter.memberName()
Expand All @@ -253,7 +257,8 @@ internal class EndpointParamsGenerator(
rustBlockTemplate("#{Params}", "Params" to paramsStruct()) {
parameters.toList().forEach { parameter ->
rust("${parameter.memberName()}: self.${parameter.memberName()}")
parameter.default.orNull()?.also { default -> rust(".or_else(||Some(${value(default)}))") }
parameter.default.orNull()
?.also { default -> rust(".or_else(||Some(${value(default)}))") }
if (parameter.isRequired) {
rustTemplate(
".ok_or_else(||#{Error}::missing(${parameter.memberName().dq()}))?",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.client.smithy.endpoint.generators

import software.amazon.smithy.jmespath.JmespathExpression
import software.amazon.smithy.model.node.ArrayNode
import software.amazon.smithy.model.node.BooleanNode
import software.amazon.smithy.model.node.Node
Expand All @@ -20,16 +21,23 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.EndpointTraitBindings
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.configParamNewtype
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.loadFromConfigBag
import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.RustJmespathShapeTraversalGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.TraversalBinding
import software.amazon.smithy.rust.codegen.client.smithy.generators.waiters.TraversedShape
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.asRef
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.generators.enforceRequired
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.inputShape
Expand Down Expand Up @@ -103,10 +111,16 @@ class EndpointParamsInterceptorGenerator(
#{Ok}(())
}
}
// The get_* functions below are generated from JMESPath expressions in the
// operationContextParams trait. They target the operation's input shape.
#{jmespath_getters}
""",
*codegenScope,
"endpoint_prefix" to endpointPrefix(operationShape),
"param_setters" to paramSetters(operationShape, endpointTypesGenerator.params),
"jmespath_getters" to jmesPathGetters(operationShape),
)
}

Expand Down Expand Up @@ -140,6 +154,33 @@ class EndpointParamsInterceptorGenerator(
rust(".$setterName(#W)", value)
}

idx.getOperationContextParams(operationShape).orNull()?.parameters?.forEach { (name, param) ->
val setterName = EndpointParamsGenerator.setterName(name)
val getterName = EndpointParamsGenerator.getterName(name)
val pathValue = param.path
val pathExpression = JmespathExpression.parse(pathValue)
val pathTraversal =
RustJmespathShapeTraversalGenerator(codegenContext).generate(
pathExpression,
listOf(
TraversalBinding.Global(
"input",
TraversedShape.from(model, operationShape.inputShape(model)),
),
),
)

when (pathTraversal.outputType) {
is RustType.Vec -> {
rust(".$setterName($getterName(_input))")
}

else -> {
rust(".$setterName($getterName(_input).cloned())")
}
}
}

// lastly, allow these to be overridden by members
memberParams.forEach { (memberShape, param) ->
val memberName = codegenContext.symbolProvider.toMemberName(memberShape)
Expand All @@ -151,6 +192,39 @@ class EndpointParamsInterceptorGenerator(
}
}

private fun jmesPathGetters(operationShape: OperationShape) =
writable {
val idx = ContextIndex.of(codegenContext.model)
val inputShape = operationShape.inputShape(codegenContext.model)
val input = symbolProvider.toSymbol(inputShape)

idx.getOperationContextParams(operationShape).orNull()?.parameters?.forEach { (name, param) ->
val getterName = EndpointParamsGenerator.getterName(name)
val pathValue = param.path
val pathExpression = JmespathExpression.parse(pathValue)
val pathTraversal =
RustJmespathShapeTraversalGenerator(codegenContext).generate(
pathExpression,
listOf(
TraversalBinding.Global(
"input",
TraversedShape.from(model, operationShape.inputShape(model)),
),
),
)

rust("// Generated from JMESPath Expression: $pathValue")
rustBlockTemplate(
"fn $getterName(input: #{Input}) -> Option<#{Ret}>",
"Input" to input.rustType().asRef(),
"Ret" to pathTraversal.outputType,
) {
pathTraversal.output(this)
rust("Some(${pathTraversal.identifier})")
}
}
}

private fun Node.toWritable(): Writable {
val node = this
return writable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ data class GeneratedExpression(

internal fun isStringOrEnum(): Boolean = isString() || isEnum()

internal fun isObject(): Boolean = outputShape is TraversedShape.Object

/** Dereferences this expression if it is a reference. */
internal fun dereference(namer: SafeNamer): GeneratedExpression =
if (outputType is RustType.Reference) {
Expand Down Expand Up @@ -278,7 +280,7 @@ class JmesPathTraversalCodegenBugException(msg: String?, what: Throwable? = null
* - Object projections
* - Multi-select lists (but only when every item in the list is the exact same type)
* - And/or/not boolean operations
* - Functions `contains` and `length`. The `keys` function may be supported in the future.
* - Functions `contains`, `length`, and `keys`.
*/
class RustJmespathShapeTraversalGenerator(
codegenContext: ClientCodegenContext,
Expand Down Expand Up @@ -429,6 +431,41 @@ class RustJmespathShapeTraversalGenerator(
}
}

"keys" -> {
if (expr.arguments.size != 1) {
throw InvalidJmesPathTraversalException("Keys function takes exactly one argument")
}
val arg = generate(expr.arguments[0], bindings)
if (!arg.isObject()) {
throw InvalidJmesPathTraversalException("Argument to `keys` function must be an object type")
}
GeneratedExpression(
identifier = ident,
outputType = RustType.Vec(RustType.String),
outputShape = TraversedShape.Array(null, TraversedShape.String(null)),
output =
writable {
arg.output(this)
val outputShape = arg.outputShape.shape
when (outputShape) {
is StructureShape -> {
// Can't iterate a struct in Rust so source the keys from smithy
val keys =
outputShape.allMembers.keys.joinToString(",") { "${it.dq()}.to_string()" }
rust("let $ident = vec![$keys];")
}

is MapShape -> {
rust("let $ident = ${arg.identifier}.keys().map(Clone::clone).collect::<Vec<String>>();")
}

else ->
throw UnsupportedJmesPathException("The shape type for an input to the keys function must be a struct or a map, got ${outputShape?.type}")
}
},
)
}

else -> throw UnsupportedJmesPathException("The `${expr.name}` function is not supported by smithy-rs")
}
}
Expand Down
Loading

0 comments on commit e4a58c3

Please sign in to comment.