Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for struct tag name overrides in native types #941

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 144 additions & 22 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package ext

import (
"errors"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -77,12 +78,45 @@ var (
// same advice holds if you are using custom type adapters and type providers. The native type
// provider composes over whichever type adapter and provider is configured in the cel.Env at
// the time that it is invoked.
func NativeTypes(refTypes ...any) cel.EnvOption {
//
// There is also the possibility to rename the fields of native structs by setting the `cel` tag
// for fields you want to override. In order to enable this feature, pass in the `EnableStructTag`
// option. Here is an example to see it in action:
//
// ```go
// package identity
//
// type Account struct {
// ID int
// OwnerName string `cel:"owner"`
// }
//
// ```
//
// The `OwnerName` field is now accessible in CEL via `owner`, e.g. `identity.Account{owner: 'bob'}`.
// In case there are duplicated field names in the struct, an error will be returned.
func NativeTypes(args ...any) cel.EnvOption {
return func(env *cel.Env) (*cel.Env, error) {
tp, err := newNativeTypeProvider(env.CELTypeAdapter(), env.CELTypeProvider(), refTypes...)
nativeTypes := make([]any, 0, len(args))
tpOptions := nativeTypeOptions{}

for _, v := range args {
switch v := v.(type) {
case NativeTypesOption:
err := v(&tpOptions)
if err != nil {
return nil, err
}
default:
nativeTypes = append(nativeTypes, v)
}
}

tp, err := newNativeTypeProvider(tpOptions, env.CELTypeAdapter(), env.CELTypeProvider(), nativeTypes...)
if err != nil {
return nil, err
}

env, err = cel.CustomTypeAdapter(tp)(env)
if err != nil {
return nil, err
Expand All @@ -91,20 +125,37 @@ func NativeTypes(refTypes ...any) cel.EnvOption {
}
}

func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTypes ...any) (*nativeTypeProvider, error) {
// NativeTypesOption is a functional interface for configuring handling of native types.
type NativeTypesOption func(*nativeTypeOptions) error
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved

type nativeTypeOptions struct {
// parseStructTags controls if CEL should support struct field renames, by parsing
// struct field tags.
parseStructTags bool
}

// ParseStructTags configures if native types field names should be overridable by CEL struct tags.
func ParseStructTags(enabled bool) NativeTypesOption {
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
return func(ntp *nativeTypeOptions) error {
ntp.parseStructTags = true
return nil
}
}

func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, provider types.Provider, refTypes ...any) (*nativeTypeProvider, error) {
nativeTypes := make(map[string]*nativeType, len(refTypes))
for _, refType := range refTypes {
switch rt := refType.(type) {
case reflect.Type:
result, err := newNativeTypes(rt)
result, err := newNativeTypes(tpOptions.parseStructTags, rt)
if err != nil {
return nil, err
}
for idx := range result {
nativeTypes[result[idx].TypeName()] = result[idx]
}
case reflect.Value:
result, err := newNativeTypes(rt.Type())
result, err := newNativeTypes(tpOptions.parseStructTags, rt.Type())
if err != nil {
return nil, err
}
Expand All @@ -119,13 +170,15 @@ func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTy
nativeTypes: nativeTypes,
baseAdapter: adapter,
baseProvider: provider,
options: tpOptions,
}, nil
}

type nativeTypeProvider struct {
nativeTypes map[string]*nativeType
baseAdapter types.Adapter
baseProvider types.Provider
options nativeTypeOptions
}

// EnumValue proxies to the types.Provider configured at the times the NativeTypes
Expand Down Expand Up @@ -155,6 +208,18 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
return tp.baseProvider.FindStructType(typeName)
}

TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
func toFieldName(parseStructTag bool, f reflect.StructField) string {
if !parseStructTag {
return f.Name
}

if name, found := f.Tag.Lookup("cel"); found {
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
return name
}

return f.Name
}

// FindStructFieldNames looks up the type definition first from the native types, then from
// the backing provider type set. If found, a set of field names corresponding to the type
// will be returned.
Expand All @@ -163,7 +228,7 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
fieldCount := t.refType.NumField()
fields := make([]string, fieldCount)
for i := 0; i < fieldCount; i++ {
fields[i] = t.refType.Field(i).Name
fields[i] = toFieldName(tp.options.parseStructTags, t.refType.Field(i))
}
return fields, true
}
Expand All @@ -173,6 +238,22 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
return tp.baseProvider.FindStructFieldNames(typeName)
}

// valueFieldByName retrieves the corresponding reflect.Value field for the given field name, by
// searching for a matching field tag value or field name.
func valueFieldByName(parseStructTags bool, target reflect.Value, fieldName string) reflect.Value {
if !parseStructTags {
return target.FieldByName(fieldName)
}

for i := 0; i < target.Type().NumField(); i++ {
f := target.Type().Field(i)
if toFieldName(parseStructTags, f) == fieldName {
return target.FieldByIndex(f.Index)
}
}
return reflect.Value{}
}

// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
Expand All @@ -192,12 +273,12 @@ func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*
Type: celType,
IsSet: func(obj any) bool {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName)
return !refField.IsZero()
},
GetFrom: func(obj any) (any, error) {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName)
return getFieldValue(tp, refField), nil
},
}, true
Expand Down Expand Up @@ -259,7 +340,7 @@ func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val {
time.Time:
return tp.baseAdapter.NativeToValue(val)
default:
return newNativeObject(tp, val, rawVal)
return tp.newNativeObject(val, rawVal)
}
default:
return tp.baseAdapter.NativeToValue(val)
Expand Down Expand Up @@ -319,13 +400,13 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) {
return nil, false
}

