From b3f40070aa2c9dbf60bef933a3705e71dae92a1a Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 22 May 2023 13:38:08 -0700 Subject: [PATCH] Support for pruning macro call nodes with limited optional support (#705) * Support for pruning macro call nodes with some limited optional support * Additional map-related test cases * Fix expression updating logic to update the macro reference rather than the core expression --- interpreter/prune.go | 227 +++++++++++++++++++++++--------------- interpreter/prune_test.go | 172 ++++++++++++++++++++++++++--- 2 files changed, 294 insertions(+), 105 deletions(-) diff --git a/interpreter/prune.go b/interpreter/prune.go index fc2c8135..b7f3a4d2 100644 --- a/interpreter/prune.go +++ b/interpreter/prune.go @@ -16,6 +16,7 @@ package interpreter import ( "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" @@ -94,35 +95,50 @@ func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr { } func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) { - switch val.Type() { - case types.BoolType: + switch v := val.(type) { + case types.Bool: p.state.SetValue(id, val) return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: val.Value().(bool)}}), true - case types.IntType: + &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: bool(v)}}), true + case types.Bytes: p.state.SetValue(id, val) return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: val.Value().(int64)}}), true - case types.UintType: + &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: []byte(v)}}), true + case types.Double: p.state.SetValue(id, val) return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: val.Value().(uint64)}}), true - case types.StringType: + &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: float64(v)}}), true + case types.Duration: + p.state.SetValue(id, val) + durationString := string(v.ConvertToType(types.StringType).(types.String)) + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_CallExpr{ + CallExpr: &exprpb.Expr_Call{ + Function: overloads.TypeConvertDuration, + Args: []*exprpb.Expr{ + p.createLiteral(p.nextID(), + &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: durationString}}), + }, + }, + }, + }, true + case types.Int: p.state.SetValue(id, val) return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: val.Value().(string)}}), true - case types.DoubleType: + &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: int64(v)}}), true + case types.Uint: p.state.SetValue(id, val) return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: val.Value().(float64)}}), true - case types.BytesType: + &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: uint64(v)}}), true + case types.String: p.state.SetValue(id, val) return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: val.Value().([]byte)}}), true - case types.NullType: + &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: string(v)}}), true + case types.Null: p.state.SetValue(id, val) return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: val.Value().(structpb.NullValue)}}), true + &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: v.Value().(structpb.NullValue)}}), true } // Attempt to build a list literal. @@ -196,29 +212,37 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo return nil, false } -func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) { - if !p.existsWithUnknownValue(node.GetId()) { - return nil, false +func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) { + elemVal, found := p.value(elem.GetId()) + if found && elemVal.Type() == types.OptionalType { + opt := elemVal.(*types.Optional) + if !opt.HasValue() { + return nil, true + } + if newElem, pruned := p.maybeCreateLiteral(elem.GetId(), opt.GetValue()); pruned { + return newElem, true + } } + return elem, false +} + +func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) { call := node.GetCallExpr() - val, valueExists := p.value(call.GetArgs()[1].GetId()) - if !valueExists { + v, exists := p.value(call.GetArgs()[1].GetId()) + if !exists || types.IsUnknownOrError(v) { return nil, false } - if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { + if sz, ok := v.(traits.Sizer); ok && sz.Size() == types.IntZero { return p.maybeCreateLiteral(node.GetId(), types.False) } return nil, false } func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) { - if !p.existsWithUnknownValue(node.GetId()) { - return nil, false - } call := node.GetCallExpr() arg := call.GetArgs()[0] v, exists := p.value(arg.GetId()) - if !exists { + if !exists || types.IsUnknownOrError(v) { return nil, false } if b, ok := v.(types.Bool); ok { @@ -228,37 +252,28 @@ func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) } func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) { - if !p.existsWithUnknownValue(node.GetId()) { - return nil, false - } - call := node.GetCallExpr() // We know result is unknown, so we have at least one unknown arg // and if one side is a known value, we know we can ignore it. - if p.existsWithKnownValue(call.Args[0].GetId()) { - return call.Args[1], true + if p.existsWithKnownValue(call.GetArgs()[0].GetId()) { + return call.GetArgs()[1], true } - if p.existsWithKnownValue(call.Args[1].GetId()) { - return call.Args[0], true + if p.existsWithKnownValue(call.GetArgs()[1].GetId()) { + return call.GetArgs()[0], true } return nil, false } func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) { - if !p.existsWithUnknownValue(node.GetId()) { - return nil, false - } - call := node.GetCallExpr() - condVal, condValueExists := p.value(call.Args[0].GetId()) - if !condValueExists || types.IsUnknownOrError(condVal) { + cond, exists := p.value(call.GetArgs()[0].GetId()) + if !exists || types.IsUnknownOrError(cond) { return nil, false } - - if condVal.Value().(bool) { - return call.Args[1], true + if cond.Value().(bool) { + return call.GetArgs()[1], true } - return call.Args[2], true + return call.GetArgs()[2], true } func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) { @@ -279,23 +294,28 @@ func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) { } func (p *astPruner) maybePrune(node *exprpb.Expr) (*exprpb.Expr, bool) { - out, pruned := p.prune(node) - if pruned { - delete(p.macroCalls, node.GetId()) - } - return out, pruned + return p.prune(node) } func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { if node == nil { return node, false } - val, valueExists := p.value(node.GetId()) - if valueExists && !types.IsUnknownOrError(val) { - if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok { + v, exists := p.value(node.GetId()) + if exists && !types.IsUnknownOrError(v) { + if newNode, ok := p.maybeCreateLiteral(node.GetId(), v); ok { + // if the macro completely evaluated, then delete the reference to it, if one exists. + delete(p.macroCalls, node.GetId()) + // return the literal value. return newNode, true } } + if macro, found := p.macroCalls[node.GetId()]; found { + // prune the expression in terms of the macro call instead of the expanded form. + if newMacro, pruned := p.prune(macro); pruned { + p.macroCalls[node.GetId()] = newMacro + } + } // We have either an unknown/error value, or something we don't want to // transform, or expression was not evaluated. If possible, drill down @@ -351,21 +371,51 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { } case *exprpb.Expr_ListExpr: elems := node.GetListExpr().GetElements() - newElems := make([]*exprpb.Expr, len(elems)) + optIndices := node.GetListExpr().GetOptionalIndices() + optIndexMap := map[int32]bool{} + for _, i := range optIndices { + optIndexMap[i] = true + } + newOptIndexMap := make(map[int32]bool, len(optIndexMap)) + newElems := make([]*exprpb.Expr, 0, len(elems)) var prunedList bool + + prunedIdx := 0 for i, elem := range elems { - newElems[i] = elem + _, isOpt := optIndexMap[int32(i)] + if isOpt { + newElem, pruned := p.maybePruneOptional(elem) + if pruned { + if newElem != nil { + newElems = append(newElems, newElem) + prunedIdx++ + } + prunedList = true + continue + } + newOptIndexMap[int32(prunedIdx)] = true + } if newElem, prunedElem := p.maybePrune(elem); prunedElem { - newElems[i] = newElem + newElems = append(newElems, newElem) prunedList = true + } else { + newElems = append(newElems, elem) } + prunedIdx++ + } + optIndices = make([]int32, len(newOptIndexMap)) + idx := 0 + for i := range newOptIndexMap { + optIndices[idx] = i + idx++ } if prunedList { return &exprpb.Expr{ Id: node.GetId(), ExprKind: &exprpb.Expr_ListExpr{ ListExpr: &exprpb.Expr_CreateList{ - Elements: newElems, + Elements: newElems, + OptionalIndices: optIndices, }, }, }, true @@ -395,6 +445,7 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { MapKey: newKey, } } + newEntry.OptionalEntry = entry.GetOptionalEntry() newEntries[i] = newEntry } if prunedStruct { @@ -408,27 +459,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { }, }, true } - case *exprpb.Expr_ComprehensionExpr: - compre := node.GetComprehensionExpr() - // Only the range of the comprehension is pruned since the state tracking only records - // the last iteration of the comprehension and not each step in the evaluation which - // means that the any residuals computed in between might be inaccurate. - if newRange, pruned := p.maybePrune(compre.GetIterRange()); pruned { - return &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_ComprehensionExpr{ - ComprehensionExpr: &exprpb.Expr_Comprehension{ - IterVar: compre.GetIterVar(), - IterRange: newRange, - AccuVar: compre.GetAccuVar(), - AccuInit: compre.GetAccuInit(), - LoopCondition: compre.GetLoopCondition(), - LoopStep: compre.GetLoopStep(), - Result: compre.GetResult(), - }, - }, - }, true - } } return node, false } @@ -438,14 +468,9 @@ func (p *astPruner) value(id int64) (ref.Val, bool) { return val, (found && val != nil) } -func (p *astPruner) existsWithUnknownValue(id int64) bool { - val, valueExists := p.value(id) - return valueExists && types.IsUnknown(val) -} - func (p *astPruner) existsWithKnownValue(id int64) bool { val, valueExists := p.value(id) - return valueExists && !types.IsUnknown(val) + return valueExists && !types.IsUnknownOrError(val) } func (p *astPruner) nextID() int64 { @@ -460,14 +485,39 @@ func (p *astPruner) nextID() int64 { } } +type astVisitor struct { + // visitEntry is called on every expr node, including those within a map/struct entry. + visitExpr func(expr *exprpb.Expr) + // visitEntry is called before entering the key, value of a map/struct entry. + visitEntry func(entry *exprpb.Expr_CreateStruct_Entry) +} + func getMaxID(expr *exprpb.Expr) int64 { maxID := int64(1) + visit(expr, maxIDVisitor(&maxID)) + return maxID +} + +func maxIDVisitor(maxID *int64) astVisitor { + return astVisitor{ + visitExpr: func(e *exprpb.Expr) { + if e.GetId() >= *maxID { + *maxID = e.GetId() + 1 + } + }, + visitEntry: func(e *exprpb.Expr_CreateStruct_Entry) { + if e.GetId() >= *maxID { + *maxID = e.GetId() + 1 + } + }, + } +} + +func visit(expr *exprpb.Expr, visitor astVisitor) { exprs := []*exprpb.Expr{expr} for len(exprs) != 0 { e := exprs[0] - if e.GetId() >= maxID { - maxID = e.GetId() + 1 - } + visitor.visitExpr(e) exprs = exprs[1:] switch e.GetExprKind().(type) { case *exprpb.Expr_SelectExpr: @@ -490,16 +540,13 @@ func getMaxID(expr *exprpb.Expr) int64 { list := e.GetListExpr() exprs = append(exprs, list.GetElements()...) case *exprpb.Expr_StructExpr: - for _, entry := range expr.GetStructExpr().GetEntries() { + for _, entry := range e.GetStructExpr().GetEntries() { + visitor.visitEntry(entry) if entry.GetMapKey() != nil { exprs = append(exprs, entry.GetMapKey()) } exprs = append(exprs, entry.GetValue()) - if entry.GetId() >= maxID { - maxID = entry.GetId() + 1 - } } } } - return maxID } diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index 302d0bf5..fe60df78 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -19,6 +19,9 @@ import ( "github.com/google/cel-go/common" "github.com/google/cel-go/common/containers" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/interpreter/functions" "github.com/google/cel-go/parser" "github.com/google/cel-go/test" ) @@ -30,6 +33,10 @@ type testInfo struct { } var testCases = []testInfo{ + { + expr: `{{'nested_key': true}.nested_key: true}`, + out: `{true: true}`, + }, { in: map[string]any{ "msg": map[string]string{"foo": "bar"}, @@ -78,8 +85,8 @@ var testCases = []testInfo{ }, { in: partialActivation(map[string]any{"rules": map[string]any{"not_in": []string{}}}, "this"), - expr: `this.size() > 0 ? this in rules.not_in : - !(this in rules.not_in) ? true : false`, + expr: `this.size() > 0 ? this in rules.not_in : + !(this in rules.not_in) ? true : false`, out: `(this.size() > 0) ? false : true`, }, { @@ -102,6 +109,107 @@ var testCases = []testInfo{ expr: `2 < 3`, out: `true`, }, + { + expr: `!false`, + out: `true`, + }, + { + in: unknownActivation("y"), + expr: `!y`, + out: `!y`, + }, + { + in: partialActivation(map[string]any{"y": 10}), + expr: `optional.of(y)`, + out: `optional.of(10)`, + }, + { + in: unknownActivation("a"), + expr: `a.?b`, + out: `a.?b`, + }, + { + in: unknownActivation("a"), + expr: `a[?"b"]`, + out: `a[?"b"]`, + }, + { + in: unknownActivation(), + expr: `[1, 2, 3, ?optional.none()]`, + out: `[1, 2, 3]`, + }, + { + in: unknownActivation(), + expr: `[1, 2, 3, ?optional.of(10)]`, + out: `[1, 2, 3, 10]`, + }, + { + in: unknownActivation(), + expr: `{1: 2, ?3: optional.none()}`, + out: `{1: 2}`, + }, + { + in: unknownActivation("a"), + expr: `[?optional.none(), a, 2, 3]`, + out: `[a, 2, 3]`, + }, + { + in: unknownActivation("a"), + expr: `[?optional.of(10), ?a, 2, 3]`, + out: `[10, ?a, 2, 3]`, + }, + { + in: partialActivation(map[string]any{"a": "hi"}, "b"), + expr: `{?a: b.?c}`, + out: `{?"hi": b.?c}`, + }, + { + in: partialActivation(map[string]any{"a": "hi"}, "b"), + expr: `"hi" in {?a: b.?c}`, + out: `"hi" in {?"hi": b.?c}`, + }, + { + in: partialActivation(map[string]any{"a": "hi"}, "b"), + expr: `"hi" in {?a: optional.of("world")}`, + out: `true`, + }, + { + in: partialActivation(map[string]any{"a": "hi"}, "b"), + expr: `{?a: optional.of("world")}[b]`, + out: `{"hi": "world"}[b]`, + }, + { + in: unknownActivation("y"), + expr: `duration('1h') + duration('2h') > y`, + out: `duration("10800s") > y`, + }, + { + in: unknownActivation("x"), + expr: `[x, timestamp(0)]`, + out: `[x, timestamp(0)]`, + }, + { + expr: `[timestamp(0), timestamp(1)]`, + out: `[timestamp(0), timestamp(1)]`, + }, + { + expr: `{"epoch": timestamp(0)}`, + out: `{"epoch": timestamp(0)}`, + }, + { + in: partialActivation(map[string]any{"x": false}, "y"), + expr: `!y && !x`, + out: `!y`, + }, + { + expr: `!y && !(1/0 < 0)`, + out: `!y && !(1/0 < 0)`, + }, + { + in: partialActivation(map[string]any{"y": false}), + expr: `!y && !(1/0 < 0)`, + out: `!(1/0 < 0)`, + }, { in: unknownActivation(), expr: `test == null`, @@ -163,23 +271,37 @@ var testCases = []testInfo{ {"name": "alice", "role": "EMPLOYEE"}, {"name": "bob", "role": "MANAGER"}, {"name": "eve", "role": "CUSTOMER"}, - }}, "r.attr.*"), + }}, "r.attr"), expr: `users.filter(u, u.role=="MANAGER").map(u, u.name) == r.attr.authorized["managers"]`, out: `["bob"] == r.attr.authorized["managers"]`, }, - // TODO: the output of an expression like this relies on either - // a) doing replacements on the original macro call, or - // b) mutating the macro call tracking data rather than the core - // expression in order to render the partial correctly. - // { - // in: unknownActivation(), - // expr: `[1+3, 2+2, 3+1, four].exists(x, x == four)`, - // out: `[4, 4, 4, four].exists(x, x == four)`, - // }, + { + in: unknownActivation("four"), + expr: `[1+3, 2+2, 3+1, four]`, + out: `[4, 4, 4, four]`, + }, + { + in: unknownActivation("four"), + expr: `[1+3, 2+2, 3+1, four].exists(x, x == four)`, + out: `[4, 4, 4, four].exists(x, x == four)`, + }, + { + in: unknownActivation("a", "c"), + expr: `[has(a.b), has(c.d)].exists(x, x == true)`, + out: `[has(a.b), has(c.d)].exists(x, x == true)`, + }, + { + in: partialActivation(map[string]any{ + "a": map[string]any{}, + }, "c"), + expr: `[has(a.b), has(c.d)].exists(x, x == true)`, + out: `[false, has(c.d)].exists(x, x == true)`, + }, } func TestPrune(t *testing.T) { p, err := parser.NewParser( + parser.EnableOptionalSyntax(true), parser.PopulateMacroCalls(true), parser.Macros(parser.AllMacros...), ) @@ -194,7 +316,10 @@ func TestPrune(t *testing.T) { state := NewEvalState() reg := newTestRegistry(t) attrs := NewPartialAttributeFactory(containers.DefaultContainer, reg, reg) - interp := NewStandardInterpreter(containers.DefaultContainer, reg, reg, attrs) + dispatcher := NewDispatcher() + dispatcher.Add(functions.StandardOverloads()...) + dispatcher.Add(optionalFunctions()...) + interp := NewInterpreter(dispatcher, containers.DefaultContainer, reg, reg, attrs) interpretable, _ := interp.NewUncheckedInterpretable( ast.GetExpr(), @@ -212,7 +337,7 @@ func TestPrune(t *testing.T) { } func unknownActivation(vars ...string) PartialActivation { - pats := make([]*AttributePattern, len(vars), len(vars)) + pats := make([]*AttributePattern, len(vars)) for i, v := range vars { pats[i] = NewAttributePattern(v) } @@ -221,7 +346,7 @@ func unknownActivation(vars ...string) PartialActivation { } func partialActivation(in map[string]any, vars ...string) PartialActivation { - pats := make([]*AttributePattern, len(vars), len(vars)) + pats := make([]*AttributePattern, len(vars)) for i, v := range vars { pats[i] = NewAttributePattern(v) } @@ -240,3 +365,20 @@ func testActivation(t *testing.T, in any) Activation { } return a } + +func optionalFunctions() []*functions.Overload { + return []*functions.Overload{ + { + Operator: "optional.none", + Function: func(args ...ref.Val) ref.Val { + return types.OptionalNone + }, + }, + { + Operator: "optional.of", + Unary: func(val ref.Val) ref.Val { + return types.OptionalOf(val) + }, + }, + } +}