diff --git a/common/types/json_value.go b/common/types/json_value.go index cd63b519..13a4efe7 100644 --- a/common/types/json_value.go +++ b/common/types/json_value.go @@ -25,4 +25,5 @@ var ( jsonValueType = reflect.TypeOf(&structpb.Value{}) jsonListValueType = reflect.TypeOf(&structpb.ListValue{}) jsonStructType = reflect.TypeOf(&structpb.Struct{}) + jsonNullType = reflect.TypeOf(structpb.NullValue_NULL_VALUE) ) diff --git a/common/types/null.go b/common/types/null.go index cce6b0c0..38927a11 100644 --- a/common/types/null.go +++ b/common/types/null.go @@ -34,13 +34,21 @@ var ( NullType = NewTypeValue("null_type") // NullValue singleton. NullValue = Null(structpb.NullValue_NULL_VALUE) + + // golang reflect type for Null values. + nullReflectType = reflect.TypeOf(NullValue) ) // ConvertToNative implements ref.Val.ConvertToNative. func (n Null) ConvertToNative(typeDesc reflect.Type) (any, error) { switch typeDesc.Kind() { case reflect.Int32: - return reflect.ValueOf(n).Convert(typeDesc).Interface(), nil + switch typeDesc { + case jsonNullType: + return structpb.NullValue_NULL_VALUE, nil + case nullReflectType: + return n, nil + } case reflect.Ptr: switch typeDesc { case anyValueType: @@ -53,6 +61,10 @@ func (n Null) ConvertToNative(typeDesc reflect.Type) (any, error) { return anypb.New(pb.(proto.Message)) case jsonValueType: return structpb.NewNullValue(), nil + case boolWrapperType, byteWrapperType, doubleWrapperType, floatWrapperType, + int32WrapperType, int64WrapperType, stringWrapperType, uint32WrapperType, + uint64WrapperType: + return nil, nil } case reflect.Interface: nv := n.Value() diff --git a/common/types/null_test.go b/common/types/null_test.go index fb0fa316..3ff92ba5 100644 --- a/common/types/null_test.go +++ b/common/types/null_test.go @@ -15,6 +15,8 @@ package types import ( + "errors" + "fmt" "reflect" "testing" @@ -25,36 +27,60 @@ import ( ) func TestNullConvertToNative(t *testing.T) { - expected := structpb.NewNullValue() - // Json Value - val, err := NullValue.ConvertToNative(jsonValueType) - if err != nil { - t.Error("Fail to convert Null to jsonValueType") - } - if !proto.Equal(expected, val.(proto.Message)) { - t.Errorf("Messages were not equal, got '%v'", val) - } - - // google.protobuf.Any - val, err = NullValue.ConvertToNative(anyValueType) - if err != nil { - t.Fatalf("NullValue.ConvertToNative(%v) failed: %v", anyValueType, err) - } - data, err := val.(*anypb.Any).UnmarshalNew() - if err != nil { - t.Fatalf("val.UnmarshalNew() failed: %v", err) - } - if !proto.Equal(expected, data) { - t.Errorf("Messages were not equal, got '%v'", data) + tests := []struct { + goType reflect.Type + out any + err error + }{ + { + goType: jsonValueType, + out: structpb.NewNullValue(), + }, + { + goType: jsonNullType, + out: structpb.NullValue_NULL_VALUE, + }, + { + goType: anyValueType, + out: testPackAny(t, structpb.NewNullValue()), + }, + { + goType: reflect.TypeOf(NullValue), + out: NullValue, + }, + {goType: boolWrapperType}, + {goType: byteWrapperType}, + {goType: doubleWrapperType}, + {goType: floatWrapperType}, + {goType: int32WrapperType}, + {goType: int64WrapperType}, + {goType: stringWrapperType}, + {goType: uint32WrapperType}, + {goType: uint64WrapperType}, + { + goType: reflect.TypeOf(1), + err: errors.New("type conversion error from 'null_type' to 'int'"), + }, } - // NullValue - val, err = NullValue.ConvertToNative(reflect.TypeOf(structpb.NullValue_NULL_VALUE)) - if err != nil { - t.Error("Fail to convert Null to strcutpb.NullValue") - } - if val != structpb.NullValue_NULL_VALUE { - t.Errorf("Messages were not equal, got '%v'", val) + for i, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + out, err := NullValue.ConvertToNative(tc.goType) + if err != nil { + if tc.err == nil { + t.Fatalf("NullValue.ConvertToType(%v) failed: %v", tc.goType, err) + } + if tc.err.Error() != err.Error() { + t.Errorf("NullValue.ConvertToType(%v) got error %v, wanted error %v", tc.goType, err, tc.err) + } + return + } + pbMsg, isPB := out.(proto.Message) + if (isPB && !proto.Equal(pbMsg, tc.out.(proto.Message))) || (!isPB && out != tc.out) { + t.Errorf("NullValue.ConvertToNative(%v) got %v, wanted %v", tc.goType, pbMsg, tc.out) + } + }) } } @@ -94,3 +120,12 @@ func TestNullValue(t *testing.T) { t.Error("NullValue gets incorrect value.") } } + +func testPackAny(t *testing.T, val proto.Message) *anypb.Any { + t.Helper() + out, err := anypb.New(val) + if err != nil { + t.Fatalf("anypb.New(%v) failed: %v", val, err) + } + return out +} diff --git a/common/types/provider.go b/common/types/provider.go index c8d68c12..e66951f5 100644 --- a/common/types/provider.go +++ b/common/types/provider.go @@ -480,6 +480,9 @@ func msgSetField(target protoreflect.Message, field *pb.FieldDescription, val re if err != nil { return fieldTypeConversionError(field, err) } + if v == nil { + return nil + } switch pv := v.(type) { case proto.Message: v = pv.ProtoReflect() @@ -496,6 +499,9 @@ func msgSetListField(target protoreflect.List, listField *pb.FieldDescription, l if err != nil { return fieldTypeConversionError(listField, err) } + if elemVal == nil { + continue + } switch ev := elemVal.(type) { case proto.Message: elemVal = ev.ProtoReflect() @@ -520,6 +526,9 @@ func msgSetMapField(target protoreflect.Map, mapField *pb.FieldDescription, mapV if err != nil { return fieldTypeConversionError(mapField, err) } + if v == nil { + continue + } switch pv := v.(type) { case proto.Message: v = pv.ProtoReflect() diff --git a/common/types/provider_test.go b/common/types/provider_test.go index 7f596e00..26fc8b64 100644 --- a/common/types/provider_test.go +++ b/common/types/provider_test.go @@ -148,6 +148,7 @@ func TestTypeRegistryNewValue_WrapperFields(t *testing.T) { "google.expr.proto3.test.TestAllTypes", map[string]ref.Val{ "single_int32_wrapper": Int(123), + "single_int64_wrapper": NullValue, }) if IsError(exp) { t.Fatalf("reg.NewValue() creation failed: %v", exp) @@ -156,9 +157,12 @@ func TestTypeRegistryNewValue_WrapperFields(t *testing.T) { if err != nil { t.Fatalf("ConvertToNative() failed: %v", err) } - ce := e.(*proto3pb.TestAllTypes) - if ce.GetSingleInt32Wrapper().GetValue() != int32(123) { - t.Errorf("single_int32_wrapper value %v not set to 123", ce) + out := e.(*proto3pb.TestAllTypes) + want := &proto3pb.TestAllTypes{ + SingleInt32Wrapper: wrapperspb.Int32(123), + } + if !proto.Equal(out, want) { + t.Errorf("reg.NewValue() got %v, wanted %v", out, want) } } diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 0b6546ec..6473b4da 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -705,6 +705,55 @@ var ( RepeatedInt32: []int32{0, 2}, }, }, + { + name: "literal_pb_wrapper_assign", + container: "google.expr.proto3.test", + types: []proto.Message{&proto3pb.TestAllTypes{}}, + expr: `TestAllTypes{ + single_int64_wrapper: 10, + single_int32_wrapper: TestAllTypes{}.single_int32_wrapper, + }`, + out: &proto3pb.TestAllTypes{ + SingleInt64Wrapper: wrapperspb.Int64(10), + }, + }, + { + name: "literal_pb_wrapper_assign_roundtrip", + container: "google.expr.proto3.test", + types: []proto.Message{&proto3pb.TestAllTypes{}}, + expr: `TestAllTypes{ + single_int32_wrapper: TestAllTypes{}.single_int32_wrapper, + }.single_int32_wrapper == null`, + out: true, + }, + { + name: "literal_pb_list_assign_null_wrapper", + container: "google.expr.proto3.test", + types: []proto.Message{&proto3pb.TestAllTypes{}}, + expr: `TestAllTypes{ + repeated_int32: [123, 456, TestAllTypes{}.single_int32_wrapper], + }`, + err: "field type conversion error", + }, + { + name: "literal_pb_map_assign_null_entry_value", + container: "google.expr.proto3.test", + types: []proto.Message{&proto3pb.TestAllTypes{}}, + expr: `TestAllTypes{ + map_string_string: { + 'hello': 'world', + 'goodbye': TestAllTypes{}.single_string_wrapper, + }, + }`, + err: "field type conversion error", + }, + { + name: "unset_wrapper_access", + container: "google.expr.proto3.test", + types: []proto.Message{&proto3pb.TestAllTypes{}}, + expr: `TestAllTypes{}.single_string_wrapper`, + out: types.NullValue, + }, { name: "timestamp_eq_timestamp", expr: `timestamp(0) == timestamp(0)`,