Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(codec/unknownproto): use official descriptorpb types, do not rely on gogo code gen types. #18541

Merged
merged 9 commits into from
Nov 29, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 115 additions & 49 deletions codec/unknownproto/unknown_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import (
"strings"
"sync"

"github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/gogoproto/jsonpb"
"github.com/cosmos/gogoproto/proto"
"github.com/cosmos/gogoproto/protoc-gen-gogo/descriptor"
"google.golang.org/protobuf/encoding/protowire"

"github.com/cosmos/cosmos-sdk/codec/types"
protov2 "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/descriptorpb"
)

const bit11NonCritical = 1 << 10
Expand Down Expand Up @@ -68,10 +68,9 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals
Type: reflect.ValueOf(msg).Type().String(),
TagNum: tagNum,
GotWireType: wireType,
WantWireType: protowire.Type(fieldDescProto.WireType()),
WantWireType: findWireTypeFromFieldDescriptorProtoType(fieldDescProto.GetType()),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error struct errMismatchedWireType should include the expected and actual wire types using descriptorpb.FieldDescriptorProto_Type instead of protowire.Type to align with the refactor.

}

testinginprod marked this conversation as resolved.
Show resolved Hide resolved
default:
isCriticalField := tagNum&bit11NonCritical == 0

Expand Down Expand Up @@ -101,14 +100,14 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals
bz = bz[n:]

// An unknown but non-critical field or just a scalar type (aka *INT and BYTES like).
if fieldDescProto == nil || fieldDescProto.IsScalar() {
if fieldDescProto == nil || isScalar(fieldDescProto) {
continue
}

protoMessageName := fieldDescProto.GetTypeName()
if protoMessageName == "" {
switch typ := fieldDescProto.GetType(); typ {
case descriptor.FieldDescriptorProto_TYPE_STRING, descriptor.FieldDescriptorProto_TYPE_BYTES:
case descriptorpb.FieldDescriptorProto_TYPE_STRING, descriptorpb.FieldDescriptorProto_TYPE_BYTES:
// At this point only TYPE_STRING is expected to be unregistered, since FieldDescriptorProto.IsScalar() returns false for
// TYPE_BYTES and TYPE_STRING as per
// https://github.com/cosmos/gogoproto/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118
Expand Down Expand Up @@ -199,68 +198,68 @@ func protoMessageForTypeName(protoMessageName string) (proto.Message, error) {
// checks is a mapping of protowire.Type to supported descriptor.FieldDescriptorProto_Type.
// it is implemented this way so as to have constant time lookups and avoid the overhead
// from O(n) walking of switch. The change to using this mapping boosts throughput by about 200%.
var checks = [...]map[descriptor.FieldDescriptorProto_Type]bool{
var checks = [...]map[descriptorpb.FieldDescriptorProto_Type]bool{
// "0 Varint: int32, int64, uint32, uint64, sint32, sint64, bool, enum"
0: {
descriptor.FieldDescriptorProto_TYPE_INT32: true,
descriptor.FieldDescriptorProto_TYPE_INT64: true,
descriptor.FieldDescriptorProto_TYPE_UINT32: true,
descriptor.FieldDescriptorProto_TYPE_UINT64: true,
descriptor.FieldDescriptorProto_TYPE_SINT32: true,
descriptor.FieldDescriptorProto_TYPE_SINT64: true,
descriptor.FieldDescriptorProto_TYPE_BOOL: true,
descriptor.FieldDescriptorProto_TYPE_ENUM: true,
descriptorpb.FieldDescriptorProto_TYPE_INT32: true,
descriptorpb.FieldDescriptorProto_TYPE_INT64: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_BOOL: true,
descriptorpb.FieldDescriptorProto_TYPE_ENUM: true,
},

// "1 64-bit: fixed64, sfixed64, double"
1: {
descriptor.FieldDescriptorProto_TYPE_FIXED64: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptor.FieldDescriptorProto_TYPE_DOUBLE: true,
descriptorpb.FieldDescriptorProto_TYPE_FIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: true,
},

// "2 Length-delimited: string, bytes, embedded messages, packed repeated fields"
2: {
descriptor.FieldDescriptorProto_TYPE_STRING: true,
descriptor.FieldDescriptorProto_TYPE_BYTES: true,
descriptor.FieldDescriptorProto_TYPE_MESSAGE: true,
descriptorpb.FieldDescriptorProto_TYPE_STRING: true,
descriptorpb.FieldDescriptorProto_TYPE_BYTES: true,
descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: true,
// The following types can be packed repeated.
// ref: "Only repeated fields of primitive numeric types (types which use the varint, 32-bit, or 64-bit wire types) can be declared "packed"."
// ref: https://developers.google.com/protocol-buffers/docs/encoding#packed
descriptor.FieldDescriptorProto_TYPE_INT32: true,
descriptor.FieldDescriptorProto_TYPE_INT64: true,
descriptor.FieldDescriptorProto_TYPE_UINT32: true,
descriptor.FieldDescriptorProto_TYPE_UINT64: true,
descriptor.FieldDescriptorProto_TYPE_SINT32: true,
descriptor.FieldDescriptorProto_TYPE_SINT64: true,
descriptor.FieldDescriptorProto_TYPE_BOOL: true,
descriptor.FieldDescriptorProto_TYPE_ENUM: true,
descriptor.FieldDescriptorProto_TYPE_FIXED64: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptor.FieldDescriptorProto_TYPE_DOUBLE: true,
descriptorpb.FieldDescriptorProto_TYPE_INT32: true,
descriptorpb.FieldDescriptorProto_TYPE_INT64: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_UINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT32: true,
descriptorpb.FieldDescriptorProto_TYPE_SINT64: true,
descriptorpb.FieldDescriptorProto_TYPE_BOOL: true,
descriptorpb.FieldDescriptorProto_TYPE_ENUM: true,
descriptorpb.FieldDescriptorProto_TYPE_FIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: true,
descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: true,
},

// "3 Start group: groups (deprecated)"
3: {
descriptor.FieldDescriptorProto_TYPE_GROUP: true,
descriptorpb.FieldDescriptorProto_TYPE_GROUP: true,
},

// "4 End group: groups (deprecated)"
4: {
descriptor.FieldDescriptorProto_TYPE_GROUP: true,
descriptorpb.FieldDescriptorProto_TYPE_GROUP: true,
},

// "5 32-bit: fixed32, sfixed32, float"
5: {
descriptor.FieldDescriptorProto_TYPE_FIXED32: true,
descriptor.FieldDescriptorProto_TYPE_SFIXED32: true,
descriptor.FieldDescriptorProto_TYPE_FLOAT: true,
descriptorpb.FieldDescriptorProto_TYPE_FIXED32: true,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: true,
descriptorpb.FieldDescriptorProto_TYPE_FLOAT: true,
},
}

// canEncodeType returns true if the wireType is suitable for encoding the descriptor type.
// See https://developers.google.com/protocol-buffers/docs/encoding#structure.
func canEncodeType(wireType protowire.Type, descType descriptor.FieldDescriptorProto_Type) bool {
func canEncodeType(wireType protowire.Type, descType descriptorpb.FieldDescriptorProto_Type) bool {
if iwt := int(wireType); iwt < 0 || iwt >= len(checks) {
return false
}
Expand Down Expand Up @@ -330,21 +329,21 @@ func (twt *errUnknownField) Error() string {
var _ error = (*errUnknownField)(nil)

var (
protoFileToDesc = make(map[string]*descriptor.FileDescriptorProto)
protoFileToDesc = make(map[string]*descriptorpb.FileDescriptorProto)
protoFileToDescMu sync.RWMutex
)

func unnestDesc(mdescs []*descriptor.DescriptorProto, indices []int) *descriptor.DescriptorProto {
func unnestDesc(mdescs []*descriptorpb.DescriptorProto, indices []int) *descriptorpb.DescriptorProto {
mdesc := mdescs[indices[0]]
for _, index := range indices[1:] {
mdesc = mdesc.NestedType[index]
}
return mdesc
}

// Invoking descriptor.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow
// Invoking descriptorpb.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow
// for every single message, thus the need for a hand-rolled custom version that's performant and cacheable.
func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescriptorProto, *descriptor.DescriptorProto, error) {
func extractFileDescMessageDesc(desc descriptorIface) (*descriptorpb.FileDescriptorProto, *descriptorpb.DescriptorProto, error) {
gzippedPb, indices := desc.Descriptor()

protoFileToDescMu.RLock()
Expand All @@ -365,8 +364,8 @@ func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescripto
return nil, nil, err
}

fdesc := new(descriptor.FileDescriptorProto)
if err := proto.Unmarshal(protoBlob, fdesc); err != nil {
fdesc := new(descriptorpb.FileDescriptorProto)
if err := protov2.Unmarshal(protoBlob, fdesc); err != nil {
return nil, nil, err
}

Expand All @@ -380,8 +379,8 @@ func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescripto
}

type descriptorMatch struct {
cache map[int32]*descriptor.FieldDescriptorProto
desc *descriptor.DescriptorProto
cache map[int32]*descriptorpb.FieldDescriptorProto
desc *descriptorpb.DescriptorProto
}

var (
Expand All @@ -390,7 +389,7 @@ var (
)

// getDescriptorInfo retrieves the mapping of field numbers to their respective field descriptors.
func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptor.FieldDescriptorProto, *descriptor.DescriptorProto, error) {
func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptorpb.FieldDescriptorProto, *descriptorpb.DescriptorProto, error) {
key := reflect.ValueOf(msg).Type()

descprotoCacheMu.RLock()
Expand All @@ -407,7 +406,7 @@ func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*desc
return nil, nil, err
}

tagNumToTypeIndex := make(map[int32]*descriptor.FieldDescriptorProto)
tagNumToTypeIndex := make(map[int32]*descriptorpb.FieldDescriptorProto)
for _, field := range md.Field {
tagNumToTypeIndex[field.GetNumber()] = field
}
Expand Down Expand Up @@ -441,3 +440,70 @@ func (d DefaultAnyResolver) Resolve(typeURL string) (proto.Message, error) {
}
return reflect.New(mt.Elem()).Interface().(proto.Message), nil
}

func findWireTypeFromFieldDescriptorProtoType(fieldType descriptorpb.FieldDescriptorProto_Type) protowire.Type {
testinginprod marked this conversation as resolved.
Show resolved Hide resolved
switch fieldType {
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
return 1
case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
return 5
case descriptorpb.FieldDescriptorProto_TYPE_INT64:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_UINT64:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_INT32:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_UINT32:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_FIXED64:
return 1
case descriptorpb.FieldDescriptorProto_TYPE_FIXED32:
return 5
case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_STRING:
return 2
case descriptorpb.FieldDescriptorProto_TYPE_GROUP:
return 2
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
return 2
case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
return 2
case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32:
return 5
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64:
return 1
case descriptorpb.FieldDescriptorProto_TYPE_SINT32:
return 0
case descriptorpb.FieldDescriptorProto_TYPE_SINT64:
return 0
}
panic("unreachable")
testinginprod marked this conversation as resolved.
Show resolved Hide resolved
}

func isScalar(field *descriptorpb.FieldDescriptorProto) bool {
testinginprod marked this conversation as resolved.
Show resolved Hide resolved
if field.Type == nil {
return false
}
switch *field.Type {
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
descriptorpb.FieldDescriptorProto_TYPE_FLOAT,
descriptorpb.FieldDescriptorProto_TYPE_INT64,
descriptorpb.FieldDescriptorProto_TYPE_UINT64,
descriptorpb.FieldDescriptorProto_TYPE_INT32,
descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
descriptorpb.FieldDescriptorProto_TYPE_BOOL,
descriptorpb.FieldDescriptorProto_TYPE_UINT32,
descriptorpb.FieldDescriptorProto_TYPE_ENUM,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
descriptorpb.FieldDescriptorProto_TYPE_SINT32,
descriptorpb.FieldDescriptorProto_TYPE_SINT64:
return true
default:
return false
}
}
Loading