Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for struct tag name overrides in native types #941

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package ext

import (
"errors"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -155,6 +156,14 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
return tp.baseProvider.FindStructType(typeName)
}

TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
func toFieldName(f reflect.StructField) string {
if name, found := f.Tag.Lookup("cel"); found {
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
return name
}

return f.Name
}

// FindStructFieldNames looks up the type definition first from the native types, then from
// the backing provider type set. If found, a set of field names corresponding to the type
// will be returned.
Expand All @@ -163,7 +172,7 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
fieldCount := t.refType.NumField()
fields := make([]string, fieldCount)
for i := 0; i < fieldCount; i++ {
fields[i] = t.refType.Field(i).Name
fields[i] = toFieldName(t.refType.Field(i))
}
return fields, true
}
Expand All @@ -173,6 +182,18 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
return tp.baseProvider.FindStructFieldNames(typeName)
}

// valueFieldByName retrieves the corresponding reflect.Value field for the given field name, by
// searching for a matching field tag value or field name.
func valueFieldByName(target reflect.Value, fieldName string) reflect.Value {
for i := 0; i < target.Type().NumField(); i++ {
f := target.Type().Field(i)
if toFieldName(f) == fieldName {
return target.FieldByIndex(f.Index)
}
}
return reflect.Value{}
}

// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
Expand All @@ -192,12 +213,12 @@ func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*
Type: celType,
IsSet: func(obj any) bool {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
refField := valueFieldByName(refVal, fieldName)
return !refField.IsZero()
},
GetFrom: func(obj any) (any, error) {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
refField := valueFieldByName(refVal, fieldName)
return getFieldValue(tp, refField), nil
},
}, true
Expand Down Expand Up @@ -372,12 +393,13 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
if !fieldValue.IsValid() || fieldValue.IsZero() {
continue
}
fieldName := toFieldName(fieldType)
fieldCELVal := o.NativeToValue(fieldValue.Interface())
fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType)
if err != nil {
return nil, err
}
fields[fieldType.Name] = fieldJSONVal.(*structpb.Value)
fields[fieldName] = fieldJSONVal.(*structpb.Value)
}
return &structpb.Struct{Fields: fields}, nil
}
Expand Down Expand Up @@ -505,6 +527,10 @@ func newNativeTypes(rawType reflect.Type) ([]*nativeType, error) {
return result, err
}

var (
errDuplicatedFieldName = errors.New("field name already exists in struct")
)

