diff --git a/codec/types/interface_registry.go b/codec/types/interface_registry.go index 5d7e72e890c0..ac210ade07f3 100644 --- a/codec/types/interface_registry.go +++ b/codec/types/interface_registry.go @@ -1,6 +1,7 @@ package types import ( + "errors" "fmt" "reflect" @@ -9,6 +10,17 @@ import ( "github.com/gogo/protobuf/proto" ) +var ( + + // MaxUnpackAnySubCalls extension point that defines the maximum number of sub-calls allowed during the unpacking + // process of protobuf Any messages. + MaxUnpackAnySubCalls = 100 + + // MaxUnpackAnyRecursionDepth extension point that defines the maximum allowed recursion depth during protobuf Any + // message unpacking. + MaxUnpackAnyRecursionDepth = 10 +) + // AnyUnpacker is an interface which allows safely unpacking types packed // in Any's against a whitelist of registered types type AnyUnpacker interface { @@ -194,6 +206,45 @@ func (registry *interfaceRegistry) ListImplementations(ifaceName string) []strin } func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error { + unpacker := &statefulUnpacker{ + registry: registry, + maxDepth: MaxUnpackAnyRecursionDepth, + maxCalls: &sharedCounter{count: MaxUnpackAnySubCalls}, + } + return unpacker.UnpackAny(any, iface) +} + +// sharedCounter is a type that encapsulates a counter value +type sharedCounter struct { + count int +} + +// statefulUnpacker is a struct that helps in deserializing and unpacking +// protobuf Any messages while maintaining certain stateful constraints. +type statefulUnpacker struct { + registry *interfaceRegistry + maxDepth int + maxCalls *sharedCounter +} + +// cloneForRecursion returns a new statefulUnpacker instance with maxDepth reduced by one, preserving the registry and maxCalls. +func (r statefulUnpacker) cloneForRecursion() *statefulUnpacker { + return &statefulUnpacker{ + registry: r.registry, + maxDepth: r.maxDepth - 1, + maxCalls: r.maxCalls, + } +} + +// UnpackAny deserializes a protobuf Any message into the provided interface, ensuring the interface is a pointer. +// It applies stateful constraints such as max depth and call limits, and unpacks interfaces if required. +func (r *statefulUnpacker) UnpackAny(any *Any, iface interface{}) error { + if r.maxDepth == 0 { + return errors.New("max depth exceeded") + } + if r.maxCalls.count == 0 { + return errors.New("call limit exceeded") + } // here we gracefully handle the case in which `any` itself is `nil`, which may occur in message decoding if any == nil { return nil @@ -204,6 +255,8 @@ func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error return nil } + r.maxCalls.count-- + rv := reflect.ValueOf(iface) if rv.Kind() != reflect.Ptr { return fmt.Errorf("UnpackAny expects a pointer") @@ -219,7 +272,7 @@ func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error } } - imap, found := registry.interfaceImpls[rt] + imap, found := r.registry.interfaceImpls[rt] if !found { return fmt.Errorf("no registered implementations of type %+v", rt) } @@ -239,7 +292,7 @@ func (registry *interfaceRegistry) UnpackAny(any *Any, iface interface{}) error return err } - err = UnpackInterfaces(msg, registry) + err = UnpackInterfaces(msg, r.cloneForRecursion()) if err != nil { return err } diff --git a/codec/unknownproto/unknown_fields.go b/codec/unknownproto/unknown_fields.go index 3af40ffed15b..7332a76e5c57 100644 --- a/codec/unknownproto/unknown_fields.go +++ b/codec/unknownproto/unknown_fields.go @@ -39,9 +39,23 @@ func RejectUnknownFieldsStrict(bz []byte, msg proto.Message, resolver jsonpb.Any // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message. // An AnyResolver must be provided for traversing inside google.protobuf.Any's. func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool, resolver jsonpb.AnyResolver) (hasUnknownNonCriticals bool, err error) { + // recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28 + return doRejectUnknownFields(bz, msg, allowUnknownNonCriticals, resolver, 10_000) +} + +func doRejectUnknownFields( + bz []byte, + msg proto.Message, + allowUnknownNonCriticals bool, + resolver jsonpb.AnyResolver, + recursionLimit int, +) (hasUnknownNonCriticals bool, err error) { if len(bz) == 0 { return hasUnknownNonCriticals, nil } + if recursionLimit == 0 { + return false, errors.New("recursion limit reached") + } desc, ok := msg.(descriptorIface) if !ok { @@ -129,7 +143,7 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals if protoMessageName == ".google.protobuf.Any" { // Firstly typecheck types.Any to ensure nothing snuck in. - hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver) + hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver, recursionLimit-1) hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild if err != nil { return hasUnknownNonCriticals, err @@ -152,7 +166,7 @@ func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals } } - hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver) + hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver, recursionLimit-1) hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild if err != nil { return hasUnknownNonCriticals, err