Skip to content

Commit

Permalink
Support assignment of a wrapper field to null (#604)
Browse files Browse the repository at this point in the history
* Support for runtime assignment of 'null' to a wrapper field where the 'null' value indicates the wrapper will remain unset.
* Additional test for roundtripping unset wrappers
  • Loading branch information
TristonianJones authored Nov 7, 2022
1 parent 10141a6 commit 051835c
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 32 deletions.
1 change: 1 addition & 0 deletions common/types/json_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ var (
jsonValueType = reflect.TypeOf(&structpb.Value{})
jsonListValueType = reflect.TypeOf(&structpb.ListValue{})
jsonStructType = reflect.TypeOf(&structpb.Struct{})
jsonNullType = reflect.TypeOf(structpb.NullValue_NULL_VALUE)
)
14 changes: 13 additions & 1 deletion common/types/null.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,21 @@ var (
NullType = NewTypeValue("null_type")
// NullValue singleton.
NullValue = Null(structpb.NullValue_NULL_VALUE)

// golang reflect type for Null values.
nullReflectType = reflect.TypeOf(NullValue)
)

// ConvertToNative implements ref.Val.ConvertToNative.
func (n Null) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Int32:
return reflect.ValueOf(n).Convert(typeDesc).Interface(), nil
switch typeDesc {
case jsonNullType:
return structpb.NullValue_NULL_VALUE, nil
case nullReflectType:
return n, nil
}
case reflect.Ptr:
switch typeDesc {
case anyValueType:
Expand All @@ -53,6 +61,10 @@ func (n Null) ConvertToNative(typeDesc reflect.Type) (any, error) {
return anypb.New(pb.(proto.Message))
case jsonValueType:
return structpb.NewNullValue(), nil
case boolWrapperType, byteWrapperType, doubleWrapperType, floatWrapperType,
int32WrapperType, int64WrapperType, stringWrapperType, uint32WrapperType,
uint64WrapperType:
return nil, nil
}
case reflect.Interface:
nv := n.Value()
Expand Down
91 changes: 63 additions & 28 deletions common/types/null_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package types

import (
"errors"
"fmt"
"reflect"
"testing"

Expand All @@ -25,36 +27,60 @@ import (
)

func TestNullConvertToNative(t *testing.T) {
expected := structpb.NewNullValue()
// Json Value
val, err := NullValue.ConvertToNative(jsonValueType)
if err != nil {
t.Error("Fail to convert Null to jsonValueType")
}
if !proto.Equal(expected, val.(proto.Message)) {
t.Errorf("Messages were not equal, got '%v'", val)
}

// google.protobuf.Any
val, err = NullValue.ConvertToNative(anyValueType)
if err != nil {
t.Fatalf("NullValue.ConvertToNative(%v) failed: %v", anyValueType, err)
}
data, err := val.(*anypb.Any).UnmarshalNew()
if err != nil {
t.Fatalf("val.UnmarshalNew() failed: %v", err)
}
if !proto.Equal(expected, data) {
t.Errorf("Messages were not equal, got '%v'", data)
tests := []struct {
goType reflect.Type
out any
err error
}{
{
goType: jsonValueType,
out: structpb.NewNullValue(),
},
{
goType: jsonNullType,
out: structpb.NullValue_NULL_VALUE,
},
{
goType: anyValueType,
out: testPackAny(t, structpb.NewNullValue()),
},
{
goType: reflect.TypeOf(NullValue),
out: NullValue,
},
{goType: boolWrapperType},
{goType: byteWrapperType},
{goType: doubleWrapperType},
{goType: floatWrapperType},
{goType: int32WrapperType},
{goType: int64WrapperType},
{goType: stringWrapperType},
{goType: uint32WrapperType},
{goType: uint64WrapperType},
{
goType: reflect.TypeOf(1),
err: errors.New("type conversion error from 'null_type' to 'int'"),
},
}

// NullValue
val, err = NullValue.ConvertToNative(reflect.TypeOf(structpb.NullValue_NULL_VALUE))
if err != nil {
t.Error("Fail to convert Null to strcutpb.NullValue")
}
if val != structpb.NullValue_NULL_VALUE {
t.Errorf("Messages were not equal, got '%v'", val)
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
out, err := NullValue.ConvertToNative(tc.goType)
if err != nil {
if tc.err == nil {
t.Fatalf("NullValue.ConvertToType(%v) failed: %v", tc.goType, err)
}
if tc.err.Error() != err.Error() {
t.Errorf("NullValue.ConvertToType(%v) got error %v, wanted error %v", tc.goType, err, tc.err)
}
return
}
pbMsg, isPB := out.(proto.Message)
if (isPB && !proto.Equal(pbMsg, tc.out.(proto.Message))) || (!isPB && out != tc.out) {
t.Errorf("NullValue.ConvertToNative(%v) got %v, wanted %v", tc.goType, pbMsg, tc.out)
}
})
}
}

Expand Down Expand Up @@ -94,3 +120,12 @@ func TestNullValue(t *testing.T) {
t.Error("NullValue gets incorrect value.")
}
}

func testPackAny(t *testing.T, val proto.Message) *anypb.Any {
t.Helper()
out, err := anypb.New(val)
if err != nil {
t.Fatalf("anypb.New(%v) failed: %v", val, err)
}
return out
}
9 changes: 9 additions & 0 deletions common/types/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,9 @@ func msgSetField(target protoreflect.Message, field *pb.FieldDescription, val re
if err != nil {
return fieldTypeConversionError(field, err)
}
if v == nil {
return nil
}
switch pv := v.(type) {
case proto.Message:
v = pv.ProtoReflect()
Expand All @@ -496,6 +499,9 @@ func msgSetListField(target protoreflect.List, listField *pb.FieldDescription, l
if err != nil {
return fieldTypeConversionError(listField, err)
}
if elemVal == nil {
continue
}
switch ev := elemVal.(type) {
case proto.Message:
elemVal = ev.ProtoReflect()
Expand All @@ -520,6 +526,9 @@ func msgSetMapField(target protoreflect.Map, mapField *pb.FieldDescription, mapV
if err != nil {
return fieldTypeConversionError(mapField, err)
}
if v == nil {
continue
}
switch pv := v.(type) {
case proto.Message:
v = pv.ProtoReflect()
Expand Down
10 changes: 7 additions & 3 deletions common/types/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ func TestTypeRegistryNewValue_WrapperFields(t *testing.T) {
"google.expr.proto3.test.TestAllTypes",
map[string]ref.Val{
"single_int32_wrapper": Int(123),
"single_int64_wrapper": NullValue,
})
if IsError(exp) {
t.Fatalf("reg.NewValue() creation failed: %v", exp)
Expand All @@ -156,9 +157,12 @@ func TestTypeRegistryNewValue_WrapperFields(t *testing.T) {
if err != nil {
t.Fatalf("ConvertToNative() failed: %v", err)
}
ce := e.(*proto3pb.TestAllTypes)
if ce.GetSingleInt32Wrapper().GetValue() != int32(123) {
t.Errorf("single_int32_wrapper value %v not set to 123", ce)
out := e.(*proto3pb.TestAllTypes)
want := &proto3pb.TestAllTypes{
SingleInt32Wrapper: wrapperspb.Int32(123),
}
if !proto.Equal(out, want) {
t.Errorf("reg.NewValue() got %v, wanted %v", out, want)
}
}

Expand Down
49 changes: 49 additions & 0 deletions interpreter/interpreter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,55 @@ var (
RepeatedInt32: []int32{0, 2},
},
},
{
name: "literal_pb_wrapper_assign",
container: "google.expr.proto3.test",
types: []proto.Message{&proto3pb.TestAllTypes{}},
expr: `TestAllTypes{
single_int64_wrapper: 10,
single_int32_wrapper: TestAllTypes{}.single_int32_wrapper,
}`,
out: &proto3pb.TestAllTypes{
SingleInt64Wrapper: wrapperspb.Int64(10),
},
},
{
name: "literal_pb_wrapper_assign_roundtrip",
container: "google.expr.proto3.test",
types: []proto.Message{&proto3pb.TestAllTypes{}},
expr: `TestAllTypes{
single_int32_wrapper: TestAllTypes{}.single_int32_wrapper,
}.single_int32_wrapper == null`,
out: true,
},
{
name: "literal_pb_list_assign_null_wrapper",
container: "google.expr.proto3.test",
types: []proto.Message{&proto3pb.TestAllTypes{}},
expr: `TestAllTypes{
repeated_int32: [123, 456, TestAllTypes{}.single_int32_wrapper],
}`,
err: "field type conversion error",
},
{
name: "literal_pb_map_assign_null_entry_value",
container: "google.expr.proto3.test",
types: []proto.Message{&proto3pb.TestAllTypes{}},
expr: `TestAllTypes{
map_string_string: {
'hello': 'world',
'goodbye': TestAllTypes{}.single_string_wrapper,
},
}`,
err: "field type conversion error",
},
{
name: "unset_wrapper_access",
container: "google.expr.proto3.test",
types: []proto.Message{&proto3pb.TestAllTypes{}},
expr: `TestAllTypes{}.single_string_wrapper`,
out: types.NullValue,
},
{
name: "timestamp_eq_timestamp",
expr: `timestamp(0) == timestamp(0)`,
Expand Down

0 comments on commit 051835c

Please sign in to comment.