From 35783e995ccefef460a18a034af7d4ad044f57b4 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 1 Dec 2022 13:12:17 -0800 Subject: [PATCH] Extension to support Golang structs as CEL types (#612) Implementation of a Golang native type provider --- cel/cel_test.go | 4 +- cel/env.go | 5 - cel/library.go | 2 +- cel/options.go | 10 +- common/types/ref/provider.go | 2 - ext/native.go | 544 +++++++++++++++++++++++++++++++ ext/native_test.go | 597 +++++++++++++++++++++++++++++++++++ interpreter/planner.go | 4 +- 8 files changed, 1151 insertions(+), 17 deletions(-) create mode 100644 ext/native.go create mode 100644 ext/native_test.go diff --git a/cel/cel_test.go b/cel/cel_test.go index be9e06cb..b8b107a9 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -1878,7 +1878,7 @@ func TestDynamicDispatch(t *testing.T) { func TestOptionalValues(t *testing.T) { env, err := NewEnv( - OptionalTypes(true), + OptionalTypes(), // Container and test message types. Container("google.expr.proto2.test"), Types(&proto2pb.TestAllTypes{}), @@ -2205,7 +2205,7 @@ func TestOptionalValues(t *testing.T) { func BenchmarkOptionalValues(b *testing.B) { env, err := NewEnv( - OptionalTypes(true), + OptionalTypes(), Variable("x", OptionalType(IntType)), Variable("y", OptionalType(IntType)), Variable("z", IntType), diff --git a/cel/env.go b/cel/env.go index faeaefcc..1a0b11ec 100644 --- a/cel/env.go +++ b/cel/env.go @@ -485,11 +485,6 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) { if err != nil { return nil, err } - // If optional types have been enabled, then configure the optional library. - e, err = e.maybeApplyFeature(featureOptionalTypes, Lib(optionalLibrary{})) - if err != nil { - return nil, err - } // Initialize all of the functions configured within the environment. for _, fn := range e.functions { diff --git a/cel/library.go b/cel/library.go index 8307c6f4..f04fda74 100644 --- a/cel/library.go +++ b/cel/library.go @@ -157,7 +157,7 @@ func (optionalLibrary) CompileOptions() []EnvOption { return opt.GetValue() }))), Function("hasValue", - MemberOverload("optional_hasValue", []*Type{optionalTypeV}, paramTypeV, + MemberOverload("optional_hasValue", []*Type{optionalTypeV}, BoolType, UnaryBinding(func(value ref.Val) ref.Val { opt := value.(*types.Optional) return types.Bool(opt.HasValue()) diff --git a/cel/options.go b/cel/options.go index b61d3e79..23e5223e 100644 --- a/cel/options.go +++ b/cel/options.go @@ -542,11 +542,11 @@ func DefaultUTCTimeZone(enabled bool) EnvOption { return features(featureDefaultUTCTimeZone, enabled) } -// OptionalTypes determines whether CEL's optional value type is enabled. The optional value type makes it -// possible to express whether variables have been provided, whether a result has been computed, and -// in the future whether an object field path, map key value, or list index has a value. -func OptionalTypes(enabled bool) EnvOption { - return features(featureOptionalTypes, enabled) +// OptionalTypes enable support for optional syntax and types in CEL. The optional value type makes +// it possible to express whether variables have been provided, whether a result has been computed, +// and in the future whether an object field path, map key value, or list index has a value. +func OptionalTypes() EnvOption { + return Lib(optionalLibrary{}) } // features sets the given feature flags. See list of Feature constants above. diff --git a/common/types/ref/provider.go b/common/types/ref/provider.go index 9ce2e34b..7eabbb9c 100644 --- a/common/types/ref/provider.go +++ b/common/types/ref/provider.go @@ -39,8 +39,6 @@ type TypeProvider interface { // FieldFieldType returns the field type for a checked type value. Returns // false if the field could not be found. - // - // Used during type-checking only. FindFieldType(messageType string, fieldName string) (*FieldType, bool) // NewValue creates a new type value from a qualified name and map of field diff --git a/ext/native.go b/ext/native.go new file mode 100644 index 00000000..551e342d --- /dev/null +++ b/ext/native.go @@ -0,0 +1,544 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "reflect" + "strings" + "time" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/pb" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +var ( + nativeObjTraitMask = traits.FieldTesterType | traits.IndexerType +) + +// NativeTypes creates a type provider which uses reflect.Type and reflect.Value instances +// to produce type definitions that can be used within CEL. +// +// All struct types in Go are exposed to CEL via their simple package name and struct type name: +// +// ```go +// package identity +// +// type Account struct { +// ID int +// } +// +// ``` +// +// The type `identity.Account` would be exported to CEL using the same qualified name, e.g. +// `identity.Account{ID: 1234}` would create a new `Account` instance with the `ID` field +// populated. +// +// Only exported fields are exposed via NativeTypes, and the type-mapping between Go and CEL +// is as follows: +// +// | Go type | CEL type | +// |-------------------------------------|-----------| +// | bool | bool | +// | []byte | bytes | +// | float32, float64 | double | +// | int, int8, int16, int32, int64 | int | +// | string | string | +// | uint, uint8, uint16, uint32, uint64 | uint | +// | time.Duration | duration | +// | time.Time | timestamp | +// | array, slice | list | +// | map | map | +// +// Please note, if you intend to configure support for proto messages in addition to native +// types, you will need to provide the protobuf types before the golang native types. The +// same advice holds if you are using custom type adapters and type providers. The native type +// provider composes over whichever type adapter and provider is configured in the cel.Env at +// the time that it is invoked. +func NativeTypes(refTypes ...any) cel.EnvOption { + return func(env *cel.Env) (*cel.Env, error) { + tp, err := newNativeTypeProvider(env.TypeAdapter(), env.TypeProvider(), refTypes...) + if err != nil { + return nil, err + } + env, err = cel.CustomTypeAdapter(tp)(env) + if err != nil { + return nil, err + } + return cel.CustomTypeProvider(tp)(env) + } +} + +func newNativeTypeProvider(adapter ref.TypeAdapter, provider ref.TypeProvider, refTypes ...any) (*nativeTypeProvider, error) { + nativeTypes := make(map[string]*nativeType, len(refTypes)) + for _, refType := range refTypes { + switch rt := refType.(type) { + case reflect.Type: + t, err := newNativeType(rt) + if err != nil { + return nil, err + } + nativeTypes[t.TypeName()] = t + case reflect.Value: + t, err := newNativeType(rt.Type()) + if err != nil { + return nil, err + } + nativeTypes[t.TypeName()] = t + default: + return nil, fmt.Errorf("unsupported native type: %v (%T) must be reflect.Type or reflect.Value", rt, rt) + } + } + return &nativeTypeProvider{ + nativeTypes: nativeTypes, + baseAdapter: adapter, + baseProvider: provider, + }, nil +} + +type nativeTypeProvider struct { + nativeTypes map[string]*nativeType + baseAdapter ref.TypeAdapter + baseProvider ref.TypeProvider +} + +// EnumValue proxies to the ref.TypeProvider configured at the times the NativeTypes +// option was configured. +func (tp *nativeTypeProvider) EnumValue(enumName string) ref.Val { + return tp.baseProvider.EnumValue(enumName) +} + +// FindIdent looks up natives type instances by qualified identifier, and if not found +// proxies to the composed ref.TypeProvider. +func (tp *nativeTypeProvider) FindIdent(typeName string) (ref.Val, bool) { + if t, found := tp.nativeTypes[typeName]; found { + return t, true + } + return tp.baseProvider.FindIdent(typeName) +} + +// FindType looks up CEL type-checker type definition by qualified identifier, and if not found +// proxies to the composed ref.TypeProvider. +func (tp *nativeTypeProvider) FindType(typeName string) (*exprpb.Type, bool) { + if _, found := tp.nativeTypes[typeName]; found { + return decls.NewTypeType(decls.NewObjectType(typeName)), true + } + return tp.baseProvider.FindType(typeName) +} + +// FindFieldType looks up a native type's field definition, and if the type name is not a native +// type then proxies to the composed ref.TypeProvider +func (tp *nativeTypeProvider) FindFieldType(typeName, fieldName string) (*ref.FieldType, bool) { + t, found := tp.nativeTypes[typeName] + if !found { + return tp.baseProvider.FindFieldType(typeName, fieldName) + } + refField, isDefined := t.hasField(fieldName) + if !found || !isDefined { + return nil, false + } + exprType, ok := convertToExprType(refField.Type) + if !ok { + return nil, false + } + return &ref.FieldType{ + Type: exprType, + IsSet: func(obj any) bool { + refVal := reflect.Indirect(reflect.ValueOf(obj)) + refField := refVal.FieldByName(fieldName) + return !refField.IsZero() + }, + GetFrom: func(obj any) (any, error) { + refVal := reflect.Indirect(reflect.ValueOf(obj)) + refField := refVal.FieldByName(fieldName) + return getFieldValue(tp, refField), nil + }, + }, true +} + +// NewValue implements the ref.TypeProvider interface method. +func (tp *nativeTypeProvider) NewValue(typeName string, fields map[string]ref.Val) ref.Val { + t, found := tp.nativeTypes[typeName] + if !found { + return tp.baseProvider.NewValue(typeName, fields) + } + refPtr := reflect.New(t.refType) + refVal := refPtr.Elem() + for fieldName, val := range fields { + refFieldDef, isDefined := t.hasField(fieldName) + if !isDefined { + return types.NewErr("no such field: %s", fieldName) + } + fieldVal, err := val.ConvertToNative(refFieldDef.Type) + if err != nil { + return types.NewErr(err.Error()) + } + refField := refVal.FieldByIndex(refFieldDef.Index) + refFieldVal := reflect.ValueOf(fieldVal) + refField.Set(refFieldVal) + } + return tp.NativeToValue(refPtr.Interface()) +} + +// NewValue adapts native values to CEL values and will proxy to the composed type adapter +// for non-native types. +func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val { + if val == nil { + return types.NullValue + } + rawVal := reflect.ValueOf(val) + refVal := rawVal + if refVal.Kind() == reflect.Ptr { + refVal = reflect.Indirect(refVal) + } + // This isn't quite right if you're also supporting proto, + // but maybe an acceptable limitation. + switch refVal.Kind() { + case reflect.Array, reflect.Slice: + switch val.(type) { + case []byte: + return tp.baseAdapter.NativeToValue(val) + default: + return types.NewDynamicList(tp, val) + } + case reflect.Map: + return types.NewDynamicMap(tp, val) + case reflect.Struct: + switch val := val.(type) { + case ref.Val: + return val + case proto.Message, *pb.Map, protoreflect.List, protoreflect.Message, protoreflect.Value, + time.Time: + return tp.baseAdapter.NativeToValue(val) + default: + return newNativeObject(tp, val, rawVal) + } + default: + return tp.baseAdapter.NativeToValue(val) + } +} + +// convertToExprType converts the Golang reflect.Type to a protobuf exprpb.Type. +func convertToExprType(refType reflect.Type) (*exprpb.Type, bool) { + switch refType.Kind() { + case reflect.Bool: + return decls.Bool, true + case reflect.Float32, reflect.Float64: + return decls.Double, true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if refType == durationType { + return decls.Duration, true + } + return decls.Int, true + case reflect.String: + return decls.String, true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return decls.Uint, true + case reflect.Array, reflect.Slice: + refElem := refType.Elem() + if refElem == reflect.TypeOf(byte(0)) { + return decls.Bytes, true + } + elemType, ok := convertToExprType(refElem) + if !ok { + return nil, false + } + return decls.NewListType(elemType), true + case reflect.Map: + keyType, ok := convertToExprType(refType.Key()) + if !ok { + return nil, false + } + // Ensure the key type is a int, bool, uint, string + elemType, ok := convertToExprType(refType.Elem()) + if !ok { + return nil, false + } + return decls.NewMapType(keyType, elemType), true + case reflect.Struct: + if refType == timestampType { + return decls.Timestamp, true + } + return decls.NewObjectType( + fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()), + ), true + case reflect.Pointer: + if refType.Implements(pbMsgInterfaceType) { + pbMsg := reflect.New(refType.Elem()).Interface().(protoreflect.ProtoMessage) + return decls.NewObjectType(string(pbMsg.ProtoReflect().Descriptor().FullName())), true + } + return convertToExprType(refType.Elem()) + } + return nil, false +} + +func newNativeObject(adapter ref.TypeAdapter, val any, refValue reflect.Value) ref.Val { + valType, err := newNativeType(refValue.Type()) + if err != nil { + return types.NewErr(err.Error()) + } + return &nativeObj{ + TypeAdapter: adapter, + val: val, + valType: valType, + refValue: refValue, + } +} + +type nativeObj struct { + ref.TypeAdapter + val any + valType *nativeType + refValue reflect.Value +} + +// ConvertToNative implements the ref.Val interface method. +// +// CEL does not have a notion of pointers, so whether a field is a pointer or value +// is handled as part of this converstion step. +func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) { + if o.refValue.Type() == typeDesc { + return o.val, nil + } + if o.refValue.Kind() == reflect.Pointer && o.refValue.Type().Elem() == typeDesc { + return o.refValue.Elem().Interface(), nil + } + if typeDesc.Kind() == reflect.Pointer && o.refValue.Type() == typeDesc.Elem() { + ptr := reflect.New(typeDesc.Elem()) + ptr.Elem().Set(o.refValue) + return ptr.Interface(), nil + } + return nil, fmt.Errorf("type conversion error from '%v' to '%v'", o.Type(), typeDesc) +} + +// ConvertToType implements the ref.Val interface method. +func (o *nativeObj) ConvertToType(typeVal ref.Type) ref.Val { + switch typeVal { + case types.TypeType: + return o.valType + default: + if typeVal.TypeName() == o.valType.typeName { + return o + } + } + return types.NewErr("type conversion error from '%s' to '%s'", o.Type(), typeVal) +} + +// Equal implements the ref.Val interface method. +// +// Note, that in Golang a pointer to a value is not equal to the value it contains. +// In CEL pointers and values to which they point are equal. +func (o *nativeObj) Equal(other ref.Val) ref.Val { + otherNtv, ok := other.(*nativeObj) + if !ok { + return types.False + } + val := o.val + otherVal := otherNtv.val + refVal := o.refValue + otherRefVal := otherNtv.refValue + if refVal.Kind() != otherRefVal.Kind() { + if refVal.Kind() == reflect.Pointer { + val = refVal.Elem().Interface() + } else if otherRefVal.Kind() == reflect.Pointer { + otherVal = otherRefVal.Elem().Interface() + } + } + return types.Bool(reflect.DeepEqual(val, otherVal)) +} + +// IsZeroValue indicates whether the contained Golang value is a zero value. +// +// Golang largely follows proto3 semantics for zero values. +func (o *nativeObj) IsZeroValue() bool { + return reflect.Indirect(o.refValue).IsZero() +} + +// IsSet tests whether a field which is defined is set to a non-default value. +func (o *nativeObj) IsSet(field ref.Val) ref.Val { + refField, refErr := o.getReflectedField(field) + if refErr != nil { + return refErr + } + return types.Bool(!refField.IsZero()) +} + +// Get returns the value fo a field name. +func (o *nativeObj) Get(field ref.Val) ref.Val { + refField, refErr := o.getReflectedField(field) + if refErr != nil { + return refErr + } + return adaptFieldValue(o, refField) +} + +func (o *nativeObj) getReflectedField(field ref.Val) (reflect.Value, ref.Val) { + fieldName, ok := field.(types.String) + if !ok { + return reflect.Value{}, types.MaybeNoSuchOverloadErr(field) + } + fieldNameStr := string(fieldName) + refField, isDefined := o.valType.hasField(fieldNameStr) + if !isDefined { + return reflect.Value{}, types.NewErr("no such field: %s", fieldName) + } + refVal := reflect.Indirect(o.refValue) + return refVal.FieldByIndex(refField.Index), nil +} + +// Type implements the ref.Val interface method. +func (o *nativeObj) Type() ref.Type { + return o.valType +} + +// Value implements the ref.Val interface method. +func (o *nativeObj) Value() any { + return o.val +} + +func newNativeType(rawType reflect.Type) (*nativeType, error) { + refType := rawType + if refType.Kind() == reflect.Pointer { + refType = refType.Elem() + } + if !isValidObjectType(refType) { + return nil, fmt.Errorf("unsupported reflect.Type %v, must be reflect.Struct", rawType) + } + return &nativeType{ + typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()), + refType: refType, + }, nil +} + +type nativeType struct { + typeName string + refType reflect.Type +} + +// ConvertToNative implements ref.Val.ConvertToNative. +func (t *nativeType) ConvertToNative(typeDesc reflect.Type) (any, error) { + return nil, fmt.Errorf("type conversion error for type to '%v'", typeDesc) +} + +// ConvertToType implements ref.Val.ConvertToType. +func (t *nativeType) ConvertToType(typeVal ref.Type) ref.Val { + switch typeVal { + case types.TypeType: + return types.TypeType + } + return types.NewErr("type conversion error from '%s' to '%s'", types.TypeType, typeVal) +} + +// Equal returns true of both type names are equal to each other. +func (t *nativeType) Equal(other ref.Val) ref.Val { + otherType, ok := other.(ref.Type) + return types.Bool(ok && t.TypeName() == otherType.TypeName()) +} + +// HasTrait implements the ref.Type interface method. +func (t *nativeType) HasTrait(trait int) bool { + return nativeObjTraitMask&trait == trait +} + +// String implements the strings.Stringer interface method. +func (t *nativeType) String() string { + return t.typeName +} + +// Type implements the ref.Val interface method. +func (t *nativeType) Type() ref.Type { + return types.TypeType +} + +// TypeName implements the ref.Type interface method. +func (t *nativeType) TypeName() string { + return t.typeName +} + +// Value implements the ref.Val interface method. +func (t *nativeType) Value() any { + return t.typeName +} + +// hasField returns whether a field name has a corresponding Golang reflect.StructField +func (t *nativeType) hasField(fieldName string) (reflect.StructField, bool) { + f, found := t.refType.FieldByName(fieldName) + if !found || !f.IsExported() || !isSupportedType(f.Type) { + return reflect.StructField{}, false + } + return f, true +} + +func adaptFieldValue(adapter ref.TypeAdapter, refField reflect.Value) ref.Val { + return adapter.NativeToValue(getFieldValue(adapter, refField)) +} + +func getFieldValue(adapter ref.TypeAdapter, refField reflect.Value) any { + if refField.IsZero() { + switch refField.Kind() { + case reflect.Array, reflect.Slice: + return types.NewDynamicList(adapter, []ref.Val{}) + case reflect.Map: + return types.NewDynamicMap(adapter, map[ref.Val]ref.Val{}) + case reflect.Struct: + if refField.Type() == timestampType { + return types.Timestamp{Time: time.Unix(0, 0)} + } + return reflect.New(refField.Type()).Elem().Interface() + case reflect.Pointer: + return reflect.New(refField.Type().Elem()).Interface() + } + } + return refField.Interface() +} + +func simplePkgAlias(pkgPath string) string { + paths := strings.Split(pkgPath, "/") + if len(paths) == 0 { + return "" + } + return paths[len(paths)-1] +} + +func isValidObjectType(refType reflect.Type) bool { + return refType.Kind() == reflect.Struct +} + +func isSupportedType(refType reflect.Type) bool { + switch refType.Kind() { + case reflect.Chan, reflect.Complex64, reflect.Complex128, reflect.Func, reflect.UnsafePointer, reflect.Uintptr: + return false + case reflect.Array, reflect.Slice: + return isSupportedType(refType.Elem()) + case reflect.Map: + return isSupportedType(refType.Key()) && isSupportedType(refType.Elem()) + } + return true +} + +var ( + pbMsgInterfaceType = reflect.TypeOf((*protoreflect.ProtoMessage)(nil)).Elem() + timestampType = reflect.TypeOf(time.Now()) + durationType = reflect.TypeOf(time.Nanosecond) +) diff --git a/ext/native_test.go b/ext/native_test.go new file mode 100644 index 00000000..346cd2d6 --- /dev/null +++ b/ext/native_test.go @@ -0,0 +1,597 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "reflect" + "strings" + "testing" + "time" + + "google.golang.org/protobuf/proto" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/pb" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/test" + + proto3pb "github.com/google/cel-go/test/proto3pb" +) + +func TestNativeTypes(t *testing.T) { + var nativeTests = []struct { + expr string + out any + in any + }{ + { + expr: `ext.TestAllTypes{ + NestedVal: ext.TestNestedType{NestedMapVal: {1: false}}, + BoolVal: true, + BytesVal: b'hello', + DurationVal: duration('5s'), + DoubleVal: 1.5, + FloatVal: 2.5, + Int32Val: 10, + Int64Val: 20, + StringVal: 'hello world', + TimestampVal: timestamp('2011-08-06T01:23:45Z'), + Uint32Val: 100u, + Uint64Val: 200u, + ListVal: [ + ext.TestNestedType{ + NestedListVal:['goodbye', 'cruel', 'world'], + NestedMapVal: {42: true}, + }, + ], + MapVal: {'map-key': ext.TestAllTypes{BoolVal: true}}, + }`, + out: &TestAllTypes{ + NestedVal: &TestNestedType{NestedMapVal: map[int64]bool{1: false}}, + BoolVal: true, + BytesVal: []byte("hello"), + DurationVal: time.Second * 5, + DoubleVal: 1.5, + FloatVal: 2.5, + Int32Val: 10, + Int64Val: 20, + StringVal: "hello world", + TimestampVal: mustParseTime(t, "2011-08-06T01:23:45Z"), + Uint32Val: uint32(100), + Uint64Val: uint64(200), + ListVal: []*TestNestedType{ + { + NestedListVal: []string{"goodbye", "cruel", "world"}, + NestedMapVal: map[int64]bool{42: true}, + }, + }, + MapVal: map[string]TestAllTypes{"map-key": {BoolVal: true}}, + }, + }, + { + expr: `ext.TestAllTypes{ + PbVal: test.TestAllTypes{single_int32: 123} + }.PbVal`, + out: &proto3pb.TestAllTypes{SingleInt32: 123}, + }, + { + expr: `ext.TestAllTypes{PbVal: test.TestAllTypes{}} == + ext.TestAllTypes{PbVal: test.TestAllTypes{single_bool: false}}`, + }, + {expr: `ext.TestNestedType{} == TestNestedType{}`}, + {expr: `ext.TestAllTypes{}.BoolVal != true`}, + {expr: `!has(ext.TestAllTypes{}.BoolVal) && !has(ext.TestAllTypes{}.NestedVal)`}, + {expr: `type(ext.TestAllTypes) == type`}, + {expr: `type(ext.TestAllTypes{}) == ext.TestAllTypes`}, + {expr: `type(ext.TestAllTypes{}) == ext.TestAllTypes`}, + {expr: `ext.TestAllTypes != test.TestAllTypes`}, + {expr: `ext.TestAllTypes{BoolVal: true} != dyn(test.TestAllTypes{single_bool: true})`}, + {expr: `ext.TestAllTypes{}.NestedVal == ext.TestNestedType{}`}, + {expr: `ext.TestNestedType{} == ext.TestAllTypes{}.NestedStructVal`}, + {expr: `ext.TestAllTypes{}.NestedStructVal == ext.TestNestedType{}`}, + {expr: `ext.TestAllTypes{}.ListVal.size() == 0`}, + {expr: `ext.TestAllTypes{}.MapVal.size() == 0`}, + {expr: `ext.TestAllTypes{}.TimestampVal == timestamp(0)`}, + {expr: `test.TestAllTypes{}.single_timestamp == timestamp(0)`}, + {expr: `[TestAllTypes{BoolVal: true}, TestAllTypes{BoolVal: false}].exists(t, t.BoolVal == true)`}, + { + expr: `tests.all(t, t.Int32Val > 17)`, + in: map[string]any{ + "tests": []*TestAllTypes{{Int32Val: 18}, {Int32Val: 19}, {Int32Val: 20}}, + }, + }, + } + env := testNativeEnv(t) + for i, tst := range nativeTests { + tc := tst + t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + var asts []*cel.Ast + pAst, iss := env.Parse(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, pAst) + cAst, iss := env.Check(pAst) + if iss.Err() != nil { + t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, cAst) + for _, ast := range asts { + prg, err := env.Program(ast) + if err != nil { + t.Fatal(err) + } + in := tc.in + if in == nil { + in = cel.NoVars() + } + out, _, err := prg.Eval(in) + if err != nil { + t.Fatal(err) + } + want := tc.out + if want == nil { + want = true + } + wantPB, isPB := want.(proto.Message) + if isPB && !pb.Equal(wantPB, out.Value().(proto.Message)) { + t.Errorf("got %v, wanted %v for expr: %s", out.Value(), want, tc.expr) + } + if !isPB && !reflect.DeepEqual(out.Value(), want) { + t.Errorf("got %v, wanted %v for expr: %s", out.Value(), want, tc.expr) + } + } + }) + } +} + +func TestNativeTypesStaticErrors(t *testing.T) { + var nativeTests = []struct { + expr string + err string + }{ + { + expr: `TestAllTypos{}`, + err: `ERROR: :1:13: undeclared reference to 'TestAllTypos' (in container 'ext') + | TestAllTypos{} + | ............^`, + }, + { + expr: `ext.TestAllTypes{bool_val: false}`, + err: `ERROR: :1:26: undefined field 'bool_val' + | ext.TestAllTypes{bool_val: false} + | .........................^`, + }, + { + expr: `ext.TestAllTypes{UnsupportedVal: null}`, + err: `ERROR: :1:32: undefined field 'UnsupportedVal' + | ext.TestAllTypes{UnsupportedVal: null} + | ...............................^`, + }, + { + expr: `ext.TestAllTypes{UnsupportedListVal: null}`, + err: `ERROR: :1:36: undefined field 'UnsupportedListVal' + | ext.TestAllTypes{UnsupportedListVal: null} + | ...................................^`, + }, + { + expr: `ext.TestAllTypes{UnsupportedMapVal: null}`, + err: `ERROR: :1:35: undefined field 'UnsupportedMapVal' + | ext.TestAllTypes{UnsupportedMapVal: null} + | ..................................^`, + }, + } + env := testNativeEnv(t) + for i, tst := range nativeTests { + tc := tst + t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + _, iss := env.Compile(tc.expr) + if iss.Err() == nil { + t.Fatalf("env.Compile(%v) succeeded, wanted error", tc.expr) + } + if !test.Compare(iss.Err().Error(), tc.err) { + t.Errorf("env.Compile(%v) got %v, wanted error %s", tc.expr, iss.Err(), tc.err) + } + }) + } +} + +func TestNativeTypesRuntimeErrors(t *testing.T) { + var nativeTests = []struct { + expr string + err string + }{ + { + expr: `TestAllTypos{}`, + err: `unknown type: TestAllTypos`, + }, + { + expr: `ext.TestAllTypes{bool_val: false}`, + err: `no such field: bool_val`, + }, + { + expr: `ext.TestAllTypes{UnsupportedVal: null}`, + err: `no such field: UnsupportedVal`, + }, + { + expr: `ext.TestAllTypes{UnsupportedListVal: null}`, + err: `no such field: UnsupportedListVal`, + }, + { + expr: `ext.TestAllTypes{UnsupportedMapVal: null}`, + err: `no such field: UnsupportedMapVal`, + }, + { + expr: `ext.TestAllTypes{privateVal: null}`, + err: `no such field: privateVal`, + }, + { + expr: `ext.TestAllTypes{}.UnsupportedMapVal`, + err: `no such field: UnsupportedMapVal`, + }, + { + expr: `ext.TestAllTypes{}.privateVal`, + err: `no such field: privateVal`, + }, + { + expr: `ext.TestAllTypes{BoolVal: 'false'}`, + err: `unsupported native conversion from string to 'bool'`, + }, + { + expr: `has(ext.TestAllTypes{}.BadFieldName)`, + err: `no such field: BadFieldName`, + }, + { + expr: `ext.TestAllTypes{}[42]`, + err: `no such overload`, + }, + { + expr: `ext.TestAllTypes{Int32Val: 9223372036854775807}`, + err: `integer overflow`, + }, + { + expr: `ext.TestAllTypes{Uint32Val: 9223372036854775807u}`, + err: `unsigned integer overflow`, + }, + } + env := testNativeEnv(t) + for i, tst := range nativeTests { + tc := tst + t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + ast, iss := env.Parse(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + if !strings.Contains(err.Error(), tc.err) { + t.Fatal(err) + } + return + } + out, _, err := prg.Eval(cel.NoVars()) + if err == nil || !strings.Contains(err.Error(), tc.err) { + var got any = err + if err == nil { + got = out + } + t.Fatalf("prg.Eval() got %v, wanted error %v", got, tc.err) + } + }) + } +} + +func TestNativeTypesErrors(t *testing.T) { + envTests := []struct { + nativeType any + err string + }{ + { + nativeType: reflect.TypeOf(1), + err: "unsupported reflect.Type", + }, + { + nativeType: reflect.ValueOf(1), + err: "unsupported reflect.Type", + }, + { + nativeType: 1, + err: "must be reflect.Type or reflect.Value", + }, + } + for i, tst := range envTests { + tc := tst + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + _, err := cel.NewEnv(NativeTypes(tc.nativeType)) + if err == nil || !strings.Contains(err.Error(), tc.err) { + t.Errorf("cel.NewEnv(NativeTypes(%v)) got error %v, wanted %v", tc.nativeType, err, tc.err) + } + }) + } +} + +func TestNativeTypesConvertToNative(t *testing.T) { + env := testNativeEnv(t, NativeTypes(reflect.TypeOf(TestNestedType{}))) + adapter := env.TypeAdapter() + conversions := []struct { + in any + out any + err string + }{ + { + in: &TestAllTypes{BoolVal: true}, + out: &TestAllTypes{BoolVal: true}, + }, + { + in: TestAllTypes{BoolVal: true}, + out: &TestAllTypes{BoolVal: true}, + }, + { + in: &TestAllTypes{BoolVal: true}, + out: TestAllTypes{BoolVal: true}, + }, + { + in: nil, + out: types.NullValue, + }, + { + in: &TestAllTypes{BoolVal: true}, + out: &proto3pb.TestAllTypes{}, + err: "type conversion error", + }, + } + for _, c := range conversions { + inVal := adapter.NativeToValue(c.in) + if types.IsError(inVal) { + t.Fatalf("adapter.NativeToValue(%v) failed: %v", c.in, inVal) + } + out, err := inVal.ConvertToNative(reflect.TypeOf(c.out)) + if err != nil { + if c.err != "" { + if !strings.Contains(err.Error(), c.err) { + t.Fatalf("%v.ConvertToNative(%T) got %v, wanted error %v", c.in, c.out, err, c.err) + } + return + } + t.Fatalf("%v.ConvertToNative(%T) failed: %v", c.in, c.out, err) + } + if !reflect.DeepEqual(out, c.out) { + t.Errorf("%v.ConvertToNative(%T) got %v, wanted %v", c.in, c.out, out, c.out) + } + } +} + +func TestNativeTypesConvertToExprTypeErrors(t *testing.T) { + unsupportedTypes := []reflect.Type{ + reflect.TypeOf(make(map[string]chan string)), + reflect.TypeOf(make([]chan int, 0)), + reflect.TypeOf(make(map[chan int]bool, 0)), + } + for _, ut := range unsupportedTypes { + if _, converted := convertToExprType(ut); converted { + t.Errorf("convertToExprType(%v) succeeded when it should have failed", ut) + } + } +} + +func TestConvertToTypeErrors(t *testing.T) { + env := testNativeEnv(t, NativeTypes(reflect.TypeOf(TestNestedType{}))) + adapter := env.TypeAdapter() + conversions := []struct { + in any + out any + err string + }{ + { + in: &TestAllTypes{BoolVal: true}, + out: &TestAllTypes{BoolVal: true}, + }, + { + in: TestAllTypes{BoolVal: true}, + out: &TestAllTypes{BoolVal: true}, + }, + { + in: &TestAllTypes{BoolVal: true}, + out: TestAllTypes{BoolVal: true}, + }, + { + in: &TestAllTypes{BoolVal: true}, + out: &proto3pb.TestAllTypes{}, + err: "type conversion error", + }, + } + for _, c := range conversions { + inVal := adapter.NativeToValue(c.in) + outVal := adapter.NativeToValue(c.out) + if types.IsError(inVal) { + t.Fatalf("adapter.NativeToValue(%v) failed: %v", c.in, inVal) + } + if types.IsError(outVal) { + t.Fatalf("adapter.NativeToValue(%v) failed: %v", c.out, outVal) + } + conv := inVal.ConvertToType(outVal.Type()) + if c.err != "" { + if !types.IsError(conv) { + t.Fatalf("%v.ConvertToType(%v) got %v, wanted error %v", c.in, outVal.Type(), conv, c.err) + } + convErr := conv.(*types.Err) + if !strings.Contains(convErr.Error(), c.err) { + t.Fatalf("%v.ConvertToType(%v) got %v, wanted error %v", c.in, outVal.Type(), conv, c.err) + } + return + } + if conv != inVal { + t.Errorf("%v.ConvertToType(%v) got %v, wanted %v", c.in, outVal.Type(), conv, c.err) + } + conv = inVal.ConvertToType(types.TypeType) + if conv.Type() != types.TypeType || conv.(ref.Type) != inVal.Type() { + t.Errorf("%v.ConvertToType(Type) got %v, wanted %v", inVal, conv, inVal.Type()) + } + } +} + +func TestNativeTypesWithOptional(t *testing.T) { + var nativeTests = []struct { + expr string + }{ + {expr: `!optional.ofNonZeroValue(ext.TestAllTypes{}).hasValue()`}, + {expr: `!ext.TestAllTypes{}.?BoolVal.orValue(false)`}, + {expr: `!ext.TestAllTypes{}.?BoolVal.hasValue()`}, + {expr: `!ext.TestAllTypes{BoolVal: false}.?BoolVal.hasValue()`}, + {expr: `ext.TestAllTypes{BoolVal: true}.?BoolVal.hasValue()`}, + {expr: `ext.TestAllTypes{}.NestedVal.?NestedMapVal.orValue({}).size() == 0`}, + } + env := testNativeEnv(t, cel.OptionalTypes()) + for i, tst := range nativeTests { + tc := tst + t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + var asts []*cel.Ast + pAst, iss := env.Parse(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, pAst) + cAst, iss := env.Check(pAst) + if iss.Err() != nil { + t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, cAst) + for _, ast := range asts { + prg, err := env.Program(ast) + if err != nil { + t.Fatal(err) + } + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(out.Value(), true) { + t.Errorf("got %v, wanted true for expr: %s", out.Value(), tc.expr) + } + } + }) + } +} + +func TestNativeTypeConvertToType(t *testing.T) { + nt, err := newNativeType(reflect.TypeOf(&TestAllTypes{})) + if err != nil { + t.Fatalf("newNativeType() failed: %v", err) + } + if nt.ConvertToType(types.TypeType) != types.TypeType { + t.Error("ConvertToType(Type) failed") + } + if !types.IsError(nt.ConvertToType(types.StringType)) { + t.Errorf("ConvertToType(String) got %v, wanted error", nt.ConvertToType(types.StringType)) + } +} + +func TestNativeTypeConvertToNative(t *testing.T) { + nt, err := newNativeType(reflect.TypeOf(&TestAllTypes{})) + if err != nil { + t.Fatalf("newNativeType() failed: %v", err) + } + out, err := nt.ConvertToNative(reflect.TypeOf(1)) + if err == nil { + t.Errorf("nt.ConvertToNative(1) produced %v, wanted error", out) + } +} + +func TestNativeTypeHasTrait(t *testing.T) { + nt, err := newNativeType(reflect.TypeOf(&TestAllTypes{})) + if err != nil { + t.Fatalf("newNativeType() failed: %v", err) + } + if !nt.HasTrait(traits.IndexerType) || !nt.HasTrait(traits.FieldTesterType) { + t.Error("nt.HasTrait() failed indicate support for presence test and field access.") + } +} + +func TestNativeTypeValue(t *testing.T) { + nt, err := newNativeType(reflect.TypeOf(&TestAllTypes{})) + if err != nil { + t.Fatalf("newNativeType() failed: %v", err) + } + if nt.Value() != nt.String() { + t.Errorf("nt.Value() got %v, wanted %v", nt.Value(), nt.String()) + } +} + +// testEnv initializes the test environment common to all tests. +func testNativeEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { + t.Helper() + envOpts := []cel.EnvOption{ + cel.Container("ext"), + cel.Abbrevs("google.expr.proto3.test"), + cel.Types(&proto3pb.TestAllTypes{}), + cel.Variable("tests", cel.ListType(cel.ObjectType("ext.TestAllTypes"))), + } + envOpts = append(envOpts, opts...) + envOpts = append(envOpts, + NativeTypes( + reflect.TypeOf(&TestNestedType{}), + reflect.ValueOf(&TestAllTypes{}), + ), + ) + env, err := cel.NewEnv(envOpts...) + if err != nil { + t.Fatalf("cel.NewEnv(NativeTypes()) failed: %v", err) + } + return env +} + +func mustParseTime(t *testing.T, timestamp string) time.Time { + t.Helper() + out, err := time.Parse(time.RFC3339, timestamp) + if err != nil { + t.Fatalf("time.Parse(%q) failed: %v", timestamp, err) + } + return out +} + +type TestNestedType struct { + NestedListVal []string + NestedMapVal map[int64]bool +} + +type TestAllTypes struct { + NestedVal *TestNestedType + NestedStructVal TestNestedType + BoolVal bool + BytesVal []byte + DurationVal time.Duration + DoubleVal float64 + FloatVal float32 + Int32Val int32 + Int64Val int32 + StringVal string + TimestampVal time.Time + Uint32Val uint32 + Uint64Val uint64 + ListVal []*TestNestedType + MapVal map[string]TestAllTypes + PbVal *proto3pb.TestAllTypes + + // channel types are not supported + UnsupportedVal chan string + UnsupportedListVal []chan string + UnsupportedMapVal map[int]chan string + + // unexported types can be found but not set or accessed + privateVal map[string]string +} diff --git a/interpreter/planner.go b/interpreter/planner.go index bd916df7..9c6bbb62 100644 --- a/interpreter/planner.go +++ b/interpreter/planner.go @@ -566,9 +566,9 @@ func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) { // planCreateObj generates an object construction Interpretable. func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) { obj := expr.GetStructExpr() - typeName, defined := p.resolveTypeName(obj.MessageName) + typeName, defined := p.resolveTypeName(obj.GetMessageName()) if !defined { - return nil, fmt.Errorf("unknown type: %s", typeName) + return nil, fmt.Errorf("unknown type: %s", obj.GetMessageName()) } entries := obj.GetEntries() optionals := make([]bool, len(entries))