Skip to content

Commit

Permalink
refactor: use proto.Message instead of ProtoMarshler (#14208)
Browse files Browse the repository at this point in the history
  • Loading branch information
tac0turtle authored Dec 8, 2022
1 parent 09bff2f commit 755c99a
Show file tree
Hide file tree
Showing 18 changed files with 131 additions and 175 deletions.
16 changes: 8 additions & 8 deletions codec/amino_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,42 @@ func NewAminoCodec(codec *LegacyAmino) *AminoCodec {
}

// Marshal implements BinaryMarshaler.Marshal method.
func (ac *AminoCodec) Marshal(o ProtoMarshaler) ([]byte, error) {
func (ac *AminoCodec) Marshal(o proto.Message) ([]byte, error) {
return ac.LegacyAmino.Marshal(o)
}

// MustMarshal implements BinaryMarshaler.MustMarshal method.
func (ac *AminoCodec) MustMarshal(o ProtoMarshaler) []byte {
func (ac *AminoCodec) MustMarshal(o proto.Message) []byte {
return ac.LegacyAmino.MustMarshal(o)
}

// MarshalLengthPrefixed implements BinaryMarshaler.MarshalLengthPrefixed method.
func (ac *AminoCodec) MarshalLengthPrefixed(o ProtoMarshaler) ([]byte, error) {
func (ac *AminoCodec) MarshalLengthPrefixed(o proto.Message) ([]byte, error) {
return ac.LegacyAmino.MarshalLengthPrefixed(o)
}

// MustMarshalLengthPrefixed implements BinaryMarshaler.MustMarshalLengthPrefixed method.
func (ac *AminoCodec) MustMarshalLengthPrefixed(o ProtoMarshaler) []byte {
func (ac *AminoCodec) MustMarshalLengthPrefixed(o proto.Message) []byte {
return ac.LegacyAmino.MustMarshalLengthPrefixed(o)
}

// Unmarshal implements BinaryMarshaler.Unmarshal method.
func (ac *AminoCodec) Unmarshal(bz []byte, ptr ProtoMarshaler) error {
func (ac *AminoCodec) Unmarshal(bz []byte, ptr proto.Message) error {
return ac.LegacyAmino.Unmarshal(bz, ptr)
}

// MustUnmarshal implements BinaryMarshaler.MustUnmarshal method.
func (ac *AminoCodec) MustUnmarshal(bz []byte, ptr ProtoMarshaler) {
func (ac *AminoCodec) MustUnmarshal(bz []byte, ptr proto.Message) {
ac.LegacyAmino.MustUnmarshal(bz, ptr)
}

// UnmarshalLengthPrefixed implements BinaryMarshaler.UnmarshalLengthPrefixed method.
func (ac *AminoCodec) UnmarshalLengthPrefixed(bz []byte, ptr ProtoMarshaler) error {
func (ac *AminoCodec) UnmarshalLengthPrefixed(bz []byte, ptr proto.Message) error {
return ac.LegacyAmino.UnmarshalLengthPrefixed(bz, ptr)
}

// MustUnmarshalLengthPrefixed implements BinaryMarshaler.MustUnmarshalLengthPrefixed method.
func (ac *AminoCodec) MustUnmarshalLengthPrefixed(bz []byte, ptr ProtoMarshaler) {
func (ac *AminoCodec) MustUnmarshalLengthPrefixed(bz []byte, ptr proto.Message) {
ac.LegacyAmino.MustUnmarshalLengthPrefixed(bz, ptr)
}

Expand Down
18 changes: 10 additions & 8 deletions codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,28 @@ type (

BinaryCodec interface {
// Marshal returns binary encoding of v.
Marshal(o ProtoMarshaler) ([]byte, error)
Marshal(o proto.Message) ([]byte, error)
// MustMarshal calls Marshal and panics if error is returned.
MustMarshal(o ProtoMarshaler) []byte
MustMarshal(o proto.Message) []byte

// MarshalLengthPrefixed returns binary encoding of v with bytes length prefix.
MarshalLengthPrefixed(o ProtoMarshaler) ([]byte, error)
MarshalLengthPrefixed(o proto.Message) ([]byte, error)
// MustMarshalLengthPrefixed calls MarshalLengthPrefixed and panics if
// error is returned.
MustMarshalLengthPrefixed(o ProtoMarshaler) []byte
MustMarshalLengthPrefixed(o proto.Message) []byte

// Unmarshal parses the data encoded with Marshal method and stores the result
// in the value pointed to by v.
Unmarshal(bz []byte, ptr ProtoMarshaler) error
Unmarshal(bz []byte, ptr proto.Message) error
// MustUnmarshal calls Unmarshal and panics if error is returned.
MustUnmarshal(bz []byte, ptr ProtoMarshaler)
MustUnmarshal(bz []byte, ptr proto.Message)

// Unmarshal parses the data encoded with UnmarshalLengthPrefixed method and stores
// the result in the value pointed to by v.
UnmarshalLengthPrefixed(bz []byte, ptr ProtoMarshaler) error
UnmarshalLengthPrefixed(bz []byte, ptr proto.Message) error
// MustUnmarshalLengthPrefixed calls UnmarshalLengthPrefixed and panics if error
// is returned.
MustUnmarshalLengthPrefixed(bz []byte, ptr ProtoMarshaler)
MustUnmarshalLengthPrefixed(bz []byte, ptr proto.Message)

// MarshalInterface is a helper method which will wrap `i` into `Any` for correct
// binary interface (de)serialization.
Expand Down Expand Up @@ -78,6 +78,8 @@ type (

// ProtoMarshaler defines an interface a type must implement to serialize itself
// as a protocol buffer defined message.
//
// Deprecated: Use proto.Message instead from github.com/cosmos/gogoproto/proto.
ProtoMarshaler interface {
proto.Message // for JSON serialization

Expand Down
20 changes: 10 additions & 10 deletions codec/codec_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ func testInterfaceMarshaling(require *require.Assertions, cdc interfaceMarshaler
}

type mustMarshaler struct {
marshal func(i codec.ProtoMarshaler) ([]byte, error)
mustMarshal func(i codec.ProtoMarshaler) []byte
unmarshal func(bz []byte, ptr codec.ProtoMarshaler) error
mustUnmarshal func(bz []byte, ptr codec.ProtoMarshaler)
marshal func(i proto.Message) ([]byte, error)
mustMarshal func(i proto.Message) []byte
unmarshal func(bz []byte, ptr proto.Message) error
mustUnmarshal func(bz []byte, ptr proto.Message)
}

type testCase struct {
name string
input codec.ProtoMarshaler
recv codec.ProtoMarshaler
input proto.Message
recv proto.Message
marshalErr bool
unmarshalErr bool
}
Expand Down Expand Up @@ -121,10 +121,10 @@ func testMarshaling(t *testing.T, cdc codec.Codec) {
m1 := mustMarshaler{cdc.Marshal, cdc.MustMarshal, cdc.Unmarshal, cdc.MustUnmarshal}
m2 := mustMarshaler{cdc.MarshalLengthPrefixed, cdc.MustMarshalLengthPrefixed, cdc.UnmarshalLengthPrefixed, cdc.MustUnmarshalLengthPrefixed}
m3 := mustMarshaler{
func(i codec.ProtoMarshaler) ([]byte, error) { return cdc.MarshalJSON(i) },
func(i codec.ProtoMarshaler) []byte { return cdc.MustMarshalJSON(i) },
func(bz []byte, ptr codec.ProtoMarshaler) error { return cdc.UnmarshalJSON(bz, ptr) },
func(bz []byte, ptr codec.ProtoMarshaler) { cdc.MustUnmarshalJSON(bz, ptr) },
func(i proto.Message) ([]byte, error) { return cdc.MarshalJSON(i) },
func(i proto.Message) []byte { return cdc.MustMarshalJSON(i) },
func(bz []byte, ptr proto.Message) error { return cdc.UnmarshalJSON(bz, ptr) },
func(bz []byte, ptr proto.Message) { cdc.MustUnmarshalJSON(bz, ptr) },
}

t.Run(tc.name+"_BinaryBare",
Expand Down
50 changes: 22 additions & 28 deletions codec/proto_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"strings"

legacyproto "github.com/golang/protobuf/proto" //nolint:staticcheck
"google.golang.org/grpc/encoding"
"google.golang.org/protobuf/proto"

Expand Down Expand Up @@ -42,19 +41,20 @@ func NewProtoCodec(interfaceRegistry types.InterfaceRegistry) *ProtoCodec {
// Marshal implements BinaryMarshaler.Marshal method.
// NOTE: this function must be used with a concrete type which
// implements proto.Message. For interface please use the codec.MarshalInterface
func (pc *ProtoCodec) Marshal(o ProtoMarshaler) ([]byte, error) {
func (pc *ProtoCodec) Marshal(o gogoproto.Message) ([]byte, error) {
// Size() check can catch the typed nil value.
if o == nil || o.Size() == 0 {
if o == nil || gogoproto.Size(o) == 0 {
// return empty bytes instead of nil, because nil has special meaning in places like store.Set
return []byte{}, nil
}
return o.Marshal()

return gogoproto.Marshal(o)
}

// MustMarshal implements BinaryMarshaler.MustMarshal method.
// NOTE: this function must be used with a concrete type which
// implements proto.Message. For interface please use the codec.MarshalInterface
func (pc *ProtoCodec) MustMarshal(o ProtoMarshaler) []byte {
func (pc *ProtoCodec) MustMarshal(o gogoproto.Message) []byte {
bz, err := pc.Marshal(o)
if err != nil {
panic(err)
Expand All @@ -64,19 +64,19 @@ func (pc *ProtoCodec) MustMarshal(o ProtoMarshaler) []byte {
}

// MarshalLengthPrefixed implements BinaryMarshaler.MarshalLengthPrefixed method.
func (pc *ProtoCodec) MarshalLengthPrefixed(o ProtoMarshaler) ([]byte, error) {
func (pc *ProtoCodec) MarshalLengthPrefixed(o gogoproto.Message) ([]byte, error) {
bz, err := pc.Marshal(o)
if err != nil {
return nil, err
}

var sizeBuf [binary.MaxVarintLen64]byte
n := binary.PutUvarint(sizeBuf[:], uint64(o.Size()))
n := binary.PutUvarint(sizeBuf[:], uint64(len(bz)))
return append(sizeBuf[:n], bz...), nil
}

// MustMarshalLengthPrefixed implements BinaryMarshaler.MustMarshalLengthPrefixed method.
func (pc *ProtoCodec) MustMarshalLengthPrefixed(o ProtoMarshaler) []byte {
func (pc *ProtoCodec) MustMarshalLengthPrefixed(o gogoproto.Message) []byte {
bz, err := pc.MarshalLengthPrefixed(o)
if err != nil {
panic(err)
Expand All @@ -88,8 +88,8 @@ func (pc *ProtoCodec) MustMarshalLengthPrefixed(o ProtoMarshaler) []byte {
// Unmarshal implements BinaryMarshaler.Unmarshal method.
// NOTE: this function must be used with a concrete type which
// implements proto.Message. For interface please use the codec.UnmarshalInterface
func (pc *ProtoCodec) Unmarshal(bz []byte, ptr ProtoMarshaler) error {
err := ptr.Unmarshal(bz)
func (pc *ProtoCodec) Unmarshal(bz []byte, ptr gogoproto.Message) error {
err := gogoproto.Unmarshal(bz, ptr)
if err != nil {
return err
}
Expand All @@ -103,14 +103,14 @@ func (pc *ProtoCodec) Unmarshal(bz []byte, ptr ProtoMarshaler) error {
// MustUnmarshal implements BinaryMarshaler.MustUnmarshal method.
// NOTE: this function must be used with a concrete type which
// implements proto.Message. For interface please use the codec.UnmarshalInterface
func (pc *ProtoCodec) MustUnmarshal(bz []byte, ptr ProtoMarshaler) {
func (pc *ProtoCodec) MustUnmarshal(bz []byte, ptr gogoproto.Message) {
if err := pc.Unmarshal(bz, ptr); err != nil {
panic(err)
}
}

// UnmarshalLengthPrefixed implements BinaryMarshaler.UnmarshalLengthPrefixed method.
func (pc *ProtoCodec) UnmarshalLengthPrefixed(bz []byte, ptr ProtoMarshaler) error {
func (pc *ProtoCodec) UnmarshalLengthPrefixed(bz []byte, ptr gogoproto.Message) error {
size, n := binary.Uvarint(bz)
if n < 0 {
return fmt.Errorf("invalid number of bytes read from length-prefixed encoding: %d", n)
Expand All @@ -127,7 +127,7 @@ func (pc *ProtoCodec) UnmarshalLengthPrefixed(bz []byte, ptr ProtoMarshaler) err
}

// MustUnmarshalLengthPrefixed implements BinaryMarshaler.MustUnmarshalLengthPrefixed method.
func (pc *ProtoCodec) MustUnmarshalLengthPrefixed(bz []byte, ptr ProtoMarshaler) {
func (pc *ProtoCodec) MustUnmarshalLengthPrefixed(bz []byte, ptr gogoproto.Message) {
if err := pc.UnmarshalLengthPrefixed(bz, ptr); err != nil {
panic(err)
}
Expand All @@ -137,13 +137,13 @@ func (pc *ProtoCodec) MustUnmarshalLengthPrefixed(bz []byte, ptr ProtoMarshaler)
// it marshals to JSON using proto codec.
// NOTE: this function must be used with a concrete type which
// implements proto.Message. For interface please use the codec.MarshalInterfaceJSON
//
//nolint:stdmethods
func (pc *ProtoCodec) MarshalJSON(o gogoproto.Message) ([]byte, error) {
m, ok := o.(ProtoMarshaler)
if !ok {
return nil, fmt.Errorf("cannot protobuf JSON encode unsupported type: %T", o)
if o == nil {
return nil, fmt.Errorf("cannot protobuf JSON encode nil")
}

return ProtoMarshalJSON(m, pc.interfaceRegistry)
return ProtoMarshalJSON(o, pc.interfaceRegistry)
}

// MustMarshalJSON implements JSONCodec.MustMarshalJSON method,
Expand All @@ -164,13 +164,11 @@ func (pc *ProtoCodec) MustMarshalJSON(o gogoproto.Message) []byte {
// NOTE: this function must be used with a concrete type which
// implements proto.Message. For interface please use the codec.UnmarshalInterfaceJSON
func (pc *ProtoCodec) UnmarshalJSON(bz []byte, ptr gogoproto.Message) error {
m, ok := ptr.(ProtoMarshaler)
if !ok {
if ptr == nil {
return fmt.Errorf("cannot protobuf JSON decode unsupported type: %T", ptr)
}

unmarshaler := jsonpb.Unmarshaler{AnyResolver: pc.interfaceRegistry}
err := unmarshaler.Unmarshal(strings.NewReader(string(bz)), m)
err := unmarshaler.Unmarshal(strings.NewReader(string(bz)), ptr)
if err != nil {
return err
}
Expand Down Expand Up @@ -283,10 +281,8 @@ func (g grpcProtoCodec) Marshal(v interface{}) ([]byte, error) {
switch m := v.(type) {
case proto.Message:
return proto.Marshal(m)
case ProtoMarshaler:
case gogoproto.Message:
return g.cdc.Marshal(m)
case legacyproto.Message:
return legacyproto.Marshal(m)
default:
return nil, fmt.Errorf("%w: cannot marshal type %T", errUnknownProtoType, v)
}
Expand All @@ -296,10 +292,8 @@ func (g grpcProtoCodec) Unmarshal(data []byte, v interface{}) error {
switch m := v.(type) {
case proto.Message:
return proto.Unmarshal(data, m)
case ProtoMarshaler:
case gogoproto.Message:
return g.cdc.Unmarshal(data, m)
case legacyproto.Message:
return legacyproto.Unmarshal(data, m)
default:
return fmt.Errorf("%w: cannot unmarshal type %T", errUnknownProtoType, v)
}
Expand Down
55 changes: 0 additions & 55 deletions codec/proto_codec_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package codec_test

import (
"errors"
"fmt"
"math"
"reflect"
"testing"
Expand Down Expand Up @@ -43,15 +41,6 @@ func TestProtoCodec(t *testing.T) {
testMarshaling(t, cdc)
}

type lyingProtoMarshaler struct {
codec.ProtoMarshaler
falseSize int
}

func (lpm *lyingProtoMarshaler) Size() int {
return lpm.falseSize
}

func TestEnsureRegistered(t *testing.T) {
interfaceRegistry := types.NewInterfaceRegistry()
cat := &testdata.Cat{Moniker: "Garfield"}
Expand Down Expand Up @@ -138,50 +127,6 @@ func grpcServerEncode(c encoding.Codec, msg interface{}) ([]byte, error) {
return b, nil
}

func TestProtoCodecUnmarshalLengthPrefixedChecks(t *testing.T) {
cdc := codec.NewProtoCodec(createTestInterfaceRegistry())

truth := &testdata.Cat{Lives: 9, Moniker: "glowing"}
realSize := len(cdc.MustMarshal(truth))

falseSizes := []int{
100,
5,
}

for _, falseSize := range falseSizes {
falseSize := falseSize

t.Run(fmt.Sprintf("ByMarshaling falseSize=%d", falseSize), func(t *testing.T) {
lpm := &lyingProtoMarshaler{
ProtoMarshaler: &testdata.Cat{Lives: 9, Moniker: "glowing"},
falseSize: falseSize,
}
var serialized []byte
require.NotPanics(t, func() { serialized = cdc.MustMarshalLengthPrefixed(lpm) })

recv := new(testdata.Cat)
gotErr := cdc.UnmarshalLengthPrefixed(serialized, recv)
var wantErr error
if falseSize > realSize {
wantErr = fmt.Errorf("not enough bytes to read; want: %d, got: %d", falseSize, realSize)
} else {
wantErr = fmt.Errorf("too many bytes to read; want: %d, got: %d", falseSize, realSize)
}
require.Equal(t, gotErr, wantErr)
})
}

t.Run("Crafted bad uvarint size", func(t *testing.T) {
crafted := []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}
recv := new(testdata.Cat)
gotErr := cdc.UnmarshalLengthPrefixed(crafted, recv)
require.Equal(t, gotErr, errors.New("invalid number of bytes read from length-prefixed encoding: -10"))

require.Panics(t, func() { cdc.MustUnmarshalLengthPrefixed(crafted, recv) })
})
}

func mustAny(msg proto.Message) *types.Any {
any, err := types.NewAnyWithValue(msg)
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/gov/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ func (s *IntegrationTestSuite) TestGetParamsGRPC() {
s.Run(tc.name, func() {
resp, err := testutil.GetRequest(tc.url)
s.Require().NoError(err)

err = val.ClientCtx.Codec.UnmarshalJSON(resp, tc.respType)

if tc.expErr {
Expand Down
Loading

0 comments on commit 755c99a

Please sign in to comment.