Skip to content

Commit

Permalink
Support for pruning macro call nodes with limited optional support (#705
Browse files Browse the repository at this point in the history
)

* Support for pruning macro call nodes with some limited optional support
* Additional map-related test cases
* Fix expression updating logic to update the macro reference rather than the core expression
  • Loading branch information
TristonianJones authored May 22, 2023
1 parent a68a6de commit b3f4007
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 105 deletions.
227 changes: 137 additions & 90 deletions interpreter/prune.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package interpreter

import (
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
Expand Down Expand Up @@ -94,35 +95,50 @@ func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr {
}

func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) {
switch val.Type() {
case types.BoolType:
switch v := val.(type) {
case types.Bool:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: val.Value().(bool)}}), true
case types.IntType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: bool(v)}}), true
case types.Bytes:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: val.Value().(int64)}}), true
case types.UintType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: []byte(v)}}), true
case types.Double:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: val.Value().(uint64)}}), true
case types.StringType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: float64(v)}}), true
case types.Duration:
p.state.SetValue(id, val)
durationString := string(v.ConvertToType(types.StringType).(types.String))
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: overloads.TypeConvertDuration,
Args: []*exprpb.Expr{
p.createLiteral(p.nextID(),
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: durationString}}),
},
},
},
}, true
case types.Int:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: val.Value().(string)}}), true
case types.DoubleType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: int64(v)}}), true
case types.Uint:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: val.Value().(float64)}}), true
case types.BytesType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: uint64(v)}}), true
case types.String:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: val.Value().([]byte)}}), true
case types.NullType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: string(v)}}), true
case types.Null:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: val.Value().(structpb.NullValue)}}), true
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: v.Value().(structpb.NullValue)}}), true
}

// Attempt to build a list literal.
Expand Down Expand Up @@ -196,29 +212,37 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
return nil, false
}

func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) {
elemVal, found := p.value(elem.GetId())
if found && elemVal.Type() == types.OptionalType {
opt := elemVal.(*types.Optional)
if !opt.HasValue() {
return nil, true
}
if newElem, pruned := p.maybeCreateLiteral(elem.GetId(), opt.GetValue()); pruned {
return newElem, true
}
}
return elem, false
}

func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
val, valueExists := p.value(call.GetArgs()[1].GetId())
if !valueExists {
v, exists := p.value(call.GetArgs()[1].GetId())
if !exists || types.IsUnknownOrError(v) {
return nil, false
}
if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero {
if sz, ok := v.(traits.Sizer); ok && sz.Size() == types.IntZero {
return p.maybeCreateLiteral(node.GetId(), types.False)
}
return nil, false
}

func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
}
call := node.GetCallExpr()
arg := call.GetArgs()[0]
v, exists := p.value(arg.GetId())
if !exists {
if !exists || types.IsUnknownOrError(v) {
return nil, false
}
if b, ok := v.(types.Bool); ok {
Expand All @@ -228,37 +252,28 @@ func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool)
}

func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
}

call := node.GetCallExpr()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if p.existsWithKnownValue(call.Args[0].GetId()) {
return call.Args[1], true
if p.existsWithKnownValue(call.GetArgs()[0].GetId()) {
return call.GetArgs()[1], true
}
if p.existsWithKnownValue(call.Args[1].GetId()) {
return call.Args[0], true
if p.existsWithKnownValue(call.GetArgs()[1].GetId()) {
return call.GetArgs()[0], true
}
return nil, false
}

func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
}

call := node.GetCallExpr()
condVal, condValueExists := p.value(call.Args[0].GetId())
if !condValueExists || types.IsUnknownOrError(condVal) {
cond, exists := p.value(call.GetArgs()[0].GetId())
if !exists || types.IsUnknownOrError(cond) {
return nil, false
}

if condVal.Value().(bool) {
return call.Args[1], true
if cond.Value().(bool) {
return call.GetArgs()[1], true
}
return call.Args[2], true
return call.GetArgs()[2], true
}

func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) {
Expand All @@ -279,23 +294,28 @@ func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) {
}

func (p *astPruner) maybePrune(node *exprpb.Expr) (*exprpb.Expr, bool) {
out, pruned := p.prune(node)
if pruned {
delete(p.macroCalls, node.GetId())
}
return out, pruned
return p.prune(node)
}

