From b70bf41c211da875242de8ac7fd5fd442dddcd46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Charles-Edouard=20Br=C3=A9t=C3=A9ch=C3=A9?= Date: Fri, 20 Sep 2024 09:46:55 +0200 Subject: [PATCH] fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Charles-Edouard Brétéché --- pkg/core/templating/cel/cel.go | 18 +++++++++ pkg/core/templating/cel/env.go | 41 +++++++++++++++++++ pkg/core/templating/cel/val.go | 48 ++++++++++++++++++++++ pkg/core/templating/compiler.go | 72 ++------------------------------- pkg/policy/load_test.go | 3 +- 5 files changed, 112 insertions(+), 70 deletions(-) create mode 100644 pkg/core/templating/cel/cel.go create mode 100644 pkg/core/templating/cel/env.go create mode 100644 pkg/core/templating/cel/val.go diff --git a/pkg/core/templating/cel/cel.go b/pkg/core/templating/cel/cel.go new file mode 100644 index 00000000..fb2e31f2 --- /dev/null +++ b/pkg/core/templating/cel/cel.go @@ -0,0 +1,18 @@ +package cel + +import ( + "github.com/google/cel-go/cel" + "github.com/jmespath-community/go-jmespath/pkg/binding" +) + +func Execute(program cel.Program, value any, bindings binding.Bindings) (any, error) { + data := map[string]interface{}{ + "object": value, + "bindings": NewVal(bindings, BindingsType), + } + out, _, err := program.Eval(data) + if err != nil { + return nil, err + } + return out.Value(), nil +} diff --git a/pkg/core/templating/cel/env.go b/pkg/core/templating/cel/env.go new file mode 100644 index 00000000..e3e0be81 --- /dev/null +++ b/pkg/core/templating/cel/env.go @@ -0,0 +1,41 @@ +package cel + +import ( + "sync" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/jmespath-community/go-jmespath/pkg/binding" +) + +var ( + BindingsType = cel.OpaqueType("bindings") + DefaultEnv = sync.OnceValues(func() (*cel.Env, error) { + return cel.NewEnv( + cel.Variable("object", cel.DynType), + cel.Variable("bindings", BindingsType), + cel.Function("resolve", + cel.MemberOverload("bindings_resolve_string", + []*cel.Type{BindingsType, cel.StringType}, + cel.AnyType, + cel.BinaryBinding(func(lhs, rhs ref.Val) ref.Val { + bindings, ok := lhs.(Val[binding.Bindings]) + if !ok { + return types.ValOrErr(bindings, "invalid bindings type") + } + name, ok := rhs.(types.String) + if !ok { + return types.ValOrErr(name, "invalid name type") + } + value, err := binding.Resolve("$"+string(name), bindings.Unwrap()) + if err != nil { + return types.WrapErr(err) + } + return types.DefaultTypeAdapter.NativeToValue(value) + }), + ), + ), + ) + }) +) diff --git a/pkg/core/templating/cel/val.go b/pkg/core/templating/cel/val.go new file mode 100644 index 00000000..a2145844 --- /dev/null +++ b/pkg/core/templating/cel/val.go @@ -0,0 +1,48 @@ +package cel + +import ( + "reflect" + + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +type Val[T comparable] struct { + inner T + celType ref.Type +} + +func NewVal[T comparable](value T, celType ref.Type) Val[T] { + return Val[T]{ + inner: value, + celType: celType, + } +} + +func (w Val[T]) Unwrap() T { + return w.inner +} + +func (w Val[T]) Value() interface{} { + return w.Unwrap() +} + +func (w Val[T]) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { + panic("not required") +} + +func (w Val[T]) ConvertToType(typeVal ref.Type) ref.Val { + panic("not required") +} + +func (w Val[T]) Equal(other ref.Val) ref.Val { + o, ok := other.Value().(Val[T]) + if !ok { + return types.ValOrErr(other, "no such overload") + } + return types.Bool(o == w) +} + +func (w Val[T]) Type() ref.Type { + return w.celType +} diff --git a/pkg/core/templating/compiler.go b/pkg/core/templating/compiler.go index d936b541..7d9e1c04 100644 --- a/pkg/core/templating/compiler.go +++ b/pkg/core/templating/compiler.go @@ -1,16 +1,11 @@ package templating import ( - "reflect" - "sync" - - "github.com/google/cel-go/cel" - "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" "github.com/jmespath-community/go-jmespath/pkg/binding" "github.com/jmespath-community/go-jmespath/pkg/interpreter" "github.com/jmespath-community/go-jmespath/pkg/parsing" "github.com/kyverno/kyverno-json/pkg/core/expression" + "github.com/kyverno/kyverno-json/pkg/core/templating/cel" "github.com/kyverno/kyverno-json/pkg/core/templating/jp" "k8s.io/apimachinery/pkg/util/validation/field" ) @@ -40,58 +35,8 @@ func (c Compiler) Options() CompilerOptions { return c.options } -var bindingsType = cel.OpaqueType("bindings") - -type b struct { - binding.Bindings -} - -func (b b) Value() interface{} { - return b -} -func (b b) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { - panic("not required") -} - -func (b b) ConvertToType(typeVal ref.Type) ref.Val { - panic("not required") -} -func (x b) Equal(other ref.Val) ref.Val { - o, ok := other.Value().(b) - if !ok { - return types.ValOrErr(other, "no such overload xxx") - } - return types.Bool(o == x) -} - -func (b b) Type() ref.Type { - return bindingsType -} - -var newEnv = sync.OnceValues(func() (*cel.Env, error) { - return cel.NewEnv( - cel.Variable("object", cel.DynType), - cel.Variable("bindings", bindingsType), - cel.Function("resolve", - cel.MemberOverload("bindings_resolve_string", - []*cel.Type{bindingsType, cel.StringType}, - cel.AnyType, - cel.BinaryBinding(func(lhs, rhs ref.Val) ref.Val { - bindings := lhs.(b) - name := rhs.(types.String) - value, err := binding.Resolve("$"+string(name), bindings) - if err != nil { - return types.WrapErr(err) - } - return types.DefaultTypeAdapter.NativeToValue(value) - }), - ), - ), - ) -}) - func (c Compiler) CompileCEL(statement string) (Program, error) { - env, err := newEnv() + env, err := cel.DefaultEnv() if err != nil { return nil, err } @@ -99,21 +44,12 @@ func (c Compiler) CompileCEL(statement string) (Program, error) { if iss.Err() != nil { return nil, iss.Err() } - prg, err := env.Program(ast) + program, err := env.Program(ast) if err != nil { return nil, err } return func(value any, bindings binding.Bindings) (any, error) { - out, _, err := prg.Eval( - map[string]interface{}{ - "object": value, - "bindings": b{bindings}, - }, - ) - if err != nil { - return nil, err - } - return out.Value(), nil + return cel.Execute(program, value, bindings) }, nil } diff --git a/pkg/policy/load_test.go b/pkg/policy/load_test.go index 80826316..dd06dbef 100644 --- a/pkg/policy/load_test.go +++ b/pkg/policy/load_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/kyverno/kyverno-json/pkg/apis/policy/v1alpha1" "github.com/stretchr/testify/assert" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -124,7 +123,7 @@ func TestLoad(t *testing.T) { } else { assert.NoError(t, err) } - assert.True(t, cmp.Equal(tt.want, got, cmp.AllowUnexported(v1alpha1.AssertionTree{}), cmpopts.IgnoreFields(v1alpha1.AssertionTree{}, "_assertion"))) + assert.True(t, cmp.Equal(tt.want, got, cmp.AllowUnexported(v1alpha1.AssertionTree{}))) }) } }