diff --git a/pkg/core/assertion/assertion.go b/pkg/core/assertion/assertion.go index a6d01b0e..dff71cfd 100644 --- a/pkg/core/assertion/assertion.go +++ b/pkg/core/assertion/assertion.go @@ -1,14 +1,11 @@ package assertion import ( - "errors" "fmt" "reflect" - "sync" "github.com/jmespath-community/go-jmespath/pkg/binding" "github.com/kyverno/kyverno-json/pkg/core/compilers" - "github.com/kyverno/kyverno-json/pkg/core/expression" "github.com/kyverno/kyverno-json/pkg/core/matching" "github.com/kyverno/kyverno-json/pkg/core/projection" reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect" @@ -19,7 +16,7 @@ type Assertion interface { Assert(*field.Path, any, binding.Bindings) (field.ErrorList, error) } -func Parse(assertion any, compiler compilers.Compilers) (node, error) { +func Parse(assertion any, compiler compilers.Compilers) (Assertion, error) { switch reflectutils.GetKind(assertion) { case reflect.Slice: return parseSlice(assertion, compiler) @@ -30,18 +27,36 @@ func Parse(assertion any, compiler compilers.Compilers) (node, error) { } } -// node implements the Assertion interface using a delegate func -type node func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) +// sliceNode is the assertion represented by a slice. +// it first compares the length of the analysed resource with the length of the descendants. +// if lengths match all descendants are evaluated with their corresponding items. +type sliceNode []Assertion -func (n node) Assert(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { - return n(path, value, bindings) +func (node sliceNode) Assert(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { + var errs field.ErrorList + if value == nil { + errs = append(errs, field.Invalid(path, value, "value is null")) + } else if reflectutils.GetKind(value) != reflect.Slice { + return nil, field.TypeInvalid(path, value, "expected a slice") + } else { + valueOf := reflect.ValueOf(value) + if valueOf.Len() != len(node) { + errs = append(errs, field.Invalid(path, value, "lengths of slices don't match")) + } else { + for i := range node { + if _errs, err := node[i].Assert(path.Index(i), valueOf.Index(i).Interface(), bindings); err != nil { + return nil, err + } else { + errs = append(errs, _errs...) + } + } + } + } + return errs, nil } -// parseSlice is the assertion represented by a slice. -// it first compares the length of the analysed resource with the length of the descendants. -// if lengths match all descendants are evaluated with their corresponding items. -func parseSlice(assertion any, compiler compilers.Compilers) (node, error) { - var assertions []node +func parseSlice(assertion any, compiler compilers.Compilers) (sliceNode, error) { + var assertions sliceNode valueOf := reflect.ValueOf(assertion) for i := 0; i < valueOf.Len(); i++ { sub, err := Parse(valueOf.Index(i).Interface(), compiler) @@ -50,37 +65,81 @@ func parseSlice(assertion any, compiler compilers.Compilers) (node, error) { } assertions = append(assertions, sub) } - return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { - var errs field.ErrorList + return assertions, nil +} + +// mapNode is the assertion represented by a map. +// it is responsible for projecting the analysed resource and passing the result to the descendant +type mapNode map[any]struct { + projection.Projection + Assertion +} + +func (node mapNode) Assert(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { + var errs field.ErrorList + // if we assert against an empty object, value is expected to be not nil + if len(node) == 0 { if value == nil { - errs = append(errs, field.Invalid(path, value, "value is null")) - } else if reflectutils.GetKind(value) != reflect.Slice { - return nil, field.TypeInvalid(path, value, "expected a slice") + errs = append(errs, field.Invalid(path, value, "invalid value, must not be null")) + } + return errs, nil + } + for k, v := range node { + projected, found, err := v.Projection.Handler(value, bindings) + if err != nil { + return nil, field.InternalError(path.Child(fmt.Sprint(k)), err) + } else if !found { + errs = append(errs, field.Required(path.Child(fmt.Sprint(k)), "field not found in the input object")) } else { - valueOf := reflect.ValueOf(value) - if valueOf.Len() != len(assertions) { - errs = append(errs, field.Invalid(path, value, "lengths of slices don't match")) - } else { - for i := range assertions { - if _errs, err := assertions[i].Assert(path.Index(i), valueOf.Index(i).Interface(), bindings); err != nil { - return nil, err - } else { - errs = append(errs, _errs...) + if v.Projection.Binding != "" { + bindings = bindings.Register("$"+v.Projection.Binding, binding.NewBinding(projected)) + } + if v.Projection.Foreach { + projectedKind := reflectutils.GetKind(projected) + if projectedKind == reflect.Slice { + valueOf := reflect.ValueOf(projected) + for i := 0; i < valueOf.Len(); i++ { + bindings := bindings + if v.Projection.ForeachName != "" { + bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(i)) + } + if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Index(i), valueOf.Index(i).Interface(), bindings); err != nil { + return nil, err + } else { + errs = append(errs, _errs...) + } } + } else if projectedKind == reflect.Map { + iter := reflect.ValueOf(projected).MapRange() + for iter.Next() { + key := iter.Key().Interface() + bindings := bindings + if v.Projection.ForeachName != "" { + bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(key)) + } + if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Key(fmt.Sprint(key)), iter.Value().Interface(), bindings); err != nil { + return nil, err + } else { + errs = append(errs, _errs...) + } + } + } else { + return nil, field.TypeInvalid(path.Child(fmt.Sprint(k)), projected, "expected a slice or a map") + } + } else { + if _errs, err := v.Assert(path.Child(fmt.Sprint(k)), projected, bindings); err != nil { + return nil, err + } else { + errs = append(errs, _errs...) } } } - return errs, nil - }, nil + } + return errs, nil } -// parseMap is the assertion represented by a map. -// it is responsible for projecting the analysed resource and passing the result to the descendant -func parseMap(assertion any, compiler compilers.Compilers) (node, error) { - assertions := map[any]struct { - projection.Projection - node - }{} +func parseMap(assertion any, compiler compilers.Compilers) (mapNode, error) { + assertions := mapNode{} iter := reflect.ValueOf(assertion).MapRange() for iter.Next() { key := iter.Key().Interface() @@ -90,120 +149,36 @@ func parseMap(assertion any, compiler compilers.Compilers) (node, error) { return nil, err } entry := assertions[key] - entry.node = assertion + entry.Assertion = assertion entry.Projection = projection.ParseMapKey(key, compiler) assertions[key] = entry } - return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { - var errs field.ErrorList - // if we assert against an empty object, value is expected to be not nil - if len(assertions) == 0 { - if value == nil { - errs = append(errs, field.Invalid(path, value, "invalid value, must not be null")) - } - return errs, nil - } - for k, v := range assertions { - projected, found, err := v.Projection.Handler(value, bindings) - if err != nil { - return nil, field.InternalError(path.Child(fmt.Sprint(k)), err) - } else if !found { - errs = append(errs, field.Required(path.Child(fmt.Sprint(k)), "field not found in the input object")) - } else { - if v.Projection.Binding != "" { - bindings = bindings.Register("$"+v.Projection.Binding, binding.NewBinding(projected)) - } - if v.Projection.Foreach { - projectedKind := reflectutils.GetKind(projected) - if projectedKind == reflect.Slice { - valueOf := reflect.ValueOf(projected) - for i := 0; i < valueOf.Len(); i++ { - bindings := bindings - if v.Projection.ForeachName != "" { - bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(i)) - } - if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Index(i), valueOf.Index(i).Interface(), bindings); err != nil { - return nil, err - } else { - errs = append(errs, _errs...) - } - } - } else if projectedKind == reflect.Map { - iter := reflect.ValueOf(projected).MapRange() - for iter.Next() { - key := iter.Key().Interface() - bindings := bindings - if v.Projection.ForeachName != "" { - bindings = bindings.Register("$"+v.Projection.ForeachName, binding.NewBinding(key)) - } - if _errs, err := v.Assert(path.Child(fmt.Sprint(k)).Key(fmt.Sprint(key)), iter.Value().Interface(), bindings); err != nil { - return nil, err - } else { - errs = append(errs, _errs...) - } - } - } else { - return nil, field.TypeInvalid(path.Child(fmt.Sprint(k)), projected, "expected a slice or a map") - } - } else { - if _errs, err := v.Assert(path.Child(fmt.Sprint(k)), projected, bindings); err != nil { - return nil, err - } else { - errs = append(errs, _errs...) - } - } - } - } - return errs, nil - }, nil + return assertions, nil } -// parseScalar is the assertion represented by a leaf. +// scalarNode is the assertion represented by a leaf. // it receives a value and compares it with an expected value. // the expected value can be the result of an expression. -func parseScalar(assertion any, compiler compilers.Compilers) (node, error) { - var project func(value any, bindings binding.Bindings) (any, error) - switch typed := assertion.(type) { - case string: - expr := expression.Parse(typed) - if expr.Foreach { - return nil, errors.New("foreach is not supported on the RHS") - } - if expr.Binding != "" { - return nil, errors.New("binding is not supported on the RHS") - } - if compiler := compiler.Compiler(expr.Compiler); compiler != nil { - parse := sync.OnceValues(func() (compilers.Program, error) { - return compiler.Compile(expr.Statement) - }) - project = func(value any, bindings binding.Bindings) (any, error) { - program, err := parse() - if err != nil { - return nil, err - } - return program(value, bindings) - } - } else { - assertion = expr.Statement - } +type scalarNode projection.ScalarHandler + +func (node scalarNode) Assert(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { + var errs field.ErrorList + if projected, err := node(value, bindings); err != nil { + return nil, field.InternalError(path, err) + } else if match, err := matching.Match(projected, value); err != nil { + return nil, field.InternalError(path, err) + } else if !match { + errs = append(errs, field.Invalid(path, value, expectValueMessage(projected))) } - return func(path *field.Path, value any, bindings binding.Bindings) (field.ErrorList, error) { - expected := assertion - if project != nil { - projected, err := project(value, bindings) - if err != nil { - return nil, field.InternalError(path, err) - } - expected = projected - } - var errs field.ErrorList - if match, err := matching.Match(expected, value); err != nil { - return nil, field.InternalError(path, err) - } else if !match { - errs = append(errs, field.Invalid(path, value, expectValueMessage(expected))) - } - return errs, nil - }, nil + return errs, nil +} + +func parseScalar(in any, compiler compilers.Compilers) (scalarNode, error) { + proj, err := projection.ParseScalar(in, compiler) + if err != nil { + return nil, err + } + return proj, err } func expectValueMessage(value any) string {