diff --git a/x/auth/tx/decoder.go b/x/auth/tx/decoder.go index a32a118604f9..87bef54e2702 100644 --- a/x/auth/tx/decoder.go +++ b/x/auth/tx/decoder.go @@ -11,6 +11,15 @@ import ( // DefaultTxDecoder returns a default protobuf TxDecoder using the provided Marshaler. func DefaultTxDecoder(cdc codec.ProtoCodecMarshaler) sdk.TxDecoder { return func(txBytes []byte) (sdk.Tx, error) { +<<<<<<< HEAD +======= + // Make sure txBytes follow ADR-027. + err := rejectNonADR027TxRaw(txBytes) + if err != nil { + return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error()) + } + +>>>>>>> 3771ebdf2 (refactor: Improve txRaw decoder code (#9769)) var raw tx.TxRaw // reject all unknown proto fields in the root TxRaw @@ -79,3 +88,85 @@ func DefaultJSONTxDecoder(cdc codec.ProtoCodecMarshaler) sdk.TxDecoder { }, nil } } +<<<<<<< HEAD +======= + +// rejectNonADR027TxRaw rejects txBytes that do not follow ADR-027. This is NOT +// a generic ADR-027 checker, it only applies decoding TxRaw. Specifically, it +// only checks that: +// - field numbers are in ascending order (1, 2, and potentially multiple 3s), +// - and varints are as short as possible. +// All other ADR-027 edge cases (e.g. default values) are not applicable with +// TxRaw. +func rejectNonADR027TxRaw(txBytes []byte) error { + // Make sure all fields are ordered in ascending order with this variable. + prevTagNum := protowire.Number(0) + + for len(txBytes) > 0 { + tagNum, wireType, m := protowire.ConsumeTag(txBytes) + if m < 0 { + return fmt.Errorf("invalid length; %w", protowire.ParseError(m)) + } + // TxRaw only has bytes fields. + if wireType != protowire.BytesType { + return fmt.Errorf("expected %d wire type, got %d", protowire.BytesType, wireType) + } + // Make sure fields are ordered in ascending order. + if tagNum < prevTagNum { + return fmt.Errorf("txRaw must follow ADR-027, got tagNum %d after tagNum %d", tagNum, prevTagNum) + } + prevTagNum = tagNum + + // All 3 fields of TxRaw have wireType == 2, so their next component + // is a varint, so we can safely call ConsumeVarint here. + // Byte structure: + // Inner fields are verified in `DefaultTxDecoder` + lengthPrefix, m := protowire.ConsumeVarint(txBytes[m:]) + if m < 0 { + return fmt.Errorf("invalid length; %w", protowire.ParseError(m)) + } + // We make sure that this varint is as short as possible. + n := varintMinLength(lengthPrefix) + if n != m { + return fmt.Errorf("length prefix varint for tagNum %d is not as short as possible, read %d, only need %d", tagNum, m, n) + } + + // Skip over the bytes that store fieldNumber and wireType bytes. + _, _, m = protowire.ConsumeField(txBytes) + if m < 0 { + return fmt.Errorf("invalid length; %w", protowire.ParseError(m)) + } + txBytes = txBytes[m:] + } + + return nil +} + +// varintMinLength returns the minimum number of bytes necessary to encode an +// uint using varint encoding. +func varintMinLength(n uint64) int { + switch { + // Note: 1<>>>>>> 3771ebdf2 (refactor: Improve txRaw decoder code (#9769))