func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
if node == nil {
return node, false
}
val, valueExists := p.value(node.GetId())
if valueExists && !types.IsUnknownOrError(val) {
if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok {
v, exists := p.value(node.GetId())
if exists && !types.IsUnknownOrError(v) {
if newNode, ok := p.maybeCreateLiteral(node.GetId(), v); ok {
// if the macro completely evaluated, then delete the reference to it, if one exists.
delete(p.macroCalls, node.GetId())
// return the literal value.
return newNode, true
}
}
if macro, found := p.macroCalls[node.GetId()]; found {
// prune the expression in terms of the macro call instead of the expanded form.
if newMacro, pruned := p.prune(macro); pruned {
p.macroCalls[node.GetId()] = newMacro
}
}

// We have either an unknown/error value, or something we don't want to
// transform, or expression was not evaluated. If possible, drill down
Expand Down Expand Up @@ -351,21 +371,51 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}
case *exprpb.Expr_ListExpr:
elems := node.GetListExpr().GetElements()
newElems := make([]*exprpb.Expr, len(elems))
optIndices := node.GetListExpr().GetOptionalIndices()
optIndexMap := map[int32]bool{}
for _, i := range optIndices {
optIndexMap[i] = true
}
newOptIndexMap := make(map[int32]bool, len(optIndexMap))
newElems := make([]*exprpb.Expr, 0, len(elems))
var prunedList bool

prunedIdx := 0
for i, elem := range elems {
newElems[i] = elem
_, isOpt := optIndexMap[int32(i)]
if isOpt {
newElem, pruned := p.maybePruneOptional(elem)
if pruned {
if newElem != nil {
newElems = append(newElems, newElem)
prunedIdx++
}
prunedList = true
continue
}
newOptIndexMap[int32(prunedIdx)] = true
}
if newElem, prunedElem := p.maybePrune(elem); prunedElem {
newElems[i] = newElem
newElems = append(newElems, newElem)
prunedList = true
} else {
newElems = append(newElems, elem)
}
prunedIdx++
}
optIndices = make([]int32, len(newOptIndexMap))
idx := 0
for i := range newOptIndexMap {
optIndices[idx] = i
idx++
}
if prunedList {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: newElems,
Elements: newElems,
OptionalIndices: optIndices,
},
},
}, true
Expand Down Expand Up @@ -395,6 +445,7 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
MapKey: newKey,
}
}
newEntry.OptionalEntry = entry.GetOptionalEntry()
newEntries[i] = newEntry
}
if prunedStruct {
Expand All @@ -408,27 +459,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
},
}, true
}
case *exprpb.Expr_ComprehensionExpr:
compre := node.GetComprehensionExpr()
// Only the range of the comprehension is pruned since the state tracking only records
// the last iteration of the comprehension and not each step in the evaluation which
// means that the any residuals computed in between might be inaccurate.
if newRange, pruned := p.maybePrune(compre.GetIterRange()); pruned {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: compre.GetIterVar(),
IterRange: newRange,
AccuVar: compre.GetAccuVar(),
AccuInit: compre.GetAccuInit(),
LoopCondition: compre.GetLoopCondition(),
LoopStep: compre.GetLoopStep(),
Result: compre.GetResult(),
},
},
}, true
}
}
return node, false
}
Expand All @@ -438,14 +468,9 @@ func (p *astPruner) value(id int64) (ref.Val, bool) {
return val, (found && val != nil)
}

func (p *astPruner) existsWithUnknownValue(id int64) bool {
val, valueExists := p.value(id)
return valueExists && types.IsUnknown(val)
}

func (p *astPruner) existsWithKnownValue(id int64) bool {
val, valueExists := p.value(id)
return valueExists && !types.IsUnknown(val)
return valueExists && !types.IsUnknownOrError(val)
}

func (p *astPruner) nextID() int64 {
Expand All @@ -460,14 +485,39 @@ func (p *astPruner) nextID() int64 {
}
}

type astVisitor struct {
// visitEntry is called on every expr node, including those within a map/struct entry.
visitExpr func(expr *exprpb.Expr)
// visitEntry is called before entering the key, value of a map/struct entry.
visitEntry func(entry *exprpb.Expr_CreateStruct_Entry)
}

func getMaxID(expr *exprpb.Expr) int64 {
maxID := int64(1)
visit(expr, maxIDVisitor(&maxID))
return maxID
}

func maxIDVisitor(maxID *int64) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
}
},
visitEntry: func(e *exprpb.Expr_CreateStruct_Entry) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
}
},
}
}

func visit(expr *exprpb.Expr, visitor astVisitor) {
exprs := []*exprpb.Expr{expr}
for len(exprs) != 0 {
e := exprs[0]
if e.GetId() >= maxID {
maxID = e.GetId() + 1
}
visitor.visitExpr(e)
exprs = exprs[1:]
switch e.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
Expand All @@ -490,16 +540,13 @@ func getMaxID(expr *exprpb.Expr) int64 {
list := e.GetListExpr()
exprs = append(exprs, list.GetElements()...)
case *exprpb.Expr_StructExpr:
for _, entry := range expr.GetStructExpr().GetEntries() {
for _, entry := range e.GetStructExpr().GetEntries() {
visitor.visitEntry(entry)
if entry.GetMapKey() != nil {
exprs = append(exprs, entry.GetMapKey())
}
exprs = append(exprs, entry.GetValue())
if entry.GetId() >= maxID {
maxID = entry.GetId() + 1
}
}
}
}
return maxID
}
Loading

0 comments on commit b3f4007

Please sign in to comment.