Skip to content

Commit

Permalink
feat: add cel language support (#483)
Browse files Browse the repository at this point in the history
* feat: add cel language support

Signed-off-by: Charles-Edouard Brétéché <[email protected]>

* go mod

Signed-off-by: Charles-Edouard Brétéché <[email protected]>

* bindings

Signed-off-by: Charles-Edouard Brétéché <[email protected]>

* constant

Signed-off-by: Charles-Edouard Brétéché <[email protected]>

* fix

Signed-off-by: Charles-Edouard Brétéché <[email protected]>

* cel

Signed-off-by: Charles-Edouard Brétéché <[email protected]>

* cel binding

Signed-off-by: Charles-Edouard Brétéché <[email protected]>

* fix

Signed-off-by: Charles-Edouard Brétéché <[email protected]>

* templating

Signed-off-by: Charles-Edouard Brétéché <[email protected]>

* fix

Signed-off-by: Charles-Edouard Brétéché <[email protected]>

---------

Signed-off-by: Charles-Edouard Brétéché <[email protected]>
  • Loading branch information
eddycharly authored Sep 20, 2024
1 parent 4c135bc commit 06f036d
Show file tree
Hide file tree
Showing 29 changed files with 396 additions and 186 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/blang/semver/v4 v4.0.0
github.com/gin-contrib/cors v1.7.2
github.com/gin-gonic/gin v1.10.0
github.com/google/cel-go v0.17.8
github.com/google/go-cmp v0.6.0
github.com/jmespath-community/go-jmespath v1.1.2-0.20240117150817-e430401a2172
github.com/kyverno/pkg/ext v0.0.0-20240418121121-df8add26c55c
Expand Down Expand Up @@ -62,7 +63,6 @@ require (
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/cel-go v0.17.8 // indirect
github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/gxui v0.0.0-20151028112939-f85e0a97b3a4 // indirect
Expand Down
18 changes: 4 additions & 14 deletions pkg/apis/policy/v1alpha1/assertion_tree.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package v1alpha1

import (
"context"
"sync"

"github.com/kyverno/kyverno-json/pkg/core/assertion"
"github.com/kyverno/kyverno-json/pkg/core/templating"
"k8s.io/apimachinery/pkg/util/json"
)

Expand All @@ -13,24 +11,20 @@ import (
// +kubebuilder:validation:Type:=""
// AssertionTree represents an assertion tree.
type AssertionTree struct {
_tree any
_assertion func() (assertion.Assertion, error)
_tree any
}

func NewAssertionTree(value any) AssertionTree {
return AssertionTree{
_tree: value,
_assertion: sync.OnceValues(func() (assertion.Assertion, error) {
return assertion.Parse(context.Background(), value)
}),
}
}

func (t *AssertionTree) Assertion() (assertion.Assertion, error) {
func (t *AssertionTree) Assertion(compiler templating.Compiler) (assertion.Assertion, error) {
if t._tree == nil {
return nil, nil
}
return t._assertion()
return assertion.Parse(t._tree, compiler)
}

func (a *AssertionTree) MarshalJSON() ([]byte, error) {
Expand All @@ -44,13 +38,9 @@ func (a *AssertionTree) UnmarshalJSON(data []byte) error {
return err
}
a._tree = v
a._assertion = sync.OnceValues(func() (assertion.Assertion, error) {
return assertion.Parse(context.Background(), v)
})
return nil
}

func (in *AssertionTree) DeepCopyInto(out *AssertionTree) {
out._tree = deepCopy(in._tree)
out._assertion = in._assertion
}
6 changes: 3 additions & 3 deletions pkg/commands/jp/query/command.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package query

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand All @@ -11,7 +10,7 @@ import (

"github.com/jmespath-community/go-jmespath/pkg/parsing"
"github.com/kyverno/kyverno-json/pkg/command"
"github.com/kyverno/kyverno-json/pkg/engine/template"
"github.com/kyverno/kyverno-json/pkg/core/templating"
"github.com/spf13/cobra"
"sigs.k8s.io/yaml"
)
Expand Down Expand Up @@ -156,7 +155,8 @@ func loadInput(cmd *cobra.Command, file string) (any, error) {
}

func evaluate(input any, query string) (any, error) {
result, err := template.ExecuteJP(context.Background(), query, input, nil)
compiler := templating.NewCompiler(templating.CompilerOptions{})
result, err := templating.ExecuteJP(query, input, nil, compiler)
if err != nil {
if syntaxError, ok := err.(parsing.SyntaxError); ok {
return nil, fmt.Errorf("%s\n%s", syntaxError, syntaxError.HighlightLocation())
Expand Down
6 changes: 6 additions & 0 deletions pkg/commands/scan/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ func Test_Execute(t *testing.T) {
policies: []string{"../../../test/commands/scan/foo-bar/policy.yaml"},
out: "../../../test/commands/scan/foo-bar/out.txt",
wantErr: false,
}, {
name: "cel",
payload: "../../../test/commands/scan/cel/payload.yaml",
policies: []string{"../../../test/commands/scan/cel/policy.yaml"},
out: "../../../test/commands/scan/cel/out.txt",
wantErr: false,
}, {
name: "wildcard",
payload: "../../../test/commands/scan/wildcard/payload.json",
Expand Down
5 changes: 3 additions & 2 deletions pkg/commands/scan/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"strings"

"github.com/kyverno/kyverno-json/pkg/apis/policy/v1alpha1"
"github.com/kyverno/kyverno-json/pkg/engine/template"
"github.com/kyverno/kyverno-json/pkg/core/templating"
jsonengine "github.com/kyverno/kyverno-json/pkg/json-engine"
"github.com/kyverno/kyverno-json/pkg/payload"
"github.com/kyverno/kyverno-json/pkg/policy"
Expand Down Expand Up @@ -76,8 +76,9 @@ func (c *options) run(cmd *cobra.Command, _ []string) error {
return errors.New("payload is `null`")
}
out.println("Pre processing ...")
compiler := templating.NewCompiler(templating.CompilerOptions{})
for _, preprocessor := range c.preprocessors {
result, err := template.ExecuteJP(context.Background(), preprocessor, payload, nil)
result, err := templating.ExecuteJP(preprocessor, payload, nil, compiler)
if err != nil {
return err
}
Expand Down
73 changes: 38 additions & 35 deletions pkg/core/assertion/assertion.go
Original file line number Diff line number Diff line change
@@ -1,58 +1,56 @@
package assertion

import (
"context"
"errors"
"fmt"
"reflect"
"sync"

"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/jmespath-community/go-jmespath/pkg/parsing"
"github.com/kyverno/kyverno-json/pkg/core/expression"
"github.com/kyverno/kyverno-json/pkg/core/projection"
"github.com/kyverno/kyverno-json/pkg/core/templating"
"github.com/kyverno/kyverno-json/pkg/engine/match"
"github.com/kyverno/kyverno-json/pkg/engine/template"
reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect"
"k8s.io/apimachinery/pkg/util/validation/field"
)

type Assertion interface {
Assert(context.Context, *field.Path, any, binding.Bindings, ...template.Option) (field.ErrorList, error)
Assert(*field.Path, any, binding.Bindings) (field.ErrorList, error)
}

func Parse(ctx context.Context, assertion any) (node, error) {
func Parse(assertion any, compiler templating.Compiler) (node, error) {
switch reflectutils.GetKind(assertion) {
case reflect.Slice:
return parseSlice(ctx, assertion)
return parseSlice(assertion, compiler)
case reflect.Map:
return parseMap(ctx, assertion)
return parseMap(assertion, compiler)
default:
return parseScalar(ctx, assertion)
return parseScalar(assertion, compiler)
}
}

// node implements the Assertion interface using a delegate func
type node func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error)
type node func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error)

func (n node) Assert(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
return n(ctx, path, value, bindings, opts...)
func (n node) Assert(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
return n(path, value, bindings)
}

// parseSlice is the assertion represented by a slice.
// it first compares the length of the analysed resource with the length of the descendants.
// if lengths match all descendants are evaluated with their corresponding items.
func parseSlice(ctx context.Context, assertion any) (node, error) {
func parseSlice(assertion any, compiler templating.Compiler) (node, error) {
var assertions []node
valueOf := reflect.ValueOf(assertion)
for i := 0; i < valueOf.Len(); i++ {
sub, err := Parse(ctx, valueOf.Index(i).Interface())
sub, err := Parse(valueOf.Index(i).Interface(), compiler)
if err != nil {
return nil, err
}
assertions = append(assertions, sub)
}
return func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
var errs field.ErrorList
if value == nil {
errs = append(errs, field.Invalid(path, value, "value is null"))
Expand All @@ -64,7 +62,7 @@ func parseSlice(ctx context.Context, assertion any) (node, error) {
errs = append(errs, field.Invalid(path, value, "lengths of slices don't match"))
} else {
for i := range assertions {
if _errs, err := assertions[i].Assert(ctx, path.Index(i), valueOf.Index(i).Interface(), bindings, opts...); err != nil {
if _errs, err := assertions[i].Assert(path.Index(i), valueOf.Index(i).Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
Expand All @@ -78,7 +76,7 @@ func parseSlice(ctx context.Context, assertion any) (node, error) {

// parseMap is the assertion represented by a map.
// it is responsible for projecting the analysed resource and passing the result to the descendant
func parseMap(ctx context.Context, assertion any) (node, error) {
func parseMap(assertion any, compiler templating.Compiler) (node, error) {
assertions := map[any]struct {
projection.Projection
node
Expand All @@ -87,16 +85,16 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
for iter.Next() {
key := iter.Key().Interface()
value := iter.Value().Interface()
assertion, err := Parse(ctx, value)
assertion, err := Parse(value, compiler)
if err != nil {
return nil, err
}
entry := assertions[key]
entry.node = assertion
entry.Projection = projection.Parse(key)
entry.Projection = projection.Parse(key, compiler)
assertions[key] = entry
}
return func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
var errs field.ErrorList
// if we assert against an empty object, value is expected to be not nil
if len(assertions) == 0 {
Expand All @@ -106,7 +104,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
return errs, nil
}
for k, v := range assertions {
projected, found, err := v.Projection.Handler(ctx, value, bindings, opts...)
projected, found, err := v.Projection.Handler(value, bindings)
if err != nil {
return nil, field.InternalError(path.Child(fmt.Sprint(k)), err)
} else if !found {
Expand All @@ -124,7 +122,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
if v.Projection.ForeachName != "" {
bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(i))
}
if _errs, err := v.Assert(ctx, path.Child(fmt.Sprint(k)).Index(i), valueOf.Index(i).Interface(), bindings, opts...); err != nil {
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Index(i), valueOf.Index(i).Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
Expand All @@ -138,7 +136,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
if v.Projection.ForeachName != "" {
bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(key))
}
if _errs, err := v.Assert(ctx, path.Child(fmt.Sprint(k)).Key(fmt.Sprint(key)), iter.Value().Interface(), bindings, opts...); err != nil {
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Key(fmt.Sprint(key)), iter.Value().Interface(), bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
Expand All @@ -148,7 +146,7 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
return nil, field.TypeInvalid(path.Child(fmt.Sprint(k)), projected, "expected a slice or a map")
}
} else {
if _errs, err := v.Assert(ctx, path.Child(fmt.Sprint(k)), projected, bindings, opts...); err != nil {
if _errs, err := v.Assert(path.Child(fmt.Sprint(k)), projected, bindings); err != nil {
return nil, err
} else {
errs = append(errs, _errs...)
Expand All @@ -163,8 +161,8 @@ func parseMap(ctx context.Context, assertion any) (node, error) {
// parseScalar is the assertion represented by a leaf.
// it receives a value and compares it with an expected value.
// the expected value can be the result of an expression.
func parseScalar(_ context.Context, assertion any) (node, error) {
var project func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, error)
func parseScalar(assertion any, compiler templating.Compiler) (node, error) {
var project func(value any, bindings binding.Bindings) (any, error)
switch typed := assertion.(type) {
case string:
expr := expression.Parse(typed)
Expand All @@ -176,34 +174,39 @@ func parseScalar(_ context.Context, assertion any) (node, error) {
}
switch expr.Engine {
case expression.EngineJP:
parse := sync.OnceValues(func() (parsing.ASTNode, error) {
parser := parsing.NewParser()
return parser.Parse(expr.Statement)
parse := sync.OnceValues(func() (templating.Program, error) {
return compiler.CompileJP(expr.Statement)
})
project = func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, error) {
ast, err := parse()
project = func(value any, bindings binding.Bindings) (any, error) {
program, err := parse()
if err != nil {
return nil, err
}
return template.ExecuteAST(ctx, ast, value, bindings, opts...)
return program(value, bindings)
}
case expression.EngineCEL:
return nil, errors.New("engine not supported")
project = func(value any, bindings binding.Bindings) (any, error) {
program, err := compiler.CompileCEL(expr.Statement)
if err != nil {
return nil, err
}
return program(value, bindings)
}
default:
assertion = expr.Statement
}
}
return func(ctx context.Context, path *field.Path, value any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) {
expected := assertion
if project != nil {
projected, err := project(ctx, value, bindings, opts...)
projected, err := project(value, bindings)
if err != nil {
return nil, field.InternalError(path, err)
}
expected = projected
}
var errs field.ErrorList
if match, err := match.Match(ctx, expected, value); err != nil {
if match, err := match.Match(expected, value); err != nil {
return nil, field.InternalError(path, err)
} else if !match {
errs = append(errs, field.Invalid(path, value, expectValueMessage(expected)))
Expand Down
7 changes: 4 additions & 3 deletions pkg/core/assertion/assertion_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package assertion

import (
"context"
"testing"

"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/kyverno/kyverno-json/pkg/core/templating"
tassert "github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/util/validation/field"
)
Expand Down Expand Up @@ -48,9 +48,10 @@ func TestAssert(t *testing.T) {
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsed, err := Parse(context.TODO(), tt.assertion)
compiler := templating.NewCompiler(templating.CompilerOptions{})
parsed, err := Parse(tt.assertion, compiler)
tassert.NoError(t, err)
got, err := parsed.Assert(context.TODO(), nil, tt.value, tt.bindings)
got, err := parsed.Assert(nil, tt.value, tt.bindings)
if tt.wantErr {
tassert.Error(t, err)
} else {
Expand Down
Loading

0 comments on commit 06f036d

Please sign in to comment.