diff --git a/spanner/value.go b/spanner/value.go index 3f806f8b7657..8e75dfc15987 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -2718,7 +2718,7 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb result = &NullString{x, !isNull} } case spannerTypeByteArray: - if code != sppb.TypeCode_BYTES { + if code != sppb.TypeCode_BYTES && code != sppb.TypeCode_PROTO { return errTypeMismatch(code, acode, ptr) } if isNull { @@ -2735,7 +2735,7 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb } result = y case spannerTypeNonNullInt64, spannerTypeNullInt64: - if code != sppb.TypeCode_INT64 { + if code != sppb.TypeCode_INT64 && code != sppb.TypeCode_ENUM { return errTypeMismatch(code, acode, ptr) } if isNull { @@ -2913,7 +2913,7 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb } result = y case spannerTypeArrayOfByteArray: - if acode != sppb.TypeCode_BYTES { + if acode != sppb.TypeCode_BYTES && acode != sppb.TypeCode_PROTO { return errTypeMismatch(code, acode, ptr) } if isNull { @@ -2930,7 +2930,7 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb } result = y case spannerTypeArrayOfNonNullInt64, spannerTypeArrayOfNullInt64: - if acode != sppb.TypeCode_INT64 { + if acode != sppb.TypeCode_INT64 && acode != sppb.TypeCode_ENUM { return errTypeMismatch(code, acode, ptr) } if isNull { diff --git a/spanner/value_test.go b/spanner/value_test.go index 95d12643c812..6e7ac34c6f61 100644 --- a/spanner/value_test.go +++ b/spanner/value_test.go @@ -18,6 +18,7 @@ package spanner import ( "bytes" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -3289,3 +3290,74 @@ func TestNullJson(t *testing.T) { t.Fatalf("expected null, got %s", v) } } + +// Test decode for PROTO type when custom type is a variant of a base type +func TestDecodeProtoUsingBaseVariant(t *testing.T) { + // nullBytes is custom type from []byte base type. + type nullBytes []byte + + var b []byte + var nb nullBytes + + gcv := &GenericColumnValue{ + Type: &sppb.Type{ + Code: sppb.TypeCode_PROTO, + ProtoTypeFqn: "examples.ProtoType", + }, + Value: structpb.NewStringValue("Zm9vCg=="), + } + if err := gcv.Decode(&nb); err != nil { + t.Error(err) + } + if err := gcv.Decode(&b); err != nil { + t.Error(err) + } + + // Convert []byte and nullBytes to base64 encoding and then compare the contents. + if !testutil.Equal(base64.StdEncoding.EncodeToString(b), base64.StdEncoding.EncodeToString(nb)) { + t.Errorf("%s: got %+v, want %+v", "Test PROTO decode to []byte custom type", nb, b) + } +} + +// Test decode for PROTO type when custom type is a variant of a base type +func TestDecodeProtoArrayUsingBaseVariant(t *testing.T) { + // nullBytes is custom type from []byte base type. + type nullBytes [][]byte + + var b [][]byte + var nb nullBytes + + gcv := &GenericColumnValue{ + Type: &sppb.Type{ + Code: sppb.TypeCode_ARRAY, + ArrayElementType: &sppb.Type{ + Code: sppb.TypeCode_PROTO, + ProtoTypeFqn: "examples.ProtoType", + }, + }, + Value: structpb.NewListValue( + &structpb.ListValue{ + Values: []*structpb.Value{ + structpb.NewStringValue("Zm9vCg=="), + }, + }), + } + if err := gcv.Decode(&nb); err != nil { + t.Error(err) + } + if err := gcv.Decode(&b); err != nil { + t.Error(err) + } + + if len(b) != 1 { + t.Errorf("Expected length to be 1") + } + + if len(nb) != 1 { + t.Errorf("Expected length to be 1") + } + // Convert to base64 encoding and then compare the contents. + if !testutil.Equal(base64.StdEncoding.EncodeToString(b[0]), base64.StdEncoding.EncodeToString(nb[0])) { + t.Errorf("%s: got %+v, want %+v", "Test PROTO decode to [][]byte custom type", nb, b) + } +}