diff --git a/pkg/core/assertion/assertion.go b/pkg/core/assertion/assertion.go index dff71cfd..ca50d6c1 100644 --- a/pkg/core/assertion/assertion.go +++ b/pkg/core/assertion/assertion.go @@ -17,13 +17,21 @@ type Assertion interface { } func Parse(assertion any, compiler compilers.Compilers) (Assertion, error) { + out, err := parse(nil, assertion, compiler) + if err != nil { + return nil, err + } + return out, nil +} + +func parse(path *field.Path, assertion any, compiler compilers.Compilers) (Assertion, *field.Error) { switch reflectutils.GetKind(assertion) { case reflect.Slice: - return parseSlice(assertion, compiler) + return parseSlice(path, assertion, compiler) case reflect.Map: - return parseMap(assertion, compiler) + return parseMap(path, assertion, compiler) default: - return parseScalar(assertion, compiler) + return parseScalar(path, assertion, compiler) } } @@ -55,11 +63,12 @@ func (node sliceNode) Assert(path *field.Path, value any, bindings binding.Bindi return errs, nil } -func parseSlice(assertion any, compiler compilers.Compilers) (sliceNode, error) { +func parseSlice(path *field.Path, assertion any, compiler compilers.Compilers) (sliceNode, *field.Error) { var assertions sliceNode valueOf := reflect.ValueOf(assertion) for i := 0; i < valueOf.Len(); i++ { - sub, err := Parse(valueOf.Index(i).Interface(), compiler) + path := path.Index(i) + sub, err := parse(path, valueOf.Index(i).Interface(), compiler) if err != nil { return nil, err } @@ -138,13 +147,14 @@ func (node mapNode) Assert(path *field.Path, value any, bindings binding.Binding return errs, nil } -func parseMap(assertion any, compiler compilers.Compilers) (mapNode, error) { +func parseMap(path *field.Path, assertion any, compiler compilers.Compilers) (mapNode, *field.Error) { assertions := mapNode{} iter := reflect.ValueOf(assertion).MapRange() for iter.Next() { key := iter.Key().Interface() value := iter.Value().Interface() - assertion, err := Parse(value, compiler) + path := path.Child(fmt.Sprint(key)) + assertion, err := parse(path, value, compiler) if err != nil { return nil, err } @@ -173,12 +183,12 @@ func (node scalarNode) Assert(path *field.Path, value any, bindings binding.Bind return errs, nil } -func parseScalar(in any, compiler compilers.Compilers) (scalarNode, error) { +func parseScalar(path *field.Path, in any, compiler compilers.Compilers) (scalarNode, *field.Error) { proj, err := projection.ParseScalar(in, compiler) if err != nil { - return nil, err + return nil, field.InternalError(path, err) } - return proj, err + return proj, nil } func expectValueMessage(value any) string { diff --git a/pkg/core/assertion/assertion_test.go b/pkg/core/assertion/assertion_test.go index c4fa6195..5ca2be70 100644 --- a/pkg/core/assertion/assertion_test.go +++ b/pkg/core/assertion/assertion_test.go @@ -61,3 +61,32 @@ func TestAssert(t *testing.T) { }) } } + +func TestParse(t *testing.T) { + tests := []struct { + name string + assertion any + want field.ErrorList + wantErr bool + }{{ + name: "bad scalar", + assertion: map[string]any{ + "foo": map[string]any{ + "bar": "~.(`42`)", + }, + }, + wantErr: true, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + compiler := compilers.DefaultCompilers + parsed, err := Parse(tt.assertion, compiler) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotNil(t, parsed) + } + }) + } +}