diff --git a/cel/inlining.go b/cel/inlining.go new file mode 100644 index 00000000..9fc3be27 --- /dev/null +++ b/cel/inlining.go @@ -0,0 +1,220 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/containers" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/traits" +) + +// InlineVariable holds a variable name to be matched and an AST representing +// the expression graph which should be used to replace it. +type InlineVariable struct { + name string + alias string + def *ast.AST +} + +// Name returns the qualified variable or field selection to replace. +func (v *InlineVariable) Name() string { + return v.name +} + +// Alias returns the alias to use when performing cel.bind() calls during inlining. +func (v *InlineVariable) Alias() string { + return v.alias +} + +// Expr returns the inlined expression value. +func (v *InlineVariable) Expr() ast.Expr { + return v.def.Expr() +} + +// Type indicates the inlined expression type. +func (v *InlineVariable) Type() *Type { + return v.def.GetType(v.def.Expr().ID()) +} + +// NewInlineVariable declares a variable name to be replaced by a checked expression. +func NewInlineVariable(name string, definition *Ast) *InlineVariable { + return NewInlineVariableWithAlias(name, name, definition) +} + +// NewInlineVariableWithAlias declares a variable name to be replaced by a checked expression. +// If the variable occurs more than once, the provided alias will be used to replace the expressions +// where the variable name occurs. +func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable { + return &InlineVariable{name: name, alias: alias, def: definition.impl} +} + +// NewInliningOptimizer creates and optimizer which replaces variables with expression definitions. +// +// If a variable occurs one time, the variable is replaced by the inline definition. If the +// variable occurs more than once, the variable occurences are replaced by a cel.bind() call. +func NewInliningOptimizer(inlineVars ...*InlineVariable) ASTOptimizer { + return &inliningOptimizer{variables: inlineVars} +} + +type inliningOptimizer struct { + variables []*InlineVariable +} + +func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + for _, inlineVar := range opt.variables { + matches := ast.MatchDescendants(root, opt.matchVariable(inlineVar.Name())) + // Skip cases where the variable isn't in the expression graph + if len(matches) == 0 { + continue + } + + // For a single match, do a direct replacement of the expression sub-graph. + if len(matches) == 1 { + opt.inlineExpr(ctx, matches[0], ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type()) + continue + } + + if !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) { + for _, match := range matches { + opt.inlineExpr(ctx, match, ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type()) + } + continue + } + // For multiple matches, find the least common ancestor (lca) and insert the + // variable as a cel.bind() macro. + var lca ast.NavigableExpr = nil + ancestors := map[int64]bool{} + for _, match := range matches { + // Update the identifier matches with the provided alias. + aliasExpr := ctx.NewIdent(inlineVar.Alias()) + opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type()) + parent, found := match, true + for found { + _, hasAncestor := ancestors[parent.ID()] + if hasAncestor && (lca == nil || lca.Depth() < parent.Depth()) { + lca = parent + } + ancestors[parent.ID()] = true + parent, found = parent.Parent() + } + } + + // Update the least common ancestor by inserting a cel.bind() call to the alias. + inlined := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), inlineVar.Expr(), lca) + opt.inlineExpr(ctx, lca, inlined, inlineVar.Type()) + } + return a +} + +// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining +// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is +// made to determine whether the inlined value can be presence or existence tested. +func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) { + switch prev.Kind() { + case ast.SelectKind: + sel := prev.AsSelect() + if !sel.IsTestOnly() { + prev.SetKindCase(inlined) + return + } + opt.rewritePresenceExpr(ctx, prev, inlined, inlinedType) + default: + prev.SetKindCase(inlined) + } +} + +// rewritePresenceExpr converts the inlined expression, when it occurs within a has() macro, to type-safe +// expression appropriate for the inlined type, if possible. +// +// If the rewrite is not possible an error is reported at the inline expression site. +func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) { + // If the input inlined expression is not a select expression it won't work with the has() + // macro. Attempt to rewrite the presence test in terms of the typed input, otherwise error. + ctx.sourceInfo.ClearMacroCall(prev.ID()) + if inlined.Kind() == ast.SelectKind { + inlinedSel := inlined.AsSelect() + prev.SetKindCase( + ctx.NewPresenceTest(prev.ID(), inlinedSel.Operand(), inlinedSel.FieldName())) + return + } + if inlinedType.IsAssignableType(NullType) { + prev.SetKindCase( + ctx.NewCall(operators.NotEquals, + inlined, + ctx.NewLiteral(types.NullValue), + )) + return + } + if inlinedType.HasTrait(traits.SizerType) { + prev.SetKindCase( + ctx.NewCall(operators.NotEquals, + ctx.NewMemberCall(overloads.Size, inlined), + ctx.NewLiteral(types.IntZero), + )) + return + } + ctx.ReportErrorAtID(prev.ID(), "unable to inline expression type %v into presence test", inlinedType) +} + +// isBindable indicates whether the inlined type can be used within a cel.bind() if the expression +// being replaced occurs within a presence test. Value types with a size() method or field selection +// support can be bound. +// +// In future iterations, support may also be added for indexer types which can be rewritten as an `in` +// expression; however, this would imply a rewrite of the inlined expression that may not be necessary +// in most cases. +func isBindable(matches []ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) bool { + if inlinedType.IsAssignableType(NullType) || + inlinedType.HasTrait(traits.SizerType) || + inlinedType.HasTrait(traits.FieldTesterType) { + return true + } + for _, m := range matches { + if m.Kind() != ast.SelectKind { + continue + } + sel := m.AsSelect() + if sel.IsTestOnly() { + return false + } + } + return true +} + +// matchVariable matches simple identifiers, select expressions, and presence test expressions +// which match the (potentially) qualified variable name provided as input. +// +// Note, this function does not support inlining against select expressions which includes optional +// field selection. This may be a future refinement. +func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher { + return func(e ast.NavigableExpr) bool { + if e.Kind() == ast.IdentKind && e.AsIdent() == varName { + return true + } + if e.Kind() == ast.SelectKind { + sel := e.AsSelect() + // While the `ToQualifiedName` call could take the select directly, this + // would skip presence tests from possible matches, which we would like + // to include. + qualName, found := containers.ToQualifiedName(sel.Operand()) + return found && qualName+"."+sel.FieldName() == varName + } + return false + } +} diff --git a/cel/inlining_test.go b/cel/inlining_test.go new file mode 100644 index 00000000..389e4625 --- /dev/null +++ b/cel/inlining_test.go @@ -0,0 +1,577 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel_test + +import ( + "testing" + + "github.com/google/cel-go/cel" + + proto3pb "github.com/google/cel-go/test/proto3pb" +) + +func TestInliningOptimizer(t *testing.T) { + type varExpr struct { + name string + alias string + t *cel.Type + expr string + } + tests := []struct { + expr string + vars []varExpr + inlined string + folded string + }{ + { + expr: `a || b`, + vars: []varExpr{ + { + name: "a", + t: cel.BoolType, + }, + { + name: "b", + alias: "bravo", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + }, + inlined: `a || "hello".contains("lo")`, + folded: `true`, + }, + { + expr: `a + [a]`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.DynType, + expr: `dyn([1, 2])`, + }, + }, + inlined: `cel.bind(alpha, dyn([1, 2]), alpha + [alpha])`, + folded: `[1, 2, [1, 2]]`, + }, + { + expr: `a && (a || b)`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + { + name: "b", + t: cel.BoolType, + }, + }, + inlined: `cel.bind(alpha, "hello".contains("lo"), alpha && (alpha || b))`, + folded: `true`, + }, + { + expr: `a && b && a`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + { + name: "b", + t: cel.BoolType, + }, + }, + inlined: `cel.bind(alpha, "hello".contains("lo"), alpha && b && alpha)`, + folded: `cel.bind(alpha, true, alpha && b && alpha)`, + }, + { + expr: `(c || d) || (a && (a || b))`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + { + name: "b", + t: cel.BoolType, + }, + { + name: "c", + t: cel.BoolType, + }, + { + name: "d", + t: cel.BoolType, + expr: "!false", + }, + }, + inlined: `c || !false || cel.bind(alpha, "hello".contains("lo"), alpha && (alpha || b))`, + folded: `true`, + }, + { + expr: `a && (a || b)`, + vars: []varExpr{ + { + name: "a", + t: cel.BoolType, + }, + { + name: "b", + alias: "bravo", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + }, + inlined: `a && (a || "hello".contains("lo"))`, + folded: `a`, + }, + { + expr: `a && b`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `!'hello'.contains('lo')`, + }, + { + name: "b", + alias: "bravo", + t: cel.BoolType, + }, + }, + inlined: `!"hello".contains("lo") && b`, + folded: `false`, + }, + { + expr: `operation.system.consumers + operation.destination_consumers`, + vars: []varExpr{ + { + name: "operation.system", + t: cel.DynType, + }, + { + name: "operation.destination_consumers", + t: cel.ListType(cel.IntType), + expr: `productsToConsumers(operation.destination_products)`, + }, + { + name: "operation.destination_products", + t: cel.ListType(cel.IntType), + expr: `operation.system.products`, + }, + }, + inlined: `operation.system.consumers + productsToConsumers(operation.system.products)`, + folded: `operation.system.consumers + productsToConsumers(operation.system.products)`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + opts := []cel.EnvOption{cel.OptionalTypes(), + cel.EnableMacroCallTracking(), + cel.Function("productsToConsumers", + cel.Overload("productsToConsumers_list", + []*cel.Type{cel.ListType(cel.IntType)}, + cel.ListType(cel.IntType)))} + + varDecls := make([]cel.EnvOption, len(tc.vars)) + for i, v := range tc.vars { + varDecls[i] = cel.Variable(v.name, v.t) + } + e, err := cel.NewEnv(append(varDecls, opts...)...) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + inlinedVars := []*cel.InlineVariable{} + for _, v := range tc.vars { + if v.expr == "" { + continue + } + checked, iss := e.Compile(v.expr) + if iss.Err() != nil { + t.Fatalf("Compile(%q) failed: %v", v.expr, iss.Err()) + } + if v.alias == "" { + inlinedVars = append(inlinedVars, cel.NewInlineVariable(v.name, checked)) + } else { + inlinedVars = append(inlinedVars, cel.NewInlineVariableWithAlias(v.name, v.alias, checked)) + } + } + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + + opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + inlined, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if inlined != tc.inlined { + t.Errorf("got %q, wanted %q", inlined, tc.inlined) + } + folder, err := cel.NewConstantFoldingOptimizer() + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt = cel.NewStaticOptimizer(folder) + optimized, iss = opt.Optimize(e, optimized) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} + +func TestInliningOptimizerMultiStage(t *testing.T) { + type varDecl struct { + name string + t *cel.Type + } + type inlineVarExpr struct { + name string + alias string + t *cel.Type + expr string + } + tests := []struct { + expr string + vars []varDecl + inlineVars []inlineVarExpr + inlined string + folded string + }{ + { + expr: `has(a.b) ? a.b : 'default'`, + vars: []varDecl{ + { + name: "a", + t: cel.MapType(cel.StringType, cel.StringType), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "a.b", + alias: "alpha", + t: cel.StringType, + expr: `'hello'`, + }, + }, + inlined: `cel.bind(alpha, "hello", (alpha.size() != 0) ? alpha : "default")`, + folded: `"hello"`, + }, + { + expr: `has(a.b) ? a.b : ['default']`, + vars: []varDecl{ + { + name: "a", + t: cel.MapType(cel.StringType, cel.ListType(cel.StringType)), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "a.b", + alias: "alpha", + t: cel.StringType, + expr: `['hello']`, + }, + }, + inlined: `cel.bind(alpha, ["hello"], (alpha.size() != 0) ? alpha : ["default"])`, + folded: `["hello"]`, + }, + { + expr: `0 in msg.map_int64_nested_type`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + { + name: "nested_map", + t: cel.MapType(cel.IntType, cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes")), + }, + }, + inlineVars: []inlineVarExpr{ + + { + name: "msg.map_int64_nested_type", + t: cel.MapType(cel.IntType, cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes")), + expr: `nested_map`, + }, + }, + inlined: `0 in nested_map`, + folded: `0 in nested_map`, + }, + { + expr: `has(msg.single_any)`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + { + name: "unpacked_wrapper", + t: cel.NullableType(cel.StringType), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.single_any", + t: cel.NullableType(cel.StringType), + expr: `unpacked_wrapper`, + }, + }, + inlined: `unpacked_wrapper != null`, + folded: `unpacked_wrapper != null`, + }, + { + expr: `has(msg.single_any) ? msg.single_any : '10'`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + { + name: "unpacked_wrapper", + t: cel.NullableType(cel.StringType), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.single_any", + t: cel.NullableType(cel.StringType), + alias: "wrapped", + expr: `unpacked_wrapper`, + }, + }, + inlined: `cel.bind(wrapped, unpacked_wrapper, (wrapped != null) ? wrapped : "10")`, + folded: `cel.bind(wrapped, unpacked_wrapper, (wrapped != null) ? wrapped : "10")`, + }, + { + expr: `has(msg.child.payload.single_int32_wrapper)`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes"), + }, + { + name: "unpacked_child", + t: cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.child.payload", + t: cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes"), + alias: "payload", + expr: `unpacked_child.payload`, + }, + }, + inlined: `has(unpacked_child.payload.single_int32_wrapper)`, + folded: `has(unpacked_child.payload.single_int32_wrapper)`, + }, + { + expr: `has(msg.child.payload.single_int32_wrapper)`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes"), + }, + { + name: "unpacked_payload", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.child.payload.single_int32_wrapper", + t: cel.NullableType(cel.IntType), + alias: "payload", + expr: `unpacked_payload.single_int32_wrapper`, + }, + }, + inlined: `has(unpacked_payload.single_int32_wrapper)`, + folded: `has(unpacked_payload.single_int32_wrapper)`, + }, + { + expr: `has(msg.child.payload.single_int32_wrapper) ? msg.child.payload.single_int32_wrapper : 1`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes"), + }, + { + name: "unpacked_payload", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.child.payload.single_int32_wrapper", + t: cel.NullableType(cel.IntType), + alias: "nullable_int", + expr: `unpacked_payload.single_int32_wrapper`, + }, + }, + inlined: `cel.bind(nullable_int, unpacked_payload.single_int32_wrapper, (nullable_int != null) ? nullable_int : 1)`, + folded: `cel.bind(nullable_int, unpacked_payload.single_int32_wrapper, (nullable_int != null) ? nullable_int : 1)`, + }, + { + expr: `has(msg.single_value) ? msg.single_value : null`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.single_value", + t: cel.NullableType(cel.DoubleType), + alias: "nullable_float", + expr: `dyn(1.5)`, + }, + }, + inlined: `cel.bind(nullable_float, dyn(1.5), (nullable_float != null) ? nullable_float : null)`, + folded: `1.5`, + }, + { + expr: `has(msg.single_any) ? msg.single_any : 42`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.single_any", + t: cel.IntType, + alias: "unpacked_nested", + expr: `google.expr.proto3.test.NestedTestAllTypes{}.payload.single_int32`, + }, + }, + inlined: `has(google.expr.proto3.test.NestedTestAllTypes{}.payload.single_int32) ? google.expr.proto3.test.NestedTestAllTypes{}.payload.single_int32 : 42`, + folded: `42`, + }, + { + expr: `has(msg.single_any.single_int32) ? msg.single_any.single_int32 : 42`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + { + name: "unpacked_any", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.single_any", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + alias: "unpacked_any", + expr: `google.expr.proto3.test.NestedTestAllTypes{}.payload`, + }, + }, + inlined: `cel.bind(unpacked_any, google.expr.proto3.test.NestedTestAllTypes{}.payload, has(unpacked_any.single_int32) ? unpacked_any.single_int32 : 42)`, + folded: `42`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + opts := []cel.EnvOption{ + cel.Container("google.expr"), + cel.Types(&proto3pb.TestAllTypes{}), + cel.OptionalTypes(), + cel.EnableMacroCallTracking()} + + varDecls := make([]cel.EnvOption, len(tc.vars)) + for i, v := range tc.vars { + varDecls[i] = cel.Variable(v.name, v.t) + } + e, err := cel.NewEnv(append(varDecls, opts...)...) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + inlinedVars := []*cel.InlineVariable{} + for _, v := range tc.inlineVars { + if v.expr == "" { + continue + } + checked, iss := e.Compile(v.expr) + if iss.Err() != nil { + t.Fatalf("Compile(%q) failed: %v", v.expr, iss.Err()) + } + if v.alias == "" { + inlinedVars = append(inlinedVars, cel.NewInlineVariable(v.name, checked)) + } else { + inlinedVars = append(inlinedVars, cel.NewInlineVariableWithAlias(v.name, v.alias, checked)) + } + } + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + + opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + inlined, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if inlined != tc.inlined { + t.Errorf("got %q, wanted %q", inlined, tc.inlined) + } + folder, err := cel.NewConstantFoldingOptimizer() + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt = cel.NewStaticOptimizer(folder) + optimized, iss = opt.Optimize(e, optimized) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} diff --git a/cel/optimizer.go b/cel/optimizer.go index a2023ed7..9422a7eb 100644 --- a/cel/optimizer.go +++ b/cel/optimizer.go @@ -17,6 +17,7 @@ package cel import ( "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" ) @@ -197,6 +198,57 @@ type optimizerExprFactory struct { sourceInfo *ast.SourceInfo } +// CopyExpr copies the structure of the input ast.Expr and renumbers the identifiers in a manner +// consistent with the CEL parser / checker. +func (opt *optimizerExprFactory) CopyExpr(e ast.Expr) ast.Expr { + copy := opt.fac.CopyExpr(e) + copy.RenumberIDs(opt.renumberID) + return copy +} + +// NewBindMacro creates a cel.bind() call with a variable name, initialization expression, and remaining expression. +// +// Note: the macroID indicates the insertion point, the call id that matched the macro signature, which will be used +// for coordinating macro metadata with the bind call. This piece of data is what makes it possible to unparse +// optimized expressions which use the bind() call. +// +// Example: +// +// cel.bind(myVar, a && b || c, !myVar || (myVar && d)) +// - varName: myVar +// - varInit: a && b || c +// - remaining: !myVar || (myVar && d) +func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) ast.Expr { + bindID := opt.nextID() + varID := opt.nextID() + + varInit = opt.CopyExpr(varInit) + varInit.RenumberIDs(opt.renumberID) + + remaining = opt.fac.CopyExpr(remaining) + remaining.RenumberIDs(opt.renumberID) + + // Place the expanded macro form in the macro calls list so that the inlined + // call can be unparsed. + opt.sourceInfo.SetMacroCall(macroID, + opt.fac.NewMemberCall(0, "bind", + opt.fac.NewIdent(opt.nextID(), "cel"), + opt.fac.NewIdent(varID, varName), + varInit, + remaining)) + + // Replace the parent node with the intercepted inlining using cel.bind()-like + // generated comprehension AST. + return opt.fac.NewComprehension(bindID, + opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}), + "#unused", + varName, + opt.fac.CopyExpr(varInit), + opt.fac.NewLiteral(opt.nextID(), types.False), + opt.fac.NewIdent(varID, varName), + opt.fac.CopyExpr(remaining)) +} + // NewCall creates a global function call invocation expression. // // Example: @@ -277,6 +329,27 @@ func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional boo return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional) } +// NewPresenceTest creates a new presence test macro call. +// +// Example: +// +// has(msg.field_name) +// - operand: msg +// - field: field_name +func (opt *optimizerExprFactory) NewPresenceTest(macroID int64, operand ast.Expr, field string) ast.Expr { + // Copy the input operand and renumber it. + operand = opt.CopyExpr(operand) + operand.RenumberIDs(opt.renumberID) + + // Place the expanded macro form in the macro calls list so that the inlined call can be unparsed. + opt.sourceInfo.SetMacroCall(macroID, + opt.fac.NewCall(0, "has", + opt.fac.NewSelect(opt.nextID(), operand, field))) + + // Generate a new presence test macro. + return opt.fac.NewPresenceTest(opt.nextID(), opt.CopyExpr(operand), field) +} + // NewSelect creates a select expression where a field value is selected from an operand. // // Example: