diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java index 3bff1f18..32a3cd8e 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java @@ -15,25 +15,44 @@ package software.amazon.smithy.go.codegen; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.SymbolUtils.isPointable; +import static software.amazon.smithy.go.codegen.SymbolUtils.sliceOf; +import static software.amazon.smithy.go.codegen.util.ShapeUtil.BOOL_SHAPE; +import static software.amazon.smithy.go.codegen.util.ShapeUtil.INT_SHAPE; import static software.amazon.smithy.go.codegen.util.ShapeUtil.STRING_SHAPE; -import static software.amazon.smithy.go.codegen.util.ShapeUtil.expectMember; import static software.amazon.smithy.go.codegen.util.ShapeUtil.listOf; import static software.amazon.smithy.utils.StringUtils.capitalize; import java.util.List; +import java.util.Map; import software.amazon.smithy.codegen.core.CodegenException; +import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.go.codegen.util.ShapeUtil; import software.amazon.smithy.jmespath.JmespathExpression; +import software.amazon.smithy.jmespath.ast.AndExpression; +import software.amazon.smithy.jmespath.ast.ComparatorExpression; +import software.amazon.smithy.jmespath.ast.ComparatorType; +import software.amazon.smithy.jmespath.ast.CurrentExpression; import software.amazon.smithy.jmespath.ast.FieldExpression; +import software.amazon.smithy.jmespath.ast.FilterProjectionExpression; +import software.amazon.smithy.jmespath.ast.FlattenExpression; import software.amazon.smithy.jmespath.ast.FunctionExpression; +import software.amazon.smithy.jmespath.ast.LiteralExpression; +import software.amazon.smithy.jmespath.ast.MultiSelectListExpression; +import software.amazon.smithy.jmespath.ast.NotExpression; import software.amazon.smithy.jmespath.ast.ProjectionExpression; import software.amazon.smithy.jmespath.ast.Subexpression; -import software.amazon.smithy.model.shapes.ListShape; +import software.amazon.smithy.model.shapes.CollectionShape; +import software.amazon.smithy.model.shapes.MapShape; +import software.amazon.smithy.model.shapes.NumberShape; import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.StringShape; import software.amazon.smithy.utils.SmithyInternalApi; /** * Traverses a JMESPath expression, producing a series of statements that evaluate the entire expression. The generator - * is shape-aware and the return indicates the underlying shape being referenced in the final result. + * is shape-aware and the return indicates the underlying shape/symbol being referenced in the final result. *
* Note that the use of writer.write() here is deliberate, it's easier to structure the code in that way instead of * trying to recursively compose/organize Writable templates. @@ -42,24 +61,19 @@ public class GoJmespathExpressionGenerator { private final GoCodegenContext ctx; private final GoWriter writer; - private final Shape input; - private final JmespathExpression root; - private int idIndex = 1; + private int idIndex = 0; - public GoJmespathExpressionGenerator(GoCodegenContext ctx, GoWriter writer, Shape input, JmespathExpression expr) { + public GoJmespathExpressionGenerator(GoCodegenContext ctx, GoWriter writer) { this.ctx = ctx; this.writer = writer; - this.input = input; - this.root = expr; } - public Result generate(String ident) { - writer.write("v1 := $L", ident); - return visit(root, input); + public Variable generate(JmespathExpression expr, Variable input) { + return visit(expr, input); } - private Result visit(JmespathExpression expr, Shape current) { + private Variable visit(JmespathExpression expr, Variable current) { if (expr instanceof FunctionExpression tExpr) { return visitFunction(tExpr, current); } else if (expr instanceof FieldExpression tExpr) { @@ -68,61 +82,252 @@ private Result visit(JmespathExpression expr, Shape current) { return visitSub(tExpr, current); } else if (expr instanceof ProjectionExpression tExpr) { return visitProjection(tExpr, current); + } else if (expr instanceof FlattenExpression tExpr) { + return visitFlatten(tExpr, current); + } else if (expr instanceof ComparatorExpression tExpr) { + return visitComparator(tExpr, current); + } else if (expr instanceof LiteralExpression tExpr) { + return visitLiteral(tExpr); + } else if (expr instanceof AndExpression tExpr) { + return visitAnd(tExpr, current); + } else if (expr instanceof NotExpression tExpr) { + return visitNot(tExpr, current); + } else if (expr instanceof FilterProjectionExpression tExpr) { + return visitFilterProjection(tExpr, current); + } else if (expr instanceof MultiSelectListExpression tExpr) { + return visitMultiSelectList(tExpr, current); + } else if (expr instanceof CurrentExpression) { + return current; } else { throw new CodegenException("unhandled jmespath expression " + expr.getClass().getSimpleName()); } } - private Result visitProjection(ProjectionExpression expr, Shape current) { + private Variable visitNot(NotExpression expr, Variable current) { + var inner = visit(expr.getExpression(), current); + var ident = nextIdent(); + writer.write("$L := !$L", ident, inner.ident); + return new Variable(BOOL_SHAPE, ident, GoUniverseTypes.Bool); + } + + private Variable visitMultiSelectList(MultiSelectListExpression expr, Variable current) { + if (expr.getExpressions().isEmpty()) { + throw new CodegenException("multi-select list w/ no expressions"); + } + + var items = expr.getExpressions().stream() + .map(it -> visit(it, current)) + .toList(); + var first = items.get(0); + + var ident = nextIdent(); + writer.write("$L := []$P{$L}", ident, first.type, + String.join(",", items.stream().map(it -> it.ident).toList())); + + return new Variable(listOf(first.shape), ident, sliceOf(first.type)); + } + + private Variable visitFilterProjection(FilterProjectionExpression expr, Variable current) { + var unfiltered = visitProjection(new ProjectionExpression(expr.getLeft(), expr.getRight()), current); + if (!(unfiltered.shape instanceof CollectionShape unfilteredCol)) { + throw new CodegenException("projection did not create a list: " + expr); + } + + var member = expectMember(unfilteredCol); + var type = ctx.symbolProvider().toSymbol(unfiltered.shape); + + var ident = nextIdent(); + writer.write("var $L $T", ident, type); + writer.openBlock("for _, v := range $L {", "}", unfiltered.ident, () -> { + var filterResult = visit(expr.getComparison(), new Variable(member, "v", type)); + writer.write(""" + if $1L { + $2L = append($2L, v) + }""", filterResult.ident, ident); + }); + + return new Variable(unfiltered.shape, ident, type); + } + + private Variable visitAnd(AndExpression expr, Variable current) { var left = visit(expr.getLeft(), current); + var right = visit(expr.getRight(), current); + var ident = nextIdent(); + writer.write("$L := $L && $L", ident, left.ident, right.ident); + return new Variable(BOOL_SHAPE, ident, GoUniverseTypes.Bool); + } + + private Variable visitLiteral(LiteralExpression expr) { + var ident = nextIdent(); + if (expr.isNumberValue()) { + // FUTURE: recognize floating-point, for now we just use int + writer.write("$L := $L", ident, expr.expectNumberValue().intValue()); + return new Variable(INT_SHAPE, ident, GoUniverseTypes.Int); + } else if (expr.isStringValue()) { + writer.write("$L := $S", ident, expr.expectStringValue()); + return new Variable(STRING_SHAPE, ident, GoUniverseTypes.String); + } else if (expr.isBooleanValue()) { + writer.write("$L := $L", ident, expr.expectBooleanValue()); + return new Variable(STRING_SHAPE, ident, GoUniverseTypes.Bool); + } else { + throw new CodegenException("unhandled literal expression " + expr.getValue()); + } + } + + private Variable visitComparator(ComparatorExpression expr, Variable current) { + var left = visit(expr.getLeft(), current); + var right = visit(expr.getRight(), current); + + String cast; + if (left.shape instanceof StringShape) { + cast = "string"; + } else if (left.shape instanceof NumberShape) { + cast = "int64"; + } else { + throw new CodegenException("don't know how to compare shape type" + left.shape.getType()); + } + + var ident = nextIdent(); + writer.write(compareVariables(ident, left, right, expr.getComparator(), cast)); + return new Variable(BOOL_SHAPE, ident, GoUniverseTypes.Bool); + } + + private Variable visitFlatten(FlattenExpression tExpr, Variable current) { + var inner = visit(tExpr.getExpression(), current); + + // inner HAS to be a list by spec, otherwise something is wrong + if (!(inner.shape instanceof CollectionShape innerList)) { + throw new CodegenException("projection did not create a list: " + tExpr); + } + + // inner expression may not be a list-of-list - if so, we're done, the result is passed up as-is + var innerMember = expectMember(innerList); + if (!(innerMember instanceof CollectionShape)) { + return inner; + } + + var innerSymbol = ctx.symbolProvider().toSymbol(innerMember); + var ident = nextIdent(); + writer.write(""" + var $1L $3P + for _, v := range $2L { + $1L = append($1L, v...) + }""", ident, inner.ident, innerSymbol); + return new Variable(innerMember, ident, innerSymbol); + } + + private Variable visitProjection(ProjectionExpression expr, Variable current) { + var left = visit(expr.getLeft(), current); + if (expr.getRight() instanceof CurrentExpression) { // e.g. "Field[]" - the projection is just itself + return left; + } - // left of projection HAS to be an array by spec, otherwise something is wrong - if (!left.shape.isListShape()) { - throw new CodegenException("left side of projection did not create a list"); + Shape leftMember; + if (left.shape instanceof CollectionShape col) { + leftMember = expectMember(col); + } else if (left.shape instanceof MapShape map) { + leftMember = expectMember(map); + } else { + // left of projection HAS to be an array/map by spec, otherwise something is wrong + throw new CodegenException("projection did not create a list: " + expr); } - var leftMember = expectMember(ctx.model(), (ListShape) left.shape); + var leftSymbol = ctx.symbolProvider().toSymbol(leftMember); // We have to know the element type for the list that we're generating, use a dummy writer to "peek" ahead and // get the traversal result - var lookahead = new GoJmespathExpressionGenerator(ctx, new GoWriter(""), leftMember, expr.getRight()) - .generate("v"); + var lookahead = new GoJmespathExpressionGenerator(ctx, new GoWriter("")) + .generate(expr.getRight(), new Variable(leftMember, "v", leftSymbol)); - ++idIndex; + var ident = nextIdent(); writer.write(""" - var v$L []$P - for _, v := range $L {""", idIndex, ctx.symbolProvider().toSymbol(lookahead.shape), left.ident); + var $L []$T + for _, v := range $L {""", ident, ctx.symbolProvider().toSymbol(lookahead.shape), left.ident); - // new scope inside loop, but now we actually want to write the contents + writer.indent(); // projected.shape is the _member_ of the resulting list - var projected = new GoJmespathExpressionGenerator(ctx, writer, leftMember, expr.getRight()) - .generate("v"); - - writer.write("v$1L = append(v$1L, $2L)", idIndex, projected.ident); + var projected = visit(expr.getRight(), new Variable(leftMember, "v", leftSymbol)); + if (isPointable(lookahead.type)) { // projections implicitly filter out nil evaluations of RHS + writer.write(""" + if $2L != nil { + $1L = append($1L, *$2L) + }""", ident, projected.ident); + } else { + writer.write("$1L = append($1L, $2L)", ident, projected.ident); + } + writer.dedent(); writer.write("}"); - return new Result(listOf(projected.shape), "v" + idIndex); + return new Variable(listOf(projected.shape), ident, sliceOf(ctx.symbolProvider().toSymbol(projected.shape))); } - private Result visitSub(Subexpression expr, Shape current) { + private Variable visitSub(Subexpression expr, Variable current) { var left = visit(expr.getLeft(), current); - return visit(expr.getRight(), left.shape); + return visit(expr.getRight(), left); } - private Result visitField(FieldExpression expr, Shape current) { - ++idIndex; - writer.write("v$L := v$L.$L", idIndex, idIndex - 1, capitalize(expr.getName())); - return new Result(expectMember(ctx.model(), current, expr.getName()), "v" + idIndex); + private Variable visitField(FieldExpression expr, Variable current) { + var member = current.shape.getMember(expr.getName()).orElseThrow(() -> + new CodegenException("field expression referenced nonexistent member: " + expr.getName())); + + var target = ctx.model().expectShape(member.getTarget()); + var ident = nextIdent(); + writer.write("$L := $L.$L", ident, current.ident, capitalize(expr.getName())); + return new Variable(target, ident, ctx.symbolProvider().toSymbol(member)); } - private Result visitFunction(FunctionExpression expr, Shape current) { + private Variable visitFunction(FunctionExpression expr, Variable current) { return switch (expr.name) { case "keys" -> visitKeysFunction(expr.arguments, current); + case "length" -> visitLengthFunction(expr.arguments, current); + case "contains" -> visitContainsFunction(expr.arguments, current); default -> throw new CodegenException("unsupported function " + expr.name); }; } - private Result visitKeysFunction(List args, Shape current) { + private Variable visitContainsFunction(List args, Variable current) { + if (args.size() != 2) { + throw new CodegenException("unexpected contains() arg length " + args.size()); + } + + var list = visit(args.get(0), current); + var item = visit(args.get(1), current); + var ident = nextIdent(); + writer.write(""" + var $1L bool + for _, v := range $2L { + if v == $3L { + $1L = true + break + } + }""", ident, list.ident, item.ident); + return new Variable(BOOL_SHAPE, ident, GoUniverseTypes.Bool); + } + + private Variable visitLengthFunction(List args, Variable current) { + if (args.size() != 1) { + throw new CodegenException("unexpected length() arg length " + args.size()); + } + + var arg = visit(args.get(0), current); + var ident = nextIdent(); + + // length() can be used on a string (so also *string) - dereference if required + if (arg.shape instanceof StringShape && isPointable(arg.type)) { + writer.write(""" + var _$1L string + if $1L != nil { + _$1L = *$1L + } + $2L := len(_$1L)""", arg.ident, ident); + } else { + writer.write("$L := len($L)", ident, arg.ident); + } + + return new Variable(INT_SHAPE, ident, GoUniverseTypes.Int); + } + + private Variable visitKeysFunction(List args, Variable current) { if (args.size() != 1) { throw new CodegenException("unexpected keys() arg length " + args.size()); } @@ -135,8 +340,71 @@ private Result visitKeysFunction(List args, Shape current) { v$1L = append(v$1L, k) }""", idIndex, arg.ident); - return new Result(listOf(STRING_SHAPE), "v" + idIndex); + return new Variable(listOf(STRING_SHAPE), "v" + idIndex, sliceOf(GoUniverseTypes.String)); + } + + private String nextIdent() { + ++idIndex; + return "v" + idIndex; + } + + private Shape expectMember(CollectionShape shape) { + return switch (shape.getMember().getTarget().toString()) { + case "smithy.go.synthetic#StringList" -> listOf(STRING_SHAPE); + case "smithy.go.synthetic#IntegerList" -> listOf(INT_SHAPE); + case "smithy.go.synthetic#BooleanList" -> listOf(BOOL_SHAPE); + default -> ShapeUtil.expectMember(ctx.model(), shape); + }; + } + + private Shape expectMember(MapShape shape) { + return switch (shape.getValue().getTarget().toString()) { + case "smithy.go.synthetic#StringList" -> listOf(STRING_SHAPE); + case "smithy.go.synthetic#IntegerList" -> listOf(INT_SHAPE); + case "smithy.go.synthetic#BooleanList" -> listOf(BOOL_SHAPE); + default -> ShapeUtil.expectMember(ctx.model(), shape); + }; } - public record Result(Shape shape, String ident) {} + // helper to generate comparisons from two results, automatically handling any dereferencing in the process + private GoWriter.Writable compareVariables(String ident, Variable left, Variable right, ComparatorType cmp, + String cast) { + var isLPtr = isPointable(left.type); + var isRPtr = isPointable(right.type); + if (!isLPtr && !isRPtr) { + return goTemplate("$1L := $5L($2L) $4L $5L($3L)", ident, left.ident, right.ident, cmp, cast); + } + + return goTemplate(""" + var $ident:L bool + if $lif:L $amp:L $rif:L { + $ident:L = $cast:L($lhs:L) $cmp:L $cast:L($rhs:L) + }""", + Map.of( + "ident", ident, + "lif", isLPtr ? left.ident + " != nil" : "", + "rif", isRPtr ? right.ident + " != nil" : "", + "amp", isLPtr && isRPtr ? "&&" : "", + "cmp", cmp, + "lhs", isLPtr ? "*" + left.ident : left.ident, + "rhs", isRPtr ? "*" + right.ident : right.ident, + "cast", cast + )); + } + + /** + * Represents a variable (input, intermediate, or final output) of a JMESPath traversal. + * @param shape The underlying shape referenced by this variable. For certain jmespath expressions (e.g. + * LiteralExpression) the value here is a synthetic shape and does not necessarily have meaning. + * @param ident The identifier of the variable in the generated traversal. + * @param type The symbol that records the type of the variable. This does NOT necessarily correspond to the result + * of toSymbol(shape) because certain jmespath expressions (such as projections) may affect the type of + * the resulting variable in a way that severs that relationship. The caller MUST use this field to + * determine whether the variable is pointable/nillable. + */ + public record Variable(Shape shape, String ident, Symbol type) { + public Variable(Shape shape, String ident) { + this(shape, ident, null); + } + } } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/endpoints/EndpointParameterOperationBindingsGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/endpoints/EndpointParameterOperationBindingsGenerator.java index faf9980c..5a63aad4 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/endpoints/EndpointParameterOperationBindingsGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/endpoints/EndpointParameterOperationBindingsGenerator.java @@ -107,10 +107,10 @@ private GoWriter.Writable generateOpContextParamBinding(String paramName, Operat var expr = JmespathExpression.parse(def.getPath()); return writer -> { - var generator = new GoJmespathExpressionGenerator(ctx, writer, input, expr); + var generator = new GoJmespathExpressionGenerator(ctx, writer); writer.write("func() {"); // contain the scope for each binding - var result = generator.generate("in"); + var result = generator.generate(expr, new GoJmespathExpressionGenerator.Variable(input, "in")); if (param.getType().equals(ParameterType.STRING_ARRAY)) { // projections can result in either []string OR []*string -- if the latter, we have to unwrap diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Waiters2.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Waiters2.java new file mode 100644 index 00000000..58eb1d36 --- /dev/null +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/Waiters2.java @@ -0,0 +1,763 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.smithy.go.codegen.integration; + +import static java.util.Collections.emptySet; +import static software.amazon.smithy.go.codegen.GoWriter.autoDocTemplate; +import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.SymbolUtils.isPointable; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import software.amazon.smithy.codegen.core.CodegenException; +import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.codegen.core.SymbolProvider; +import software.amazon.smithy.go.codegen.ClientOptions; +import software.amazon.smithy.go.codegen.GoCodegenContext; +import software.amazon.smithy.go.codegen.GoJmespathExpressionGenerator; +import software.amazon.smithy.go.codegen.GoWriter; +import software.amazon.smithy.go.codegen.SmithyGoDependency; +import software.amazon.smithy.go.codegen.SymbolUtils; +import software.amazon.smithy.jmespath.JmespathExpression; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.knowledge.TopDownIndex; +import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.Shape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.shapes.StructureShape; +import software.amazon.smithy.utils.StringUtils; +import software.amazon.smithy.waiters.Acceptor; +import software.amazon.smithy.waiters.Matcher; +import software.amazon.smithy.waiters.PathComparator; +import software.amazon.smithy.waiters.WaitableTrait; +import software.amazon.smithy.waiters.Waiter; + +/** + * Implements support for WaitableTrait. + */ +public class Waiters2 implements GoIntegration { + private static final String WAITER_INVOKER_FUNCTION_NAME = "Wait"; + private static final String WAITER_INVOKER_WITH_OUTPUT_FUNCTION_NAME = "WaitForOutput"; + + public Set getAdditionalClientOptions() { + return emptySet(); + } + + @Override + public void writeAdditionalFiles(GoCodegenContext ctx) { + var service = ctx.settings().getService(ctx.model()); + + TopDownIndex.of(ctx.model()).getContainedOperations(service).stream() + .forEach(operation -> { + if (!operation.hasTrait(WaitableTrait.ID)) { + return; + } + + Map waiters = operation.expectTrait(WaitableTrait.class).getWaiters(); + generateOperationWaiter(ctx, operation, waiters); + }); + } + + + /** + * Generates all waiter components used for the operation. + */ + private void generateOperationWaiter(GoCodegenContext ctx, OperationShape operation, Map waiters) { + var model = ctx.model(); + var symbolProvider = ctx.symbolProvider(); + ctx.writerDelegator().useShapeWriter(operation, writer -> { + waiters.forEach((name, waiter) -> { + generateWaiterOptions(model, symbolProvider, writer, operation, name, waiter); + generateWaiterClient(model, symbolProvider, writer, operation, name, waiter); + generateWaiterInvoker(model, symbolProvider, writer, operation, name, waiter); + generateWaiterInvokerWithOutput(model, symbolProvider, writer, operation, name, waiter); + generateRetryable(ctx, writer, operation, name, waiter); + }); + }); + } + + /** + * Generates waiter options to configure a waiter client. + */ + private void generateWaiterOptions( + Model model, + SymbolProvider symbolProvider, + GoWriter writer, + OperationShape operationShape, + String waiterName, + Waiter waiter + ) { + String optionsName = generateWaiterOptionsName(waiterName); + String waiterClientName = generateWaiterClientName(waiterName); + + StructureShape inputShape = model.expectShape( + operationShape.getInput().get(), StructureShape.class + ); + StructureShape outputShape = model.expectShape( + operationShape.getOutput().get(), StructureShape.class + ); + + Symbol inputSymbol = symbolProvider.toSymbol(inputShape); + Symbol outputSymbol = symbolProvider.toSymbol(outputShape); + + writer.write(""); + writer.writeDocs( + String.format("%s are waiter options for %s", optionsName, waiterClientName) + ); + + writer.openBlock("type $L struct {", "}", + optionsName, () -> { + writer.addUseImports(SmithyGoDependency.TIME); + + writer.write(""); + var apiOptionsDocs = autoDocTemplate(""" + Set of options to modify how an operation is invoked. These apply to all operations invoked + for this client. Use functional options on operation call to modify this list for per + operation behavior. + + Passing options here is functionally equivalent to passing values to this config's + ClientOptions field that extend the inner client's APIOptions directly."""); + Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack", + SmithyGoDependency.SMITHY_MIDDLEWARE) + .build(); + writer.write(goTemplate(""" + $W + APIOptions []func($P) error + """, apiOptionsDocs, stackSymbol)); + + var clientOptionsDocs = autoDocTemplate(""" + Functional options to be passed to all operations invoked by this client. + + Function values that modify the inner APIOptions are applied after the waiter config's own + APIOptions modifiers."""); + writer.write(""); + writer.write(goTemplate(""" + $W + ClientOptions []func(*$L) + """, clientOptionsDocs, ClientOptions.NAME)); + + writer.write(""); + writer.writeDocs( + String.format("MinDelay is the minimum amount of time to delay between retries. " + + "If unset, %s will use default minimum delay of %s seconds. " + + "Note that MinDelay must resolve to a value lesser than or equal " + + "to the MaxDelay.", waiterClientName, waiter.getMinDelay()) + ); + writer.write("MinDelay time.Duration"); + + writer.write(""); + writer.writeDocs( + String.format("MaxDelay is the maximum amount of time to delay between retries. " + + "If unset or set to zero, %s will use default max delay of %s seconds. " + + "Note that MaxDelay must resolve to value greater than or equal " + + "to the MinDelay.", waiterClientName, waiter.getMaxDelay()) + ); + writer.write("MaxDelay time.Duration"); + + writer.write(""); + writer.writeDocs("LogWaitAttempts is used to enable logging for waiter retry attempts"); + writer.write("LogWaitAttempts bool"); + + writer.write(""); + writer.writeDocs( + "Retryable is function that can be used to override the " + + "service defined waiter-behavior based on operation output, or returned error. " + + "This function is used by the waiter to decide if a state is retryable " + + "or a terminal state.\n\nBy default service-modeled logic " + + "will populate this option. This option can thus be used to define a custom " + + "waiter state with fall-back to service-modeled waiter state mutators." + + "The function returns an error in case of a failure state. " + + "In case of retry state, this function returns a bool value of true and " + + "nil error, while in case of success it returns a bool value of false and " + + "nil error." + ); + writer.write( + "Retryable func(context.Context, $P, $P, error) " + + "(bool, error)", inputSymbol, outputSymbol); + } + ); + writer.write(""); + } + + + /** + * Generates waiter client used to invoke waiter function. The waiter client is specific to a modeled waiter. + * Each waiter client is unique within a enclosure of a service. + * This function also generates a waiter client constructor that takes in a API client interface, and waiter options + * to configure a waiter client. + */ + private void generateWaiterClient( + Model model, + SymbolProvider symbolProvider, + GoWriter writer, + OperationShape operationShape, + String waiterName, + Waiter waiter + ) { + Symbol operationSymbol = symbolProvider.toSymbol(operationShape); + String clientName = generateWaiterClientName(waiterName); + + writer.write(""); + writer.writeDocs( + String.format("%s defines the waiters for %s", clientName, waiterName) + ); + writer.openBlock("type $L struct {", "}", + clientName, () -> { + writer.write(""); + writer.write("client $L", OperationInterfaceGenerator.getApiClientInterfaceName(operationSymbol)); + + writer.write(""); + writer.write("options $L", generateWaiterOptionsName(waiterName)); + }); + + writer.write(""); + + String constructorName = String.format("New%s", clientName); + + Symbol waiterOptionsSymbol = SymbolUtils.createPointableSymbolBuilder( + generateWaiterOptionsName(waiterName) + ).build(); + + Symbol clientSymbol = SymbolUtils.createPointableSymbolBuilder( + clientName + ).build(); + + writer.writeDocs( + String.format("%s constructs a %s.", constructorName, clientName) + ); + writer.openBlock("func $L(client $L, optFns ...func($P)) $P {", "}", + constructorName, OperationInterfaceGenerator.getApiClientInterfaceName(operationSymbol), + waiterOptionsSymbol, clientSymbol, () -> { + writer.write("options := $T{}", waiterOptionsSymbol); + writer.addUseImports(SmithyGoDependency.TIME); + + // set defaults + writer.write("options.MinDelay = $L * time.Second", waiter.getMinDelay()); + writer.write("options.MaxDelay = $L * time.Second", waiter.getMaxDelay()); + writer.write("options.Retryable = $L", generateRetryableName(waiterName)); + writer.write(""); + + writer.openBlock("for _, fn := range optFns {", + "}", () -> { + writer.write("fn(&options)"); + }); + + writer.openBlock("return &$T {", "}", clientSymbol, () -> { + writer.write("client: client, "); + writer.write("options: options, "); + }); + }); + } + + /** + * Generates waiter invoker functions to call specific operation waiters + * These waiter invoker functions is defined on each modeled waiter client. + * The invoker function takes in a context, along with operation input, and + * optional functional options for the waiter. + */ + private void generateWaiterInvoker( + Model model, + SymbolProvider symbolProvider, + GoWriter writer, + OperationShape operationShape, + String waiterName, + Waiter waiter + ) { + StructureShape inputShape = model.expectShape( + operationShape.getInput().get(), StructureShape.class + ); + + Symbol inputSymbol = symbolProvider.toSymbol(inputShape); + + Symbol waiterOptionsSymbol = SymbolUtils.createPointableSymbolBuilder( + generateWaiterOptionsName(waiterName) + ).build(); + + Symbol clientSymbol = SymbolUtils.createPointableSymbolBuilder( + generateWaiterClientName(waiterName) + ).build(); + + writer.write(""); + writer.addUseImports(SmithyGoDependency.CONTEXT); + writer.addUseImports(SmithyGoDependency.TIME); + writer.writeDocs( + String.format( + "%s calls the waiter function for %s waiter. The maxWaitDur is the maximum wait duration " + + "the waiter will wait. The maxWaitDur is required and must be greater than zero.", + WAITER_INVOKER_FUNCTION_NAME, waiterName) + ); + writer.openBlock( + "func (w $P) $L(ctx context.Context, params $P, maxWaitDur time.Duration, optFns ...func($P)) error {", + "}", + clientSymbol, WAITER_INVOKER_FUNCTION_NAME, inputSymbol, waiterOptionsSymbol, + () -> { + writer.write( + "_, err := w.$L(ctx, params, maxWaitDur, optFns...)", + WAITER_INVOKER_WITH_OUTPUT_FUNCTION_NAME + ); + + writer.write("return err"); + }); + } + + /** + * Generates waiter invoker functions to call specific operation waiters + * and return the output of the successful operation. + * These waiter invoker functions is defined on each modeled waiter client. + * The invoker function takes in a context, along with operation input, and + * optional functional options for the waiter. + */ + private void generateWaiterInvokerWithOutput( + Model model, + SymbolProvider symbolProvider, + GoWriter writer, + OperationShape operationShape, + String waiterName, + Waiter waiter + ) { + StructureShape inputShape = model.expectShape( + operationShape.getInput().get(), StructureShape.class + ); + + StructureShape outputShape = model.expectShape( + operationShape.getOutput().get(), StructureShape.class + ); + + Symbol operationSymbol = symbolProvider.toSymbol(operationShape); + Symbol inputSymbol = symbolProvider.toSymbol(inputShape); + Symbol outputSymbol = symbolProvider.toSymbol(outputShape); + + Symbol waiterOptionsSymbol = SymbolUtils.createPointableSymbolBuilder( + generateWaiterOptionsName(waiterName) + ).build(); + + Symbol clientSymbol = SymbolUtils.createPointableSymbolBuilder( + generateWaiterClientName(waiterName) + ).build(); + + writer.write(""); + writer.addUseImports(SmithyGoDependency.CONTEXT); + writer.addUseImports(SmithyGoDependency.TIME); + writer.writeDocs( + String.format( + "%s calls the waiter function for %s waiter and returns the output of the successful " + + "operation. The maxWaitDur is the maximum wait duration the waiter will wait. The " + + "maxWaitDur is required and must be greater than zero.", + WAITER_INVOKER_WITH_OUTPUT_FUNCTION_NAME, waiterName) + ); + writer.openBlock( + "func (w $P) $L(ctx context.Context, params $P, maxWaitDur time.Duration, optFns ...func($P)) " + + "($P, error) {", + "}", + clientSymbol, WAITER_INVOKER_WITH_OUTPUT_FUNCTION_NAME, inputSymbol, waiterOptionsSymbol, outputSymbol, + () -> { + writer.openBlock("if maxWaitDur <= 0 {", "}", () -> { + writer.addUseImports(SmithyGoDependency.FMT); + writer.write( + "return nil, fmt.Errorf(\"maximum wait time for waiter must be greater than zero\")" + ); + }).write(""); + + writer.write("options := w.options"); + + writer.openBlock("for _, fn := range optFns {", + "}", () -> { + writer.write("fn(&options)"); + }); + writer.write(""); + + // validate values for MaxDelay from options + writer.openBlock("if options.MaxDelay <= 0 {", "}", () -> { + writer.write("options.MaxDelay = $L * time.Second", waiter.getMaxDelay()); + }); + writer.write(""); + + // validate that MinDelay is lesser than or equal to resolved MaxDelay + writer.openBlock("if options.MinDelay > options.MaxDelay {", "}", () -> { + writer.addUseImports(SmithyGoDependency.FMT); + writer.write("return nil, fmt.Errorf(\"minimum waiter delay %v must be lesser than or equal to " + + "maximum waiter delay of %v.\", options.MinDelay, options.MaxDelay)"); + }).write(""); + + writer.addUseImports(SmithyGoDependency.CONTEXT); + writer.write("ctx, cancelFn := context.WithTimeout(ctx, maxWaitDur)"); + writer.write("defer cancelFn()"); + writer.write(""); + + Symbol loggerMiddleware = SymbolUtils.createValueSymbolBuilder( + "Logger", SmithyGoDependency.SMITHY_WAITERS + ).build(); + writer.write("logger := $T{}", loggerMiddleware); + writer.write("remainingTime := maxWaitDur").write(""); + + writer.write("var attempt int64"); + writer.openBlock("for {", "}", () -> { + writer.write(""); + writer.write("attempt++"); + + writer.write("apiOptions := options.APIOptions"); + writer.write("start := time.Now()").write(""); + + // add waiter logger middleware to log an attempt, if LogWaitAttempts is enabled. + writer.openBlock("if options.LogWaitAttempts {", "}", () -> { + writer.write("logger.Attempt = attempt"); + writer.write( + "apiOptions = append([]func(*middleware.Stack) error{}, options.APIOptions...)"); + writer.write("apiOptions = append(apiOptions, logger.AddLogger)"); + }).write(""); + + // make a request + var baseOpts = GoWriter.ChainWritable.of( + getAdditionalClientOptions().stream() + .map(it -> goTemplate("$T,", it)) + .toList() + ).compose(false); + writer.openBlock("out, err := w.client.$T(ctx, params, func (o *Options) { ", "})", + operationSymbol, () -> { + writer.write(""" + baseOpts := []func(*Options) { + $W + }""", baseOpts); + writer.write("o.APIOptions = append(o.APIOptions, apiOptions...)"); + writer.write(""" + for _, opt := range baseOpts { + opt(o) + } + for _, opt := range options.ClientOptions { + opt(o) + }"""); + }); + writer.write(""); + + // handle response and identify waiter state + writer.write("retryable, err := options.Retryable(ctx, params, out, err)"); + writer.write("if err != nil { return nil, err }"); + writer.write("if !retryable { return out, nil }").write(""); + + // update remaining time + writer.write("remainingTime -= time.Since(start)"); + + // check if next iteration is possible + writer.openBlock("if remainingTime < options.MinDelay || remainingTime <= 0 {", "}", () -> { + writer.write("break"); + }); + writer.write(""); + + // handle retry delay computation, sleep. + Symbol computeDelaySymbol = SymbolUtils.createValueSymbolBuilder( + "ComputeDelay", SmithyGoDependency.SMITHY_WAITERS + ).build(); + writer.writeDocs("compute exponential backoff between waiter retries"); + writer.openBlock("delay, err := $T(", ")", computeDelaySymbol, () -> { + writer.write("attempt, options.MinDelay, options.MaxDelay, remainingTime,"); + }); + + writer.addUseImports(SmithyGoDependency.FMT); + writer.write( + "if err != nil { return nil, fmt.Errorf(\"error computing waiter delay, %w\", err)}"); + writer.write(""); + + // update remaining time as per computed delay + writer.write("remainingTime -= delay"); + + // sleep for delay + Symbol sleepWithContextSymbol = SymbolUtils.createValueSymbolBuilder( + "SleepWithContext", SmithyGoDependency.SMITHY_TIME + ).build(); + writer.writeDocs("sleep for the delay amount before invoking a request"); + writer.openBlock("if err := $T(ctx, delay); err != nil {", "}", sleepWithContextSymbol, + () -> { + writer.write( + "return nil, fmt.Errorf(\"request cancelled while waiting, %w\", err)"); + }); + }); + writer.write("return nil, fmt.Errorf(\"exceeded max wait time for $L waiter\")", waiterName); + }); + } + + /** + * Generates a waiter state mutator function which is used by the waiter retrier Middleware to mutate + * waiter state as per the defined logic and returned operation response. + * + * @param ctx the GoCodegenContext + * @param writer the Gowriter + * @param operationShape operation shape on which the waiter is modeled + * @param waiterName the waiter name + * @param waiter the waiter structure that contains info on modeled waiter + */ + private void generateRetryable( + GoCodegenContext ctx, + GoWriter writer, + OperationShape operationShape, + String waiterName, + Waiter waiter + ) { + var model = ctx.model(); + var symbolProvider = ctx.symbolProvider(); + var serviceShape = ctx.settings().getService(model); + StructureShape inputShape = model.expectShape( + operationShape.getInput().get(), StructureShape.class + ); + StructureShape outputShape = model.expectShape( + operationShape.getOutput().get(), StructureShape.class + ); + + Symbol inputSymbol = symbolProvider.toSymbol(inputShape); + Symbol outputSymbol = symbolProvider.toSymbol(outputShape); + + writer.write(""); + writer.openBlock("func $L(ctx context.Context, input $P, output $P, err error) (bool, error) {", + "}", generateRetryableName(waiterName), inputSymbol, outputSymbol, () -> { + waiter.getAcceptors().forEach(acceptor -> { + writer.write(""); + // scope each acceptor to avoid name collisions + Matcher matcher = acceptor.getMatcher(); + switch (matcher.getMemberName()) { + case "output": + writer.addUseImports(SmithyGoDependency.FMT); + + Matcher.OutputMember outputMember = (Matcher.OutputMember) matcher; + String path = outputMember.getValue().getPath(); + String expectedValue = outputMember.getValue().getExpected(); + PathComparator comparator = outputMember.getValue().getComparator(); + writer.openBlock("if err == nil {", "}", () -> { + var pathInput = new GoJmespathExpressionGenerator.Variable(outputShape, "output"); + var searchResult = new GoJmespathExpressionGenerator(ctx, writer) + .generate(JmespathExpression.parse(path), pathInput); + + writer.write("expectedValue := $S", expectedValue); + writeWaiterComparator(writer, acceptor, comparator, searchResult); + }); + break; + + case "inputOutput": + writer.addUseImports(SmithyGoDependency.GO_JMESPATH); + writer.addUseImports(SmithyGoDependency.FMT); + + Matcher.InputOutputMember ioMember = (Matcher.InputOutputMember) matcher; + path = ioMember.getValue().getPath(); + expectedValue = ioMember.getValue().getExpected(); + comparator = ioMember.getValue().getComparator(); + + // inputOutput matchers operate on a synthetic structure with operation input and output + // as top-level fields - we set that up here both in codegen for jmespathing and for + // the actual generated code to work + var inputOutputShape = StructureShape.builder() + .addMember("input", inputShape.toShapeId()) + .addMember("output", outputShape.toShapeId()) + .build(); + writer.write(""" + inputOutput := struct{ + Input $P + Output $P + }{ + Input: input, + Output: output, + } + """); + + writer.openBlock("if err == nil {", "}", () -> { + var pathInput = new GoJmespathExpressionGenerator.Variable( + inputOutputShape, "inputOutput"); + var searchResult = new GoJmespathExpressionGenerator(ctx, writer) + .generate(JmespathExpression.parse(path), pathInput); + + writer.write("expectedValue := $S", expectedValue); + writeWaiterComparator(writer, acceptor, comparator, searchResult); + }); + break; + + case "success": + Matcher.SuccessMember successMember = (Matcher.SuccessMember) matcher; + writer.openBlock("if err == nil {", "}", + () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + break; + + case "errorType": + Matcher.ErrorTypeMember errorTypeMember = (Matcher.ErrorTypeMember) matcher; + String errorType = errorTypeMember.getValue(); + + writer.openBlock("if err != nil {", "}", () -> { + + // identify if this is a modeled error shape + Optional errorShapeId = operationShape.getErrors().stream().filter( + shapeId -> { + return shapeId.getName(serviceShape).equalsIgnoreCase(errorType); + }).findFirst(); + + // if modeled error shape + if (errorShapeId.isPresent()) { + Shape errorShape = model.expectShape(errorShapeId.get()); + Symbol modeledErrorSymbol = symbolProvider.toSymbol(errorShape); + writer.addUseImports(SmithyGoDependency.ERRORS); + writer.write("var errorType *$T", modeledErrorSymbol); + writer.openBlock("if errors.As(err, &errorType) {", "}", () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + } else { + // fall back to un-modeled error shape matching + writer.addUseImports(SmithyGoDependency.SMITHY); + writer.addUseImports(SmithyGoDependency.ERRORS); + + // assert unmodeled error to smithy's API error + writer.write("var apiErr smithy.APIError"); + writer.write("ok := errors.As(err, &apiErr)"); + writer.openBlock("if !ok {", "}", () -> { + writer.write("return false, " + + "fmt.Errorf(\"expected err to be of type smithy.APIError, " + + "got %w\", err)"); + }); + writer.write(""); + + writer.openBlock("if $S == apiErr.ErrorCode() {", "}", + errorType, () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + } + }); + break; + + default: + throw new CodegenException( + String.format("unknown waiter state : %v", matcher.getMemberName()) + ); + } + }); + + writer.write(""); + writer.write("return true, nil"); + }); + } + + private void writeWaiterComparator(GoWriter writer, Acceptor acceptor, PathComparator comparator, + GoJmespathExpressionGenerator.Variable searchResult) { + switch (comparator) { + case STRING_EQUALS: + writer.write("var pathValue string"); + if (!isPointable(searchResult.type())) { + writer.write("pathValue = string($L)", searchResult.ident()); + } else { + writer.write(""" + if $1L != nil { + pathValue = string(*$1L) + }""", searchResult.ident()); + } + writer.openBlock("if pathValue == expectedValue {", "}", () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + break; + + case BOOLEAN_EQUALS: + writer.addUseImports(SmithyGoDependency.STRCONV); + writer.write("bv, err := strconv.ParseBool($L)", "expectedValue"); + writer.write( + "if err != nil { return false, " + + "fmt.Errorf(\"error parsing boolean from string %w\", err)}"); + + writer.openBlock("if $L == bv {", "}", searchResult.ident(), () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + break; + + case ALL_STRING_EQUALS: + writer.write("match := len($L) > 0", searchResult.ident()); + writer.openBlock("for _, v := range $L {", "}", searchResult.ident(), () -> { + writer.write(""" + if string(v) != expectedValue { + match = false + break + }"""); + }); + writer.write(""); + + writer.openBlock("if match {", "}", () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + break; + + case ANY_STRING_EQUALS: + writer.write("var match bool"); + writer.openBlock("for _, v := range $L {", "}", searchResult.ident(), () -> { + writer.write(""" + if string(v) == expectedValue { + match = true + break + }"""); + }); + writer.write(""); + + writer.openBlock("if match {", "}", () -> { + writeMatchedAcceptorReturn(writer, acceptor); + }); + break; + + default: + throw new CodegenException( + String.format("Found unknown waiter path comparator, %s", comparator.toString())); + } + } + + + /** + * Writes return statement for state where a waiter's acceptor state is a match. + * + * @param writer the Go writer + * @param acceptor the waiter acceptor who's state is used to write an appropriate return statement. + */ + private void writeMatchedAcceptorReturn(GoWriter writer, Acceptor acceptor) { + switch (acceptor.getState()) { + case SUCCESS: + writer.write("return false, nil"); + break; + + case FAILURE: + writer.addUseImports(SmithyGoDependency.FMT); + writer.write("return false, fmt.Errorf(\"waiter state transitioned to Failure\")"); + break; + + case RETRY: + writer.write("return true, nil"); + break; + + default: + throw new CodegenException("unknown acceptor state defined for the waiter"); + } + } + + private String generateWaiterOptionsName( + String waiterName + ) { + waiterName = StringUtils.capitalize(waiterName); + return String.format("%sWaiterOptions", waiterName); + } + + private String generateWaiterClientName( + String waiterName + ) { + waiterName = StringUtils.capitalize(waiterName); + return String.format("%sWaiter", waiterName); + } + + private String generateRetryableName( + String waiterName + ) { + waiterName = StringUtils.uncapitalize(waiterName); + return String.format("%sStateRetryable", waiterName); + } +} diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/util/ShapeUtil.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/util/ShapeUtil.java index 68944603..fc03945a 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/util/ShapeUtil.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/util/ShapeUtil.java @@ -17,8 +17,11 @@ import software.amazon.smithy.codegen.core.CodegenException; import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.BooleanShape; import software.amazon.smithy.model.shapes.CollectionShape; +import software.amazon.smithy.model.shapes.IntegerShape; import software.amazon.smithy.model.shapes.ListShape; +import software.amazon.smithy.model.shapes.MapShape; import software.amazon.smithy.model.shapes.Shape; import software.amazon.smithy.model.shapes.StringShape; @@ -27,6 +30,14 @@ public final class ShapeUtil { .id("smithy.go.synthetic#String") .build(); + public static final IntegerShape INT_SHAPE = IntegerShape.builder() + .id("smithy.api#Integer") + .build(); + + public static final BooleanShape BOOL_SHAPE = BooleanShape.builder() + .id("smithy.api#Boolean") + .build(); + private ShapeUtil() {} public static ListShape listOf(Shape member) { @@ -49,4 +60,8 @@ public static Shape expectMember(Model model, Shape shape, String memberName) { public static Shape expectMember(Model model, CollectionShape shape) { return model.expectShape(shape.getMember().getTarget()); } + + public static Shape expectMember(Model model, MapShape shape) { + return model.expectShape(shape.getValue().getTarget()); + } } diff --git a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java index aecd1060..0c77721d 100644 --- a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java +++ b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java @@ -16,6 +16,7 @@ package software.amazon.smithy.go.codegen; import static org.hamcrest.MatcherAssert.assertThat; +import static software.amazon.smithy.go.codegen.util.ShapeUtil.listOf; import org.hamcrest.Matchers; import org.junit.jupiter.api.Test; @@ -33,8 +34,12 @@ public class GoJmespathExpressionGeneratorTest { namespace smithy.go.test + service Test { + } + structure Struct { simpleShape: String + simpleShape2: String objectList: ObjectList objectMap: ObjectMap nested: NestedStruct @@ -42,6 +47,11 @@ public class GoJmespathExpressionGeneratorTest { structure Object { key: String + innerObjectList: InnerObjectList + } + + structure InnerObject { + innerKey: String } structure NestedStruct { @@ -52,6 +62,10 @@ public class GoJmespathExpressionGeneratorTest { member: Object } + list InnerObjectList { + member: InnerObject + } + map ObjectMap { key: String, value: Object @@ -63,10 +77,14 @@ public class GoJmespathExpressionGeneratorTest { .assemble().unwrap(); private static final GoSettings TEST_SETTINGS = GoSettings.from(ObjectNode.fromStringMap(Map.of( - "service", "smithy.go.test#foo", + "service", "smithy.go.test#Test", "module", "github.com/aws/aws-sdk-go-v2/test" ))); + private static GoWriter testWriter() { + return new GoWriter("test").setIndentText(" "); // for ease of string comparison + } + private static GoCodegenContext testContext() { return new GoCodegenContext( TEST_MODEL, TEST_SETTINGS, @@ -79,17 +97,16 @@ private static GoCodegenContext testContext() { public void testFieldExpression() { var expr = "simpleShape"; - var writer = new GoWriter("test"); - var generator = new GoJmespathExpressionGenerator(testContext(), writer, + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), - JmespathExpression.parse(expr) - ); - var actual = generator.generate("input"); + "input" + )); assertThat(actual.shape(), Matchers.equalTo(TEST_MODEL.expectShape(ShapeId.from("smithy.api#String")))); - assertThat(actual.ident(), Matchers.equalTo("v2")); + assertThat(actual.ident(), Matchers.equalTo("v1")); assertThat(writer.toString(), Matchers.containsString(""" - v1 := input - v2 := v1.SimpleShape + v1 := input.SimpleShape """)); } @@ -97,18 +114,17 @@ public void testFieldExpression() { public void testSubexpression() { var expr = "nested.nestedField"; - var writer = new GoWriter("test"); - var generator = new GoJmespathExpressionGenerator(testContext(), writer, + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), - JmespathExpression.parse(expr) - ); - var actual = generator.generate("input"); + "input" + )); assertThat(actual.shape(), Matchers.equalTo(TEST_MODEL.expectShape(ShapeId.from("smithy.api#String")))); - assertThat(actual.ident(), Matchers.equalTo("v3")); + assertThat(actual.ident(), Matchers.equalTo("v2")); assertThat(writer.toString(), Matchers.containsString(""" - v1 := input - v2 := v1.Nested - v3 := v2.NestedField + v1 := input.Nested + v2 := v1.NestedField """)); } @@ -116,20 +132,19 @@ public void testSubexpression() { public void testKeysFunctionExpression() { var expr = "keys(objectMap)"; - var writer = new GoWriter("test"); - var generator = new GoJmespathExpressionGenerator(testContext(), writer, + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), - JmespathExpression.parse(expr) - ); - var actual = generator.generate("input"); - assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.listOf(ShapeUtil.STRING_SHAPE))); - assertThat(actual.ident(), Matchers.equalTo("v3")); + "input" + )); + assertThat(actual.shape(), Matchers.equalTo(listOf(ShapeUtil.STRING_SHAPE))); + assertThat(actual.ident(), Matchers.equalTo("v2")); assertThat(writer.toString(), Matchers.containsString(""" - v1 := input - v2 := v1.ObjectMap - var v3 []string - for k := range v2 { - v3 = append(v3, k) + v1 := input.ObjectMap + var v2 []string + for k := range v1 { + v2 = append(v2, k) } """)); } @@ -138,24 +153,343 @@ public void testKeysFunctionExpression() { public void testProjectionExpression() { var expr = "objectList[*].key"; - var writer = new GoWriter("test"); - var generator = new GoJmespathExpressionGenerator(testContext(), writer, + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), - JmespathExpression.parse(expr) - ); - var actual = generator.generate("input"); + "input" + )); assertThat(actual.shape(), Matchers.equalTo( - ShapeUtil.listOf(TEST_MODEL.expectShape(ShapeId.from("smithy.api#String"))))); - assertThat(actual.ident(), Matchers.equalTo("v3")); + listOf(TEST_MODEL.expectShape(ShapeId.from("smithy.api#String"))))); + assertThat(actual.ident(), Matchers.equalTo("v2")); assertThat(writer.toString(), Matchers.containsString(""" - v1 := input - v2 := v1.ObjectList - var v3 []*string + v1 := input.ObjectList + var v2 []string + for _, v := range v1 { + v3 := v.Key + if v3 != nil { + v2 = append(v2, *v3) + } + } + """)); + } + + @Test + public void testNopFlattenExpression() { + var expr = "objectList[].key"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo( + listOf(TEST_MODEL.expectShape(ShapeId.from("smithy.api#String"))))); + assertThat(actual.ident(), Matchers.equalTo("v2")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.ObjectList + var v2 []string + for _, v := range v1 { + v3 := v.Key + if v3 != nil { + v2 = append(v2, *v3) + } + } + """)); + } + + @Test + public void testActualFlattenExpression() { + var expr = "objectList[].innerObjectList[].innerKey"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo( + listOf(TEST_MODEL.expectShape(ShapeId.from("smithy.api#String"))))); + assertThat(actual.ident(), Matchers.equalTo("v5")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.ObjectList + var v2 [][]types.InnerObject + for _, v := range v1 { + v3 := v.InnerObjectList + v2 = append(v2, v3) + } + var v4 []types.InnerObject for _, v := range v2 { - v1 := v - v2 := v1.Key - v3 = append(v3, v2) + v4 = append(v4, v...) + } + var v5 []string + for _, v := range v4 { + v6 := v.InnerKey + if v6 != nil { + v5 = append(v5, *v6) + } } """)); } + + @Test + public void testLengthFunctionExpression() { + var expr = "length(objectList)"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.INT_SHAPE)); + assertThat(actual.ident(), Matchers.equalTo("v2")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.ObjectList + v2 := len(v1) + """)); + } + + @Test + public void testLengthFunctionStringPtr() { + var expr = "length(simpleShape)"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.INT_SHAPE)); + assertThat(actual.ident(), Matchers.equalTo("v2")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.SimpleShape + var _v1 string + if v1 != nil { + _v1 = *v1 + } + v2 := len(_v1) + """)); + } + + @Test + public void testComparatorInt() { + var expr = "length(objectList) > `99`"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE)); + assertThat(actual.ident(), Matchers.equalTo("v4")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.ObjectList + v2 := len(v1) + v3 := 99 + v4 := int64(v2) > int64(v3) + """)); + } + + @Test + public void testComparatorStringLHSNil() { + var expr = "nested.nestedField == 'foo'"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE)); + assertThat(actual.ident(), Matchers.equalTo("v4")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.Nested + v2 := v1.NestedField + v3 := "foo" + var v4 bool + if v2 != nil { + v4 = string(*v2) == string(v3) + } + """)); + } + + @Test + public void testComparatorStringRHSNil() { + var expr = "'foo' == nested.nestedField"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE)); + assertThat(actual.ident(), Matchers.equalTo("v4")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := "foo" + v2 := input.Nested + v3 := v2.NestedField + var v4 bool + if v3 != nil { + v4 = string(v1) == string(*v3) + } + """)); + } + + @Test + public void testComparatorStringBothNil() { + var expr = "nested.nestedField == simpleShape"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE)); + assertThat(actual.ident(), Matchers.equalTo("v4")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.Nested + v2 := v1.NestedField + v3 := input.SimpleShape + var v4 bool + if v2 != nil && v3 != nil { + v4 = string(*v2) == string(*v3) + } + """)); + } + + @Test + public void testContainsFunctionExpression() { + var expr = "contains(objectList[].key, 'foo')"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE)); + assertThat(actual.ident(), Matchers.equalTo("v5")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.ObjectList + var v2 []string + for _, v := range v1 { + v3 := v.Key + if v3 != nil { + v2 = append(v2, *v3) + } + } + v4 := "foo" + var v5 bool + for _, v := range v2 { + if v == v4 { + v5 = true + break + } + } + """)); + } + + @Test + public void testAndExpression() { + var expr = "length(objectList) > `0` && length(objectList) <= `10`"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE)); + assertThat(actual.ident(), Matchers.equalTo("v9")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.ObjectList + v2 := len(v1) + v3 := 0 + v4 := int64(v2) > int64(v3) + v5 := input.ObjectList + v6 := len(v5) + v7 := 10 + v8 := int64(v6) <= int64(v7) + v9 := v4 && v8 + """)); + } + + @Test + public void testFilterExpression() { + var expr = "objectList[?length(innerObjectList) > `0`]"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo(TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#ObjectList")))); + assertThat(actual.ident(), Matchers.equalTo("v2")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.ObjectList + var v2 []types.Object + for _, v := range v1 { + v3 := v.InnerObjectList + v4 := len(v3) + v5 := 0 + v6 := int64(v4) > int64(v5) + if v6 { + v2 = append(v2, v) + } + } + """)); + } + + @Test + public void testNot() { + var expr = "objectList[?!(length(innerObjectList) > `0`)]"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape(), Matchers.equalTo(TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#ObjectList")))); + assertThat(actual.ident(), Matchers.equalTo("v2")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.ObjectList + var v2 []types.Object + for _, v := range v1 { + v3 := v.InnerObjectList + v4 := len(v3) + v5 := 0 + v6 := int64(v4) > int64(v5) + v7 := !v6 + if v7 { + v2 = append(v2, v) + } + } + """)); + } + + @Test + public void testMultiSelect() { + var expr = "[simpleShape, simpleShape2]"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.go.synthetic#StringList")); + assertThat(actual.ident(), Matchers.equalTo("v3")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.SimpleShape + v2 := input.SimpleShape2 + v3 := []*string{v1,v2} + """)); + } }