func newNativeObject(adapter types.Adapter, val any, refValue reflect.Value) ref.Val {
valType, err := newNativeType(refValue.Type())
func (tp *nativeTypeProvider) newNativeObject(val any, refValue reflect.Value) ref.Val {
valType, err := newNativeType(tp.options.parseStructTags, refValue.Type())
if err != nil {
return types.NewErr(err.Error())
}
return &nativeObj{
Adapter: adapter,
Adapter: tp,
val: val,
valType: valType,
refValue: refValue,
Expand Down Expand Up @@ -372,12 +453,13 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
if !fieldValue.IsValid() || fieldValue.IsZero() {
continue
}
fieldName := toFieldName(o.valType.parseStructTags, fieldType)
fieldCELVal := o.NativeToValue(fieldValue.Interface())
fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType)
if err != nil {
return nil, err
}
fields[fieldType.Name] = fieldJSONVal.(*structpb.Value)
fields[fieldName] = fieldJSONVal.(*structpb.Value)
}
return &structpb.Struct{Fields: fields}, nil
}
Expand Down Expand Up @@ -469,8 +551,8 @@ func (o *nativeObj) Value() any {
return o.val
}

func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(rawType)
func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(parseStructTags, rawType)
if err != nil {
return nil, err
}
Expand All @@ -489,7 +571,7 @@ func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) {
return
}
alreadySeen[t.String()] = struct{}{}
nt, ntErr := newNativeType(t)
nt, ntErr := newNativeType(parseStructTags, t)
if ntErr != nil {
err = ntErr
return
Expand All @@ -505,23 +587,46 @@ func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) {
return result, err
}

func newNativeType(rawType reflect.Type) (*nativeType, error) {
var (
errDuplicatedFieldName = errors.New("field name already exists in struct")
)

func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, error) {
refType := rawType
if refType.Kind() == reflect.Pointer {
refType = refType.Elem()
}
if !isValidObjectType(refType) {
return nil, fmt.Errorf("unsupported reflect.Type %v, must be reflect.Struct", rawType)
}

TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
// Since naming collisions can only happen with struct tag parsing, we only check for them if it is enabled.
if parseStructTags {
fieldNames := make(map[string]struct{})

for idx := 0; idx < refType.NumField(); idx++ {
field := refType.Field(idx)
fieldName := toFieldName(parseStructTags, field)

if _, found := fieldNames[fieldName]; found {
return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName)
} else {
fieldNames[fieldName] = struct{}{}
}
}
}

return &nativeType{
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
parseStructTags: parseStructTags,
}, nil
}

type nativeType struct {
typeName string
refType reflect.Type
typeName string
refType reflect.Type
parseStructTags bool
}

// ConvertToNative implements ref.Val.ConvertToNative.
Expand Down Expand Up @@ -569,9 +674,26 @@ func (t *nativeType) Value() any {
return t.typeName
}

// fieldByName returns the corresponding reflect.StructField for the give name either by matching
// field tag or field name.
func (t *nativeType) fieldByName(fieldName string) (reflect.StructField, bool) {
if !t.parseStructTags {
return t.refType.FieldByName(fieldName)
}

for i := 0; i < t.refType.NumField(); i++ {
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
f := t.refType.Field(i)
if toFieldName(t.parseStructTags, f) == fieldName {
return f, true
}
}

return reflect.StructField{}, false
}

// hasField returns whether a field name has a corresponding Golang reflect.StructField
func (t *nativeType) hasField(fieldName string) (reflect.StructField, bool) {
f, found := t.refType.FieldByName(fieldName)
f, found := t.fieldByName(fieldName)
if !found || !f.IsExported() || !isSupportedType(f.Type) {
return reflect.StructField{}, false
}
Expand Down
Loading