Skip to content

Commit

Permalink
Fix unmarshalling of registered types
Browse files Browse the repository at this point in the history
This is a faithful backport of #53

---

In 7d3d258 I inadvterantly removed
support for json unmarshalling for the case when a type implements a
protobuf message but is a type that is registered.

For types that are registered through the `Register` function typeurl
is supposed to ignore the interfaces that the type implements and just
use json.

This change restores that behavior.

Signed-off-by: Brian Goff <[email protected]>
  • Loading branch information
cpuguy83 committed Nov 6, 2024
1 parent 1666bdb commit a044d53
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 18 deletions.
32 changes: 17 additions & 15 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type handler interface {
Marshaller(interface{}) func() ([]byte, error)
Unmarshaller(interface{}) func([]byte) error
TypeURL(interface{}) string
GetType(url string) reflect.Type
GetType(url string) (reflect.Type, bool)
}

// Definitions of common error types used throughout typeurl.
Expand Down Expand Up @@ -240,7 +240,7 @@ func MarshalAnyToProto(from interface{}) (*anypb.Any, error) {
}

func unmarshal(typeURL string, value []byte, v interface{}) (interface{}, error) {
t, err := getTypeByUrl(typeURL)
t, isProto, err := getTypeByUrl(typeURL)
if err != nil {
return nil, err
}
Expand All @@ -258,43 +258,45 @@ func unmarshal(typeURL string, value []byte, v interface{}) (interface{}, error)
}
}

pm, ok := v.(proto.Message)
if ok {
return v, proto.Unmarshal(value, pm)
}
if isProto {
pm, ok := v.(proto.Message)
if ok {
return v, proto.Unmarshal(value, pm)
}

for _, h := range handlers {
if unmarshal := h.Unmarshaller(v); unmarshal != nil {
return v, unmarshal(value)
for _, h := range handlers {
if unmarshal := h.Unmarshaller(v); unmarshal != nil {
return v, unmarshal(value)
}
}
}

// fallback to json unmarshaller
return v, json.Unmarshal(value, v)
}

func getTypeByUrl(url string) (reflect.Type, error) {
func getTypeByUrl(url string) (_ reflect.Type, isProto bool, _ error) {
mu.RLock()
for t, u := range registry {
if u == url {
mu.RUnlock()
return t, nil
return t, false, nil
}
}
mu.RUnlock()
mt, err := protoregistry.GlobalTypes.FindMessageByURL(url)
if err != nil {
if errors.Is(err, protoregistry.NotFound) {
for _, h := range handlers {
if t := h.GetType(url); t != nil {
return t, nil
if t, isProto := h.GetType(url); t != nil {
return t, isProto, nil
}
}
}
return nil, fmt.Errorf("type with url %s: %w", url, ErrNotFound)
return nil, false, fmt.Errorf("type with url %s: %w", url, ErrNotFound)
}
empty := mt.New().Interface()
return reflect.TypeOf(empty).Elem(), nil
return reflect.TypeOf(empty).Elem(), true, nil
}

func tryDereference(v interface{}) reflect.Type {
Expand Down
6 changes: 3 additions & 3 deletions types_gogo.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ func (gogoHandler) TypeURL(v interface{}) string {
return gogoproto.MessageName(pm)
}

func (gogoHandler) GetType(url string) reflect.Type {
func (gogoHandler) GetType(url string) (reflect.Type, bool) {
t := gogoproto.MessageType(url)
if t == nil {
return nil
return nil, false
}
return t.Elem()
return t.Elem(), true
}
22 changes: 22 additions & 0 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package typeurl

import (
"bytes"
"encoding/json"
"errors"
"reflect"
"testing"
Expand Down Expand Up @@ -241,3 +242,24 @@ func TestUnmarshalNotFound(t *testing.T) {
t.Fatalf("unexpected error unmarshalling type which does not exist: %v", err)
}
}

func TestUnmarshalJSON(t *testing.T) {
url := t.Name()
Register(&timestamppb.Timestamp{}, url)

expected := timestamppb.Now()

dt, err := json.Marshal(expected)
if err != nil {
t.Fatal(err)
}

var actual timestamppb.Timestamp
if err := UnmarshalToByTypeURL(url, dt, &actual); err != nil {
t.Fatal(err)
}

if !expected.AsTime().Equal(actual.AsTime()) {
t.Fatalf("expected value to be %q, got: %q", expected.AsTime(), actual.AsTime())
}
}

0 comments on commit a044d53

Please sign in to comment.