func newNativeType(rawType reflect.Type) (*nativeType, error) {
refType := rawType
if refType.Kind() == reflect.Pointer {
Expand All @@ -513,6 +539,18 @@ func newNativeType(rawType reflect.Type) (*nativeType, error) {
if !isValidObjectType(refType) {
return nil, fmt.Errorf("unsupported reflect.Type %v, must be reflect.Struct", rawType)
}
fieldNames := make(map[string]struct{})

TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
for idx := 0; idx < refType.NumField(); idx++ {
field := refType.Field(idx)
fieldName := toFieldName(field)

if _, found := fieldNames[fieldName]; found {
return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName)
} else {
fieldNames[fieldName] = struct{}{}
}
}
return &nativeType{
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
Expand Down Expand Up @@ -569,9 +607,22 @@ func (t *nativeType) Value() any {
return t.typeName
}

// fieldByName returns the corresponding reflect.StructField for the give name either by matching
// field tag or field name.
func (t *nativeType) fieldByName(fieldName string) (reflect.StructField, bool) {
for i := 0; i < t.refType.NumField(); i++ {
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
f := t.refType.Field(i)
if toFieldName(f) == fieldName {
return f, true
}
}

return reflect.StructField{}, false
}

// hasField returns whether a field name has a corresponding Golang reflect.StructField
func (t *nativeType) hasField(fieldName string) (reflect.StructField, bool) {
f, found := t.refType.FieldByName(fieldName)
f, found := t.fieldByName(fieldName)
if !found || !f.IsExported() || !isSupportedType(f.Type) {
return reflect.StructField{}, false
}
Expand Down
45 changes: 36 additions & 9 deletions ext/native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package ext

import (
"errors"
"fmt"
"reflect"
"sort"
Expand Down Expand Up @@ -60,17 +61,20 @@ func TestNativeTypes(t *testing.T) {
ext.TestNestedType{
NestedListVal:['goodbye', 'cruel', 'world'],
NestedMapVal: {42: true},
custom_name: 'name',
},
],
ArrayVal: [
ext.TestNestedType{
NestedListVal:['goodbye', 'cruel', 'world'],
NestedMapVal: {42: true},
custom_name: 'name',
},
],
MapVal: {'map-key': ext.TestAllTypes{BoolVal: true}},
CustomSliceVal: [ext.TestNestedSliceType{Value: 'none'}],
CustomMapVal: {'even': ext.TestMapVal{Value: 'more'}},
custom_name: 'name',
}`,
out: &TestAllTypes{
NestedVal: &TestNestedType{NestedMapVal: map[int64]bool{1: false}},
Expand All @@ -87,17 +91,20 @@ func TestNativeTypes(t *testing.T) {
Uint64Val: uint64(200),
ListVal: []*TestNestedType{
{
NestedListVal: []string{"goodbye", "cruel", "world"},
NestedMapVal: map[int64]bool{42: true},
NestedListVal: []string{"goodbye", "cruel", "world"},
NestedMapVal: map[int64]bool{42: true},
NestedCustomName: "name",
},
},
ArrayVal: [1]*TestNestedType{{
NestedListVal: []string{"goodbye", "cruel", "world"},
NestedMapVal: map[int64]bool{42: true},
NestedListVal: []string{"goodbye", "cruel", "world"},
NestedMapVal: map[int64]bool{42: true},
NestedCustomName: "name",
}},
MapVal: map[string]TestAllTypes{"map-key": {BoolVal: true}},
CustomSliceVal: []TestNestedSliceType{{Value: "none"}},
CustomMapVal: map[string]TestMapVal{"even": {Value: "more"}},
CustomName: "name",
},
},
{
Expand Down Expand Up @@ -126,6 +133,7 @@ func TestNativeTypes(t *testing.T) {
{expr: `ext.TestAllTypes{}.TimestampVal == timestamp(0)`},
{expr: `test.TestAllTypes{}.single_timestamp == timestamp(0)`},
{expr: `[TestAllTypes{BoolVal: true}, TestAllTypes{BoolVal: false}].exists(t, t.BoolVal == true)`},
{expr: `[TestAllTypes{custom_name: 'Alice'}, TestAllTypes{custom_name: 'Bob'}].exists(t, t.custom_name == 'Alice')`},
{
expr: `tests.all(t, t.Int32Val > 17)`,
in: map[string]any{
Expand Down Expand Up @@ -186,7 +194,7 @@ func TestNativeFindStructFieldNames(t *testing.T) {
}{
{
typeName: "ext.TestNestedType",
fields: []string{"NestedListVal", "NestedMapVal"},
fields: []string{"NestedListVal", "NestedMapVal", "custom_name"},
},
{
typeName: "google.expr.proto3.test.TestAllTypes.NestedMessage",
Expand Down Expand Up @@ -287,7 +295,8 @@ func TestNativeTypesJsonSerialization(t *testing.T) {
NestedVal: TestNestedType{
NestedListVal: ["first", "second"],
},
StringVal: "string"
StringVal: "string",
custom_name: "name",
}`,
out: `{
"BoolVal": true,
Expand All @@ -307,7 +316,8 @@ func TestNativeTypesJsonSerialization(t *testing.T) {
"second"
]
},
"StringVal": "string"
"StringVal": "string",
"custom_name": "name"
}`,
},
}
Expand Down Expand Up @@ -647,6 +657,16 @@ func TestNativeTypeValue(t *testing.T) {
}
}

func TestNativeStructWithMultileSameFieldNames(t *testing.T) {
_, err := newNativeType(reflect.TypeOf(TestStructWithMultipleSameNames{}))
if err == nil {
t.Fatal("newNativeType() did not fail as expected")
}
if !errors.Is(err, errDuplicatedFieldName) {
t.Fatalf("newNativeType() exepected duplicated field name error, but got: %v", err)
}
}

// testEnv initializes the test environment common to all tests.
func testNativeEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env {
t.Helper()
Expand Down Expand Up @@ -678,9 +698,15 @@ func mustParseTime(t *testing.T, timestamp string) time.Time {
return out
}

type TestStructWithMultipleSameNames struct {
Name string
custom_name string `cel:"Name"`
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
}

type TestNestedType struct {
NestedListVal []string
NestedMapVal map[int64]bool
NestedListVal []string
NestedMapVal map[int64]bool
NestedCustomName string `cel:"custom_name"`
}

type TestAllTypes struct {
Expand All @@ -703,6 +729,7 @@ type TestAllTypes struct {
PbVal *proto3pb.TestAllTypes
CustomSliceVal []TestNestedSliceType
CustomMapVal map[string]TestMapVal
CustomName string `cel:"custom_name"`

// channel types are not supported
UnsupportedVal chan string
Expand Down