Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: TxRaw must follow ADR 027 (backport #9743) #9754

Merged
merged 1 commit into from
Jul 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion x/auth/tx/decoder.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package tx

import (
"fmt"

"google.golang.org/protobuf/encoding/protowire"

"github.com/cosmos/cosmos-sdk/codec"
"github.com/cosmos/cosmos-sdk/codec/unknownproto"
sdk "github.com/cosmos/cosmos-sdk/types"
Expand All @@ -11,10 +15,16 @@ import (
// DefaultTxDecoder returns a default protobuf TxDecoder using the provided Marshaler.
func DefaultTxDecoder(cdc codec.ProtoCodecMarshaler) sdk.TxDecoder {
return func(txBytes []byte) (sdk.Tx, error) {
// Make sure txBytes follow ADR-027.
err := rejectNonADR027(txBytes)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}

var raw tx.TxRaw

// reject all unknown proto fields in the root TxRaw
err := unknownproto.RejectUnknownFieldsStrict(txBytes, &raw, cdc.InterfaceRegistry())
err = unknownproto.RejectUnknownFieldsStrict(txBytes, &raw, cdc.InterfaceRegistry())
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrTxDecode, err.Error())
}
Expand Down Expand Up @@ -79,3 +89,77 @@ func DefaultJSONTxDecoder(cdc codec.ProtoCodecMarshaler) sdk.TxDecoder {
}, nil
}
}

// rejectNonADR027 rejects txBytes that do not follow ADR-027. This function
// only checks that:
// - field numbers are in ascending order (1, 2, and potentially multiple 3s),
// - and varints as as short as possible.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// - and varints as as short as possible.
// - and varints are as short as possible.

// All other ADR-027 edge cases (e.g. TxRaw fields having default values) will
// not happen with TxRaw.
Comment on lines +97 to +98
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// All other ADR-027 edge cases (e.g. TxRaw fields having default values) will
// not happen with TxRaw.
// All other ADR-027 edge cases (e.g. TxRaw fields having default values) are
// not applicable with TxRaw.

func rejectNonADR027(txBytes []byte) error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about moving the new functions to a separate package? It seams it's not necessary tide to auth/tx/decoder

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually and maybe should just be renamed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is not a generic ADR-027 checker, but a simpler version of it that checks it for txBytes

// Make sure all fields are ordered in ascending order with this variable.
prevTagNum := protowire.Number(0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we better document the process?

  • moving comment from line 116
  • define the bytes structure: <varint of bytes lenght><bytes sequence>

for len(txBytes) > 0 {
tagNum, wireType, m := protowire.ConsumeTag(txBytes)
if m < 0 {
return fmt.Errorf("invalid length; %w", protowire.ParseError(m))
}
if wireType != protowire.BytesType {
return fmt.Errorf("expected %d wire type, got %d", protowire.VarintType, wireType)
}
if tagNum < prevTagNum {
return fmt.Errorf("txRaw must follow ADR-027, got tagNum %d after tagNum %d", tagNum, prevTagNum)
}
prevTagNum = tagNum

// All 3 fields of TxRaw have wireType == 2, so their next component
// is a varint.
// We make sure that the varint is as short as possible.
lengthPrefix, m := protowire.ConsumeVarint(txBytes[m:])
if m < 0 {
return fmt.Errorf("invalid length; %w", protowire.ParseError(m))
}
n := varintMinLength(lengthPrefix)
if n != m {
return fmt.Errorf("length prefix varint for tagNum %d is not as short as possible, read %d, only need %d", tagNum, m, n)
}

// Skip over the bytes that store fieldNumber and wireType bytes.
_, _, m = protowire.ConsumeField(txBytes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also verify the "inner" type serialization?

Copy link
Contributor

@amaury1093 amaury1093 Jul 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo it's not needed, because it's done further down here and here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I made a suggestion in the new PR to add a short note about it.

if m < 0 {
return fmt.Errorf("invalid length; %w", protowire.ParseError(m))
}
txBytes = txBytes[m:]
}

return nil
}

// varintMinLength returns the minimum number of bytes necessary to encode an
// uint using varint encoding.
func varintMinLength(n uint64) int {
switch {
// Note: 1<<N == 2**N.
case n < 1<<7:
return 1
case n < 1<<14:
return 2
case n < 1<<21:
return 3
case n < 1<<28:
return 4
case n < 1<<35:
return 5
case n < 1<<42:
return 6
case n < 1<<49:
return 7
case n < 1<<56:
return 8
case n < 1<<63:
Comment on lines +146 to +160
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
case n < 1<<14:
return 2
case n < 1<<21:
return 3
case n < 1<<28:
return 4
case n < 1<<35:
return 5
case n < 1<<42:
return 6
case n < 1<<49:
return 7
case n < 1<<56:
return 8
case n < 1<<63:
case n < 1<<7*2:
return 2
case n < 1<<7*3:
return 3
case n < 1<<7*4:
return 4
case n < 1<<7*5:
return 5
case n < 1<<7*6:
return 6
case n < 1<<7*7:
return 7
case n < 1<<7*8:
return 8
case n < 1<<7*9:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. let's let the compiler to do multiplication
  2. so protobuf is limited to 2**70 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. protobuf varints go until uint64, and that's max 70 bytes yes

return 9
default:
return 10
}
}
134 changes: 128 additions & 6 deletions x/auth/tx/encode_decode_test.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
package tx

import (
"encoding/binary"
"fmt"
"math"
"testing"

sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"

"github.com/cosmos/cosmos-sdk/types/tx"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/signing"

"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/protowire"

"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx"
signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing"
"github.com/cosmos/cosmos-sdk/x/auth/signing"
)

func TestDefaultTxDecoderError(t *testing.T) {
Expand Down Expand Up @@ -159,3 +160,124 @@ func TestUnknownFields(t *testing.T) {
_, err = decoder(txBz)
require.Error(t, err)
}

func TestRejectNonADR027(t *testing.T) {
registry := codectypes.NewInterfaceRegistry()
cdc := codec.NewProtoCodec(registry)
decoder := DefaultTxDecoder(cdc)

body := &testdata.TestUpdatedTxBody{Memo: "AAA"} // Look for "65 65 65" when debugging the bytes stream.
bodyBz, err := body.Marshal()
require.NoError(t, err)
authInfo := &testdata.TestUpdatedAuthInfo{Fee: &tx.Fee{GasLimit: 127}} // Look for "127" when debugging the bytes stream.
authInfoBz, err := authInfo.Marshal()
txRaw := &tx.TxRaw{
BodyBytes: bodyBz,
AuthInfoBytes: authInfoBz,
Signatures: [][]byte{{41}, {42}, {43}}, // Look for "42" when debugging the bytes stream.
}

// We know these bytes are ADR-027-compliant.
txBz, err := txRaw.Marshal()

// From the `txBz`, we extract the 3 components:
// bodyBz, authInfoBz, sigsBz.
// In our tests, we will try to decode txs with those 3 components in all
// possible orders.
//
// Consume "BodyBytes" field.
_, _, m := protowire.ConsumeField(txBz)
bodyBz = append([]byte{}, txBz[:m]...)
txBz = txBz[m:] // Skip over "BodyBytes" bytes.
// Consume "AuthInfoBytes" field.
_, _, m = protowire.ConsumeField(txBz)
authInfoBz = append([]byte{}, txBz[:m]...)
txBz = txBz[m:] // Skip over "AuthInfoBytes" bytes.
// Consume "Signature" field, it's the remaining bytes.
sigsBz := append([]byte{}, txBz...)

// bodyBz's length prefix is 5, with `5` as varint encoding. We also try a
// longer varint encoding for 5: `133 00`.
longVarintBodyBz := append(append([]byte{bodyBz[0]}, byte(133), byte(00)), bodyBz[2:]...)

tests := []struct {
name string
txBz []byte
shouldErr bool
}{
{
"authInfo, body, sigs",
append(append(authInfoBz, bodyBz...), sigsBz...),
true,
},
{
"authInfo, sigs, body",
append(append(authInfoBz, sigsBz...), bodyBz...),
true,
},
{
"sigs, body, authInfo",
append(append(sigsBz, bodyBz...), authInfoBz...),
true,
},
{
"sigs, authInfo, body",
append(append(sigsBz, authInfoBz...), bodyBz...),
true,
},
{
"body, sigs, authInfo",
append(append(bodyBz, sigsBz...), authInfoBz...),
true,
},
{
"body, authInfo, sigs (valid txRaw)",
append(append(bodyBz, authInfoBz...), sigsBz...),
false,
},
{
"longer varint than needed",
append(append(longVarintBodyBz, authInfoBz...), sigsBz...),
true,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
_, err = decoder(tt.txBz)
if tt.shouldErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

func TestVarintMinLength(t *testing.T) {
tests := []struct {
n uint64
}{
{1<<7 - 1}, {1 << 7},
{1<<14 - 1}, {1 << 14},
{1<<21 - 1}, {1 << 21},
{1<<28 - 1}, {1 << 28},
{1<<35 - 1}, {1 << 35},
{1<<42 - 1}, {1 << 42},
{1<<49 - 1}, {1 << 49},
{1<<56 - 1}, {1 << 56},
{1<<63 - 1}, {1 << 63},
{math.MaxUint64},
}

for _, tt := range tests {
tt := tt
t.Run(fmt.Sprintf("test %d", tt.n), func(t *testing.T) {
l1 := varintMinLength(tt.n)
buf := make([]byte, binary.MaxVarintLen64)
l2 := binary.PutUvarint(buf, tt.n)
require.Equal(t, l2, l1)
})
}
}