From 705546a3fd4975d917e972b33596796ef2e739c0 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 30 Aug 2023 16:16:47 -0700 Subject: [PATCH] Static optimizer for constant folding (#804) * Optimizer API with Constant Folding implementatiton * Better logical folds and additional tests * Add a configurable limit to constant folding --- cel/BUILD.bazel | 3 + cel/env.go | 8 + cel/folding.go | 553 ++++++++++++++++++++++++++++++++++ cel/folding_test.go | 652 ++++++++++++++++++++++++++++++++++++++++ cel/optimizer.go | 317 +++++++++++++++++++ common/ast/ast.go | 13 +- common/ast/expr.go | 22 +- common/ast/factory.go | 11 +- common/ast/navigable.go | 4 + 9 files changed, 1575 insertions(+), 8 deletions(-) create mode 100644 cel/folding.go create mode 100644 cel/folding_test.go create mode 100644 cel/optimizer.go diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index aa978e06..62b903c8 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -10,9 +10,11 @@ go_library( "cel.go", "decls.go", "env.go", + "folding.go", "io.go", "library.go", "macro.go", + "optimizer.go", "options.go", "program.go", "validator.go", @@ -56,6 +58,7 @@ go_test( "cel_test.go", "decls_test.go", "env_test.go", + "folding_test.go", "io_test.go", "validator_test.go", ], diff --git a/cel/env.go b/cel/env.go index 113b89b5..786a13c4 100644 --- a/cel/env.go +++ b/cel/env.go @@ -43,6 +43,9 @@ type Ast struct { } // Expr returns the proto serializable instance of the parsed/checked expression. +// +// Deprecated: prefer cel.AstToCheckedExpr() or cel.AstToParsedExpr() and call GetExpr() +// the result instead. func (ast *Ast) Expr() *exprpb.Expr { if ast == nil { return nil @@ -221,6 +224,11 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { source: ast.Source(), impl: checked} + // Avoid creating a validator config if it's not needed. + if len(e.validators) == 0 { + return ast, nil + } + // Generate a validator configuration from the set of configured validators. vConfig := newValidatorConfig() for _, v := range e.validators { diff --git a/cel/folding.go b/cel/folding.go new file mode 100644 index 00000000..5903d732 --- /dev/null +++ b/cel/folding.go @@ -0,0 +1,553 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +// ConstantFoldingOption defines a functional option for configuring constant folding. +type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) + +// MaxConstantFoldIterations limits the number of times literals may be folding during optimization. +// +// Defaults to 100 if not set. +func MaxConstantFoldIterations(limit int) ConstantFoldingOption { + return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) { + opt.maxFoldIterations = limit + return opt, nil + } +} + +// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate +// literal values within function calls and select statements with their evaluated result. +func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) { + folder := &constantFoldingOptimizer{ + maxFoldIterations: defaultMaxConstantFoldIterations, + } + var err error + for _, o := range opts { + folder, err = o(folder) + if err != nil { + return nil, err + } + } + return folder, nil +} + +type constantFoldingOptimizer struct { + maxFoldIterations int +} + +// Optimize queries the expression graph for scalar and aggregate literal expressions within call and +// select statements and then evaluates them and replaces the call site with the literal result. +// +// Note: only values which can be represented as literals in CEL syntax are supported. +func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + + // Walk the list of foldable expression and continue to fold until there are no more folds left. + // All of the fold candidates returned by the constantExprMatcher should succeed unless there's + // a logic bug with the selection of expressions. + foldableExprs := ast.MatchDescendants(root, constantExprMatcher) + foldCount := 0 + for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations { + for _, fold := range foldableExprs { + // If the expression could be folded because it's a non-strict call, and the + // branches are pruned, continue to the next fold. + if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) { + continue + } + // Otherwise, assume all context is needed to evaluate the expression. + err := tryFold(ctx, a, fold) + if err != nil { + ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error()) + return a + } + } + foldCount++ + foldableExprs = ast.MatchDescendants(root, constantExprMatcher) + } + // Once all of the constants have been folded, try to run through the remaining comprehensions + // one last time. In this case, there's no guarantee they'll run, so we only update the + // target comprehension node with the literal value if the evaluation succeeds. + for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) { + tryFold(ctx, a, compre) + } + + // If the output is a list, map, or struct which contains optional entries, then prune it + // to make sure that the optionals, if resolved, do not surface in the output literal. + pruneOptionalElements(ctx, root) + + // Ensure that all intermediate values in the folded expression can be represented as valid + // CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents` + // to avoid extra allocations during this final pass through the AST. + ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() != ast.LiteralKind { + return + } + val := e.AsLiteral() + adapted, err := adaptLiteral(ctx, val) + if err != nil { + ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error()) + return + } + e.SetKindCase(adapted) + })) + + return a +} + +// tryFold attempts to evaluate a sub-expression to a literal. +// +// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise +// the method will return an error. +func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error { + // Assume all context is needed to evaluate the expression. + subAST := &Ast{ + impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()), + } + prg, err := ctx.Program(subAST) + if err != nil { + return err + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + return err + } + // Clear any macro metadata associated with the fold. + a.SourceInfo().ClearMacroCall(expr.ID()) + // Update the fold expression to be a literal. + expr.SetKindCase(ctx.NewLiteral(out)) + return nil +} + +// maybePruneBranches inspects the non-strict call expression to determine whether +// a branch can be removed. Evaluation will naturally prune logical and / or calls, +// but conditional will not be pruned cleanly, so this is one small area where the +// constant folding step reimplements a portion of the evaluator. +func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool { + call := expr.AsCall() + args := call.Args() + switch call.FunctionName() { + case operators.LogicalAnd, operators.LogicalOr: + return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr) + case operators.Conditional: + cond := args[0] + truthy := args[1] + falsy := args[2] + if cond.Kind() != ast.LiteralKind { + return false + } + if cond.AsLiteral() == types.True { + expr.SetKindCase(truthy) + } else { + expr.SetKindCase(falsy) + } + return true + case operators.In: + haystack := args[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + expr.SetKindCase(ctx.NewLiteral(types.False)) + return true + } + needle := args[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + expr.SetKindCase(ctx.NewLiteral(types.True)) + return true + } + } + } + } + return false +} + +func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool { + shortcircuit := types.False + skip := types.True + if function == operators.LogicalOr { + shortcircuit = types.True + skip = types.False + } + newArgs := []ast.Expr{} + for _, arg := range args { + if arg.Kind() != ast.LiteralKind { + newArgs = append(newArgs, arg) + continue + } + if arg.AsLiteral() == skip { + continue + } + if arg.AsLiteral() == shortcircuit { + expr.SetKindCase(arg) + return true + } + } + if len(newArgs) == 1 { + expr.SetKindCase(newArgs[0]) + return true + } + expr.SetKindCase(ctx.NewCall(function, newArgs...)) + return true +} + +// pruneOptionalElements works from the bottom up to resolve optional elements within +// aggregate literals. +// +// Note, many aggregate literals will be resolved as arguments to functions or select +// statements, so this method exists to handle the case where the literal could not be +// fully resolved or exists outside of a call, select, or comprehension context. +func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) { + aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher) + for _, lit := range aggregateLiterals { + switch lit.Kind() { + case ast.ListKind: + pruneOptionalListElements(ctx, lit) + case ast.MapKind: + pruneOptionalMapEntries(ctx, lit) + case ast.StructKind: + pruneOptionalStructFields(ctx, lit) + } + } +} + +func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) { + l := e.AsList() + elems := l.Elements() + optIndices := l.OptionalIndices() + if len(optIndices) == 0 { + return + } + updatedElems := []ast.Expr{} + updatedIndices := []int32{} + for i, e := range elems { + if !l.IsOptional(int32(i)) { + updatedElems = append(updatedElems, e) + continue + } + if e.Kind() != ast.LiteralKind { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(i)) + continue + } + optElemVal, ok := e.AsLiteral().(*types.Optional) + if !ok { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(i)) + continue + } + if !optElemVal.HasValue() { + continue + } + e.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedElems = append(updatedElems, e) + } + e.SetKindCase(ctx.NewList(updatedElems, updatedIndices)) +} + +func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) { + m := e.AsMap() + entries := m.Entries() + updatedEntries := []ast.EntryExpr{} + modified := false + for _, e := range entries { + entry := e.AsMapEntry() + key := entry.Key() + val := entry.Value() + // If the entry is not optional, or the value-side of the optional hasn't + // been resolved to a literal, then preserve the entry as-is. + if !entry.IsOptional() || val.Kind() != ast.LiteralKind { + updatedEntries = append(updatedEntries, e) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedEntries = append(updatedEntries, e) + continue + } + // When the key is not a literal, but the value is, then it needs to be + // restored to an optional value. + if key.Kind() != ast.LiteralKind { + undoOptVal, err := adaptLiteral(ctx, optElemVal) + if err != nil { + ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err) + } + val.SetKindCase(undoOptVal) + updatedEntries = append(updatedEntries, e) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedEntry := ctx.NewMapEntry(key, val, false) + updatedEntries = append(updatedEntries, updatedEntry) + } + if modified { + e.SetKindCase(ctx.NewMap(updatedEntries)) + } +} + +func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) { + s := e.AsStruct() + fields := s.Fields() + updatedFields := []ast.EntryExpr{} + modified := false + for _, f := range fields { + field := f.AsStructField() + val := field.Value() + if !field.IsOptional() || val.Kind() != ast.LiteralKind { + updatedFields = append(updatedFields, f) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedFields = append(updatedFields, f) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedField := ctx.NewStructField(field.Name(), val, false) + updatedFields = append(updatedFields, updatedField) + } + if modified { + e.SetKindCase(ctx.NewStruct(s.TypeName(), updatedFields)) + } +} + +// adaptLiteral converts a runtime CEL value to its equivalent literal expression. +// +// For strongly typed values, the type-provider will be used to reconstruct the fields +// which are present in the literal and their equivalent initialization values. +func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) { + switch t := val.Type().(type) { + case *types.Type: + switch t { + case types.BoolType, types.BytesType, types.DoubleType, types.IntType, + types.NullType, types.StringType, types.UintType: + return ctx.NewLiteral(val), nil + case types.DurationType: + return ctx.NewCall( + overloads.TypeConvertDuration, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.TimestampType: + return ctx.NewCall( + overloads.TypeConvertTimestamp, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.OptionalType: + opt := val.(*types.Optional) + if !opt.HasValue() { + return ctx.NewCall("optional.none"), nil + } + target, err := adaptLiteral(ctx, opt.GetValue()) + if err != nil { + return nil, err + } + return ctx.NewCall("optional.of", target), nil + case types.TypeType: + return ctx.NewIdent(val.(*types.Type).TypeName()), nil + case types.ListType: + l, ok := val.(traits.Lister) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + elems := make([]ast.Expr, l.Size().(types.Int)) + idx := 0 + it := l.Iterator() + for it.HasNext() == types.True { + elemVal := it.Next() + elemExpr, err := adaptLiteral(ctx, elemVal) + if err != nil { + return nil, err + } + elems[idx] = elemExpr + idx++ + } + return ctx.NewList(elems, []int32{}), nil + case types.MapType: + m, ok := val.(traits.Mapper) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + entries := make([]ast.EntryExpr, m.Size().(types.Int)) + idx := 0 + it := m.Iterator() + for it.HasNext() == types.True { + keyVal := it.Next() + keyExpr, err := adaptLiteral(ctx, keyVal) + if err != nil { + return nil, err + } + valVal := m.Get(keyVal) + valExpr, err := adaptLiteral(ctx, valVal) + if err != nil { + return nil, err + } + entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false) + idx++ + } + return ctx.NewMap(entries), nil + default: + provider := ctx.CELTypeProvider() + fields, found := provider.FindStructFieldNames(t.TypeName()) + if !found { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + tester := val.(traits.FieldTester) + indexer := val.(traits.Indexer) + fieldInits := []ast.EntryExpr{} + for _, f := range fields { + field := types.String(f) + if tester.IsSet(field) != types.True { + continue + } + fieldVal := indexer.Get(field) + fieldExpr, err := adaptLiteral(ctx, fieldVal) + if err != nil { + return nil, err + } + fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false)) + } + return ctx.NewStruct(t.TypeName(), fieldInits), nil + } + } + return nil, fmt.Errorf("failed to adapt %v to literal", val) +} + +// constantExprMatcher matches calls, select statements, and comprehensions whose arguments +// are all constant scalar or aggregate literal values. +// +// Only comprehensions which are not nested are included as possible constant folds, and only +// if all variables referenced in the comprehension stack exist are only iteration or +// accumulation variables. +func constantExprMatcher(e ast.NavigableExpr) bool { + switch e.Kind() { + case ast.CallKind: + return constantCallMatcher(e) + case ast.SelectKind: + sel := e.AsSelect() // guaranteed to be a navigable value + return constantMatcher(sel.Operand().(ast.NavigableExpr)) + case ast.ComprehensionKind: + if isNestedComprehension(e) { + return false + } + vars := map[string]bool{} + constantExprs := true + visitor := ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() == ast.ComprehensionKind { + nested := e.AsComprehension() + vars[nested.AccuVar()] = true + vars[nested.IterVar()] = true + } + if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] { + constantExprs = false + } + }) + ast.PreOrderVisit(e, visitor) + return constantExprs + default: + return false + } +} + +// constantCallMatcher identifies strict and non-strict calls which can be folded. +func constantCallMatcher(e ast.NavigableExpr) bool { + call := e.AsCall() + children := e.Children() + fnName := call.FunctionName() + if fnName == operators.LogicalAnd { + for _, child := range children { + if child.Kind() == ast.LiteralKind { + return true + } + } + } + if fnName == operators.LogicalOr { + for _, child := range children { + if child.Kind() == ast.LiteralKind { + return true + } + } + } + if fnName == operators.Conditional { + cond := children[0] + if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType { + return true + } + } + if fnName == operators.In { + haystack := children[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + return true + } + needle := children[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + return true + } + } + } + } + // convert all other calls with constant arguments + for _, child := range children { + if !constantMatcher(child) { + return false + } + } + return true +} + +func isNestedComprehension(e ast.NavigableExpr) bool { + parent, found := e.Parent() + for found { + if parent.Kind() == ast.ComprehensionKind { + return true + } + parent, found = parent.Parent() + } + return false +} + +func aggregateLiteralMatcher(e ast.NavigableExpr) bool { + return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind +} + +var ( + constantMatcher = ast.ConstantValueMatcher() +) + +const ( + defaultMaxConstantFoldIterations = 100 +) diff --git a/cel/folding_test.go b/cel/folding_test.go new file mode 100644 index 00000000..18fe4d58 --- /dev/null +++ b/cel/folding_test.go @@ -0,0 +1,652 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "reflect" + "sort" + "testing" + + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" + + "github.com/google/cel-go/common/ast" + + proto3pb "github.com/google/cel-go/test/proto3pb" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +func TestConstantFoldingOptimizer(t *testing.T) { + tests := []struct { + expr string + folded string + }{ + { + expr: `[1, 1 + 2, 1 + (2 + 3)]`, + folded: `[1, 3, 6]`, + }, + { + expr: `6 in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `true`, + }, + { + expr: `5 in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `false`, + }, + { + expr: `x in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `x in [1, 3, 6]`, + }, + { + expr: `1 in [1, x + 2, 1 + (2 + 3)]`, + folded: `true`, + }, + { + expr: `1 in [x, x + 2, 1 + (2 + 3)]`, + folded: `1 in [x, x + 2, 6]`, + }, + { + expr: `x in []`, + folded: `false`, + }, + { + expr: `{'hello': 'world'}.hello == x`, + folded: `"world" == x`, + }, + { + expr: `{'hello': 'world'}.?hello.orValue('default') == x`, + folded: `"world" == x`, + }, + { + expr: `{'hello': 'world'}['hello'] == x`, + folded: `"world" == x`, + }, + { + expr: `optional.of("hello")`, + folded: `optional.of("hello")`, + }, + { + expr: `optional.ofNonZeroValue("")`, + folded: `optional.none()`, + }, + { + expr: `{?'hello': optional.of('world')}['hello'] == x`, + folded: `"world" == x`, + }, + { + expr: `duration(string(7 * 24) + 'h')`, + folded: `duration("604800s")`, + }, + { + expr: `timestamp("1970-01-01T00:00:00Z")`, + folded: `timestamp("1970-01-01T00:00:00Z")`, + }, + { + expr: `[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 10)`, + folded: `true`, + }, + { + expr: `[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 1 % 2)`, + folded: `false`, + }, + { + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j))`, + folded: `[[1, 2, 3], [2, 4, 6], [3, 6, 9]]`, + }, + { + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == 0))`, + folded: `[[2], [2, 4, 6], [6]]`, + }, + { + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == x))`, + folded: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == x))`, + }, + { + expr: `[{}, {"a": 1}, {"b": 2}].filter(m, has(m.a))`, + folded: `[{"a": 1}]`, + }, + { + expr: `[{}, {"a": 1}, {"b": 2}].filter(m, has({'a': true}.a))`, + folded: `[{}, {"a": 1}, {"b": 2}]`, + }, + { + expr: `type(1)`, + folded: `int`, + }, + { + expr: `[google.expr.proto3.test.TestAllTypes{single_int32: 2 + 3}].map(i, i)[0]`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: 5}`, + }, + { + expr: `[1, ?optional.ofNonZeroValue(0)]`, + folded: `[1]`, + }, + { + expr: `[1, x, ?optional.ofNonZeroValue(3), ?x.?y]`, + folded: `[1, x, 3, ?x.?y]`, + }, + { + expr: `[1, x, ?optional.ofNonZeroValue(3), ?x.?y].size() > 3`, + folded: `[1, x, 3, ?x.?y].size() > 3`, + }, + { + expr: `{?'a': optional.of('hello'), ?x : optional.of(1), ?'b': optional.none()}`, + folded: `{"a": "hello", ?x: optional.of(1)}`, + }, + { + expr: `true ? x + 1 : x + 2`, + folded: `x + 1`, + }, + { + expr: `false ? x + 1 : x + 2`, + folded: `x + 2`, + }, + { + expr: `false ? x + 'world' : 'hello' + 'world'`, + folded: `"helloworld"`, + }, + { + expr: `true && x`, + folded: `x`, + }, + { + expr: `x && true`, + folded: `x`, + }, + { + expr: `false && x`, + folded: `false`, + }, + { + expr: `x && false`, + folded: `false`, + }, + { + expr: `true || x`, + folded: `true`, + }, + { + expr: `x || true`, + folded: `true`, + }, + { + expr: `false || x`, + folded: `x`, + }, + { + expr: `x || false`, + folded: `x`, + }, + { + expr: `true && x && true && x`, + folded: `x && x`, + }, + { + expr: `false || x || false || x`, + folded: `x || x`, + }, + { + expr: `null`, + folded: `null`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{?single_int32: optional.ofNonZeroValue(1)}`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: 1}`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{?single_int32: optional.ofNonZeroValue(0)}`, + folded: `google.expr.proto3.test.TestAllTypes{}`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}`, + }, + { + expr: `x + dyn([1, 2] + [3, 4])`, + folded: `x + [1, 2, 3, 4]`, + }, + { + expr: `dyn([1, 2]) + [3.0, 4.0]`, + folded: `[1, 2, 3.0, 4.0]`, + }, + { + expr: `{'a': dyn([1, 2]), 'b': x}`, + folded: `{"a": [1, 2], "b": x}`, + }, + { + expr: `1 + x + 2 == 2 + x + 1`, + folded: `1 + x + 2 == 2 + x + 1`, + }, + { + // The order of operations makes it such that the appearance of x in the first means that + // none of the values provided into the addition call will be folded with the current + // implementation. Ideally, the result would be 3 + x == x + 3 (which could be trivially true + // and more easily observed as a result of common subexpression eliminiation) + expr: `1 + 2 + x == x + 2 + 1`, + folded: `3 + x == x + 2 + 1`, + }, + } + e, err := NewEnv( + OptionalTypes(), + EnableMacroCallTracking(), + Types(&proto3pb.TestAllTypes{}), + Variable("x", DynType)) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + folder, err := NewConstantFoldingOptimizer() + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt := NewStaticOptimizer(folder) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := AstToString(optimized) + if err != nil { + t.Fatalf("AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} + +func TestConstantFoldingOptimizerWithLimit(t *testing.T) { + tests := []struct { + expr string + limit int + folded string + }{ + { + expr: `[1, 1 + 2, 1 + (2 + 3)]`, + limit: 1, + folded: `[1, 3, 1 + 5]`, + }, + { + expr: `5 in [1, 1 + 2, 1 + (2 + 3)]`, + limit: 2, + folded: `5 in [1, 3, 6]`, + }, + { + // though more complex, the final tryFold() at the end of the optimization pass + // results in this computed output. + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j))`, + limit: 1, + folded: `[[1, 2, 3], [2, 4, 6], [3, 6, 9]]`, + }, + } + e, err := NewEnv( + OptionalTypes(), + EnableMacroCallTracking(), + Types(&proto3pb.TestAllTypes{}), + Variable("x", DynType)) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + folder, err := NewConstantFoldingOptimizer(MaxConstantFoldIterations(tc.limit)) + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt := NewStaticOptimizer(folder) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := AstToString(optimized) + if err != nil { + t.Fatalf("AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} + +func TestConstantFoldingNormalizeIDs(t *testing.T) { + tests := []struct { + expr string + ids []int64 + macros map[int64]string + normalizedIDs []int64 + normalizedMacros map[int64]string + }{ + { + expr: `[1, 2, 3]`, + ids: []int64{1, 2, 3, 4}, + normalizedIDs: []int64{1, 2, 3, 4}, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{single_int32: 0}`, + ids: []int64{1, 2, 3}, + normalizedIDs: []int64{1, 2, 3}, + }, + { + expr: `has({x: 'value'}.single_int32)`, + ids: []int64{2, 3, 4, 5, 7}, + macros: map[int64]string{7: ` + call_expr: { + function: "has" + args: { + id: 6 + select_expr: { + operand: { + id: 2 + struct_expr: { + entries: { + id: 3 + map_key: { + id: 4 + ident_expr: { + name: "x" + } + } + value: { + id: 5 + const_expr: { + string_value: "value" + } + } + } + } + } + field: "single_int32" + } + } + }`}, + normalizedIDs: []int64{1, 2, 3, 4, 5}, + normalizedMacros: map[int64]string{1: ` + call_expr: { + function: "has" + args: { + id: 6 + select_expr: { + operand: { + id: 2 + struct_expr: { + entries: { + id: 3 + map_key: { + id: 4 + ident_expr: { + name: "x" + } + } + value: { + id: 5 + const_expr: { + string_value: "value" + } + } + } + } + } + field: "single_int32" + } + } + }`, + }, + }, + { + expr: `has(google.expr.proto3.test.TestAllTypes{}.single_int32)`, + ids: []int64{2, 4}, + macros: map[int64]string{ + 4: `call_expr: { + function: "has" + args: { + id: 3 + select_expr: { + operand: { + id: 2 + struct_expr: { + message_name: "google.expr.proto3.test.TestAllTypes" + } + } + field: "single_int32" + } + } + }`, + }, + normalizedIDs: []int64{1}, + }, + { + expr: `[true].exists(i, i)`, + ids: []int64{1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + macros: map[int64]string{ + 13: `call_expr: { + target: { + id: 1 + list_expr: { + elements: { + id: 2 + const_expr: { + bool_value: true + } + } + } + } + function: "exists" + args: { + id: 4 + ident_expr: { + name: "i" + } + } + args: { + id: 5 + ident_expr: { + name: "i" + } + } + }`, + }, + normalizedIDs: []int64{1}, + }, + { + expr: `[x].exists(i, i)`, + ids: []int64{1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + macros: map[int64]string{ + 13: `call_expr: { + target: { + id: 1 + list_expr: { + elements: { + id: 2 + ident_expr: { + name: "x" + } + } + } + } + function: "exists" + args: { + id: 4 + ident_expr: { + name: "i" + } + } + args: { + id: 5 + ident_expr: { + name: "i" + } + } + }`, + }, + normalizedIDs: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + normalizedMacros: map[int64]string{ + 1: `call_expr: { + target: { + id: 2 + list_expr: { + elements: { + id: 3 + ident_expr: { + name: "x" + } + } + } + } + function: "exists" + args: { + id: 12 + ident_expr: { + name: "i" + } + } + args: { + id: 10 + ident_expr: { + name: "i" + } + } + }`, + }, + }, + } + e, err := NewEnv( + EnableMacroCallTracking(), + Types(&proto3pb.TestAllTypes{}), + Variable("x", DynType)) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + preOpt := newIDCollector() + ast.PostOrderVisit(checked.impl.Expr(), preOpt) + if !reflect.DeepEqual(preOpt.IDs(), tc.ids) { + t.Errorf("Compile() got ids %v, expected %v", preOpt.IDs(), tc.ids) + } + for id, call := range checked.impl.SourceInfo().MacroCalls() { + macroText, found := tc.macros[id] + if !found { + t.Fatalf("Compile() did not find macro %d", id) + } + pbCall, err := ast.ExprToProto(call) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } + pbMacro := &exprpb.Expr{} + err = prototext.Unmarshal([]byte(macroText), pbMacro) + if err != nil { + t.Fatalf("prototext.Unmarshal() failed: %v", err) + } + if !proto.Equal(pbCall, pbMacro) { + t.Errorf("Compile() for macro %d got %s, expected %s", id, prototext.Format(pbCall), macroText) + } + } + folder, err := NewConstantFoldingOptimizer() + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt := NewStaticOptimizer(folder) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + postOpt := newIDCollector() + ast.PostOrderVisit(optimized.impl.Expr(), postOpt) + if !reflect.DeepEqual(postOpt.IDs(), tc.normalizedIDs) { + t.Errorf("Optimize() got ids %v, expected %v", postOpt.IDs(), tc.normalizedIDs) + } + for id, call := range optimized.impl.SourceInfo().MacroCalls() { + macroText, found := tc.normalizedMacros[id] + if !found { + t.Fatalf("Optimize() did not find macro %d", id) + } + pbCall, err := ast.ExprToProto(call) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } + pbMacro := &exprpb.Expr{} + err = prototext.Unmarshal([]byte(macroText), pbMacro) + if err != nil { + t.Fatalf("prototext.Unmarshal() failed: %v", err) + } + if !proto.Equal(pbCall, pbMacro) { + t.Errorf("Optimize() for macro %d got %s, expected %s", id, prototext.Format(pbCall), macroText) + } + } + }) + } +} + +func newIDCollector() *idCollector { + return &idCollector{ + ids: int64Slice{}, + } +} + +type idCollector struct { + ids int64Slice +} + +func (c *idCollector) VisitExpr(e ast.Expr) { + if e.ID() == 0 { + return + } + c.ids = append(c.ids, e.ID()) +} + +// VisitEntryExpr updates the max identifier if the incoming entry id is greater than previously observed. +func (c *idCollector) VisitEntryExpr(e ast.EntryExpr) { + if e.ID() == 0 { + return + } + c.ids = append(c.ids, e.ID()) +} + +func (c *idCollector) IDs() []int64 { + sort.Sort(c.ids) + return c.ids +} + +// int64Slice is an implementation of the sort.Interface +type int64Slice []int64 + +// Len returns the number of elements in the slice. +func (x int64Slice) Len() int { return len(x) } + +// Less indicates whether the value at index i is less than the value at index j. +func (x int64Slice) Less(i, j int) bool { return x[i] < x[j] } + +// Swap swaps the values at indices i and j in place. +func (x int64Slice) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +// Sort is a convenience method: x.Sort() calls Sort(x). +func (x int64Slice) Sort() { sort.Sort(x) } diff --git a/cel/optimizer.go b/cel/optimizer.go new file mode 100644 index 00000000..a2023ed7 --- /dev/null +++ b/cel/optimizer.go @@ -0,0 +1,317 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types/ref" +) + +// StaticOptimizer contains a sequence of ASTOptimizer instances which will be applied in order. +// +// The static optimizer normalizes expression ids and type-checking run between optimization +// passes to ensure that the final optimized output is a valid expression with metadata consistent +// with what would have been generated from a parsed and checked expression. +// +// Note: source position information is best-effort and likely wrong, but optimized expressions +// should be suitable for calls to parser.Unparse. +type StaticOptimizer struct { + optimizers []ASTOptimizer +} + +// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied +// to a checked expression. +func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer { + return &StaticOptimizer{ + optimizers: optimizers, + } +} + +// Optimize applies a sequence of optimizations to an Ast within a given environment. +// +// If issues are encountered, the Issues.Err() return value will be non-nil. +func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { + // Make a copy of the AST to be optimized. + optimized := ast.Copy(a.impl) + + // Create the optimizer context, could be pooled in the future. + issues := NewIssues(common.NewErrors(a.Source())) + ids := newMonotonicIDGen(ast.MaxID(a.impl)) + fac := &optimizerExprFactory{ + nextID: ids.nextID, + renumberID: ids.renumberID, + fac: ast.NewExprFactory(), + sourceInfo: optimized.SourceInfo(), + } + ctx := &OptimizerContext{ + optimizerExprFactory: fac, + Env: env, + Issues: issues, + } + + // Apply the optimizations sequentially. + for _, o := range opt.optimizers { + optimized = o.Optimize(ctx, optimized) + if issues.Err() != nil { + return nil, issues + } + // Normalize expression id metadata including coordination with macro call metadata. + normalizeIDs(env, optimized) + + // Recheck the updated expression for any possible type-agreement or validation errors. + parsed := &Ast{ + source: a.Source(), + impl: ast.NewAST(optimized.Expr(), optimized.SourceInfo())} + checked, iss := ctx.Check(parsed) + if iss.Err() != nil { + return nil, iss + } + optimized = checked.impl + } + + // Return the optimized result. + return &Ast{ + source: a.Source(), + impl: optimized, + }, nil +} + +// normalizeIDs ensures that the metadata present with an AST is reset in a manner such +// that the ids within the expression correspond to the ids within macros. +func normalizeIDs(e *Env, optimized *ast.AST) { + ids := newStableIDGen() + optimized.Expr().RenumberIDs(ids.renumberID) + allExprMap := make(map[int64]ast.Expr) + ast.PostOrderVisit(optimized.Expr(), ast.NewExprVisitor(func(e ast.Expr) { + allExprMap[e.ID()] = e + })) + info := optimized.SourceInfo() + + // First, update the macro call ids themselves. + for id, call := range info.MacroCalls() { + info.ClearMacroCall(id) + callID := ids.renumberID(id) + if e, found := allExprMap[callID]; found && e.Kind() == ast.LiteralKind { + continue + } + info.SetMacroCall(callID, call) + } + + // Second, update the macro call id references to ensure that macro pointers are' + // updated consistently across macros. + for id, call := range info.MacroCalls() { + call.RenumberIDs(ids.renumberID) + resetMacroCall(optimized, call, allExprMap) + info.SetMacroCall(id, call) + } +} + +func resetMacroCall(optimized *ast.AST, call ast.Expr, allExprMap map[int64]ast.Expr) { + modified := []ast.Expr{} + ast.PostOrderVisit(call, ast.NewExprVisitor(func(e ast.Expr) { + if _, found := allExprMap[e.ID()]; found { + modified = append(modified, e) + } + })) + for _, m := range modified { + updated := allExprMap[m.ID()] + m.SetKindCase(updated) + } +} + +// newMonotonicIDGen increments numbers from an initial seed value. +func newMonotonicIDGen(seed int64) *monotonicIDGenerator { + return &monotonicIDGenerator{seed: seed} +} + +type monotonicIDGenerator struct { + seed int64 +} + +func (gen *monotonicIDGenerator) nextID() int64 { + gen.seed++ + return gen.seed +} + +func (gen *monotonicIDGenerator) renumberID(int64) int64 { + return gen.nextID() +} + +// newStableIDGen ensures that new ids are only created the first time they are encountered. +func newStableIDGen() *stableIDGenerator { + return &stableIDGenerator{ + idMap: make(map[int64]int64), + } +} + +type stableIDGenerator struct { + idMap map[int64]int64 + nextID int64 +} + +func (gen *stableIDGenerator) renumberID(id int64) int64 { + if id == 0 { + return 0 + } + if newID, found := gen.idMap[id]; found { + return newID + } + gen.nextID++ + gen.idMap[id] = gen.nextID + return gen.nextID +} + +// OptimizerContext embeds Env and Issues instances to make it easy to type-check and evaluate +// subexpressions and report any errors encountered along the way. The context also embeds the +// optimizerExprFactory which can be used to generate new sub-expressions with expression ids +// consistent with the expectations of a parsed expression. +type OptimizerContext struct { + *Env + *optimizerExprFactory + *Issues +} + +// ASTOptimizer applies an optimization over an AST and returns the optimized result. +type ASTOptimizer interface { + // Optimize optimizes a type-checked AST within an Environment and accumulates any issues. + Optimize(*OptimizerContext, *ast.AST) *ast.AST +} + +type optimizerExprFactory struct { + nextID func() int64 + renumberID ast.IDGenerator + fac ast.ExprFactory + sourceInfo *ast.SourceInfo +} + +// NewCall creates a global function call invocation expression. +// +// Example: +// +// countByField(list, fieldName) +// - function: countByField +// - args: [list, fieldName] +func (opt *optimizerExprFactory) NewCall(function string, args ...ast.Expr) ast.Expr { + return opt.fac.NewCall(opt.nextID(), function, args...) +} + +// NewMemberCall creates a member function call invocation expression where 'target' is the receiver of the call. +// +// Example: +// +// list.countByField(fieldName) +// - function: countByField +// - target: list +// - args: [fieldName] +func (opt *optimizerExprFactory) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr { + return opt.fac.NewMemberCall(opt.nextID(), function, target, args...) +} + +// NewIdent creates a new identifier expression. +// +// Examples: +// +// - simple_var_name +// - qualified.subpackage.var_name +func (opt *optimizerExprFactory) NewIdent(name string) ast.Expr { + return opt.fac.NewIdent(opt.nextID(), name) +} + +// NewLiteral creates a new literal expression value. +// +// The range of valid values for a literal generated during optimization is different than for expressions +// generated via parsing / type-checking, as the ref.Val may be _any_ CEL value so long as the value can +// be converted back to a literal-like form. +func (opt *optimizerExprFactory) NewLiteral(value ref.Val) ast.Expr { + return opt.fac.NewLiteral(opt.nextID(), value) +} + +// NewList creates a list expression with a set of optional indices. +// +// Examples: +// +// [a, b] +// - elems: [a, b] +// - optIndices: [] +// +// [a, ?b, ?c] +// - elems: [a, b, c] +// - optIndices: [1, 2] +func (opt *optimizerExprFactory) NewList(elems []ast.Expr, optIndices []int32) ast.Expr { + return opt.fac.NewList(opt.nextID(), elems, optIndices) +} + +// NewMap creates a map from a set of entry expressions which contain a key and value expression. +func (opt *optimizerExprFactory) NewMap(entries []ast.EntryExpr) ast.Expr { + return opt.fac.NewMap(opt.nextID(), entries) +} + +// NewMapEntry creates a map entry with a key and value expression and a flag to indicate whether the +// entry is optional. +// +// Examples: +// +// {a: b} +// - key: a +// - value: b +// - optional: false +// +// {?a: ?b} +// - key: a +// - value: b +// - optional: true +func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional) +} + +// NewSelect creates a select expression where a field value is selected from an operand. +// +// Example: +// +// msg.field_name +// - operand: msg +// - field: field_name +func (opt *optimizerExprFactory) NewSelect(operand ast.Expr, field string) ast.Expr { + return opt.fac.NewSelect(opt.nextID(), operand, field) +} + +// NewStruct creates a new typed struct value with an set of field initializations. +// +// Example: +// +// pkg.TypeName{field: value} +// - typeName: pkg.TypeName +// - fields: [{field: value}] +func (opt *optimizerExprFactory) NewStruct(typeName string, fields []ast.EntryExpr) ast.Expr { + return opt.fac.NewStruct(opt.nextID(), typeName, fields) +} + +// NewStructField creates a struct field initialization. +// +// Examples: +// +// {count: 3u} +// - field: count +// - value: 3u +// - optional: false +// +// {?count: x} +// - field: count +// - value: x +// - optional: true +func (opt *optimizerExprFactory) NewStructField(field string, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewStructField(opt.nextID(), field, value, isOptional) +} diff --git a/common/ast/ast.go b/common/ast/ast.go index 7610b467..c3620eb9 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -249,10 +249,16 @@ func (s *SourceInfo) GetMacroCall(id int64) (Expr, bool) { // SetMacroCall records a macro call at a specific location. func (s *SourceInfo) SetMacroCall(id int64, e Expr) { - if s == nil { - return + if s != nil { + s.macroCalls[id] = e + } +} + +// ClearMacroCall removes the macro call at the given expression id. +func (s *SourceInfo) ClearMacroCall(id int64) { + if s != nil { + delete(s.macroCalls, id) } - s.macroCalls[id] = e } // OffsetRanges returns a map of expression id to OffsetRange values where the range indicates either: @@ -407,6 +413,7 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool { type maxIDVisitor struct { maxID int64 + *baseVisitor } // VisitExpr updates the max identifier if the incoming expression id is greater than previously observed. diff --git a/common/ast/expr.go b/common/ast/expr.go index 5811e395..c9d88bba 100644 --- a/common/ast/expr.go +++ b/common/ast/expr.go @@ -184,6 +184,9 @@ type ListExpr interface { // OptionalIndicies returns the list of optional indices in the list literal. OptionalIndices() []int32 + // IsOptional indicates whether the given element index is optional. + IsOptional(int32) bool + // Size returns the number of elements in the list. Size() int @@ -404,9 +407,14 @@ func (e *expr) SetKindCase(other Expr) { e.exprKindCase = baseIdentExpr(other.AsIdent()) case ListKind: l := other.AsList() + optIndexMap := make(map[int32]struct{}, len(l.OptionalIndices())) + for _, idx := range l.OptionalIndices() { + optIndexMap[idx] = struct{}{} + } e.exprKindCase = &baseListExpr{ - elements: l.Elements(), - optIndices: l.OptionalIndices(), + elements: l.Elements(), + optIndices: l.OptionalIndices(), + optIndexMap: optIndexMap, } case LiteralKind: e.exprKindCase = &baseLiteral{Val: other.AsLiteral()} @@ -591,8 +599,9 @@ func (*baseLiteral) isExpr() {} var _ ListExpr = &baseListExpr{} type baseListExpr struct { - elements []Expr - optIndices []int32 + elements []Expr + optIndices []int32 + optIndexMap map[int32]struct{} } func (*baseListExpr) Kind() ExprKind { @@ -606,6 +615,11 @@ func (e *baseListExpr) Elements() []Expr { return e.elements } +func (e *baseListExpr) IsOptional(index int32) bool { + _, found := e.optIndexMap[index] + return found +} + func (e *baseListExpr) OptionalIndices() []int32 { if e == nil { return []int32{} diff --git a/common/ast/factory.go b/common/ast/factory.go index 0111c289..b7f36e72 100644 --- a/common/ast/factory.go +++ b/common/ast/factory.go @@ -137,7 +137,16 @@ func (fac *baseExprFactory) NewLiteral(id int64, value ref.Val) Expr { } func (fac *baseExprFactory) NewList(id int64, elems []Expr, optIndices []int32) Expr { - return fac.newExpr(id, &baseListExpr{elements: elems, optIndices: optIndices}) + optIndexMap := make(map[int32]struct{}, len(optIndices)) + for _, idx := range optIndices { + optIndexMap[idx] = struct{}{} + } + return fac.newExpr(id, + &baseListExpr{ + elements: elems, + optIndices: optIndices, + optIndexMap: optIndexMap, + }) } func (fac *baseExprFactory) NewMap(id int64, entries []EntryExpr) Expr { diff --git a/common/ast/navigable.go b/common/ast/navigable.go index 2836b565..f5ddf6aa 100644 --- a/common/ast/navigable.go +++ b/common/ast/navigable.go @@ -423,6 +423,10 @@ func (l navigableListImpl) Elements() []Expr { return elems } +func (l navigableListImpl) IsOptional(index int32) bool { + return l.Expr.AsList().IsOptional(index) +} + func (l navigableListImpl) OptionalIndices() []int32 { return l.Expr.AsList().OptionalIndices() }