Skip to content

Commit

Permalink
fix: add a limit of number of bytes while scale decoding a slice (#3733)
Browse files Browse the repository at this point in the history
While scale decoding we first read the length of bytes to decode and
then we decode that many bytes.

Someone could ask us to decode some malicious bytes such that the read
length is unreasonably big. In such case, we would have to create a byte
slice as big as the length. The length in byte slice is an encoded as `Compact<u32>`. 

Current we are just reading length as uint and not checking if it goes beyond the bounds of uint32.
So, we would either panic because of `makeslice: len out of range` or because the asked length would be
more than the memory we have available in our machine.

We are going to put a check to makes sure that this length is less than max of uint32.
  • Loading branch information
kishansagathiya authored Feb 12, 2024
1 parent f5b9c4c commit 5edbf89
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
11 changes: 11 additions & 0 deletions dot/types/block_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,14 @@ func TestMustEncodeBlock(t *testing.T) {
})
}
}

func TestScaleUnmarshal(t *testing.T) {
block := NewBlock(*NewEmptyHeader(), Body{})
err := scale.Unmarshal(
[]byte{48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 4, 48, 48, 48, 48, 19, 48, 48, 48, 48, 48, 48, 48, 48}, //nolint
&block,
)

require.EqualError(t, err,
"decoding struct: unmarshalling field at index 0: decoding struct: unmarshalling field at index 4: decoding struct: unmarshalling field at index 1: byte array length 3472328296227680304 exceeds max value of uint32") //nolint
}
8 changes: 4 additions & 4 deletions internal/trie/node/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func Test_decodeBranch(t *testing.T) {
nodeVariant: branchVariant,
partialKeyLength: 1,
errWrapped: ErrDecodeChildHash,
errMessage: "cannot decode child hash: at index 10: reading byte: EOF",
errMessage: "cannot decode child hash: at index 10: decoding uint: reading byte: EOF",
},
"success_for_branch_variant": {
reader: bytes.NewBuffer(
Expand Down Expand Up @@ -246,7 +246,7 @@ func Test_decodeBranch(t *testing.T) {
nodeVariant: branchWithValueVariant,
partialKeyLength: 1,
errWrapped: ErrDecodeStorageValue,
errMessage: "cannot decode storage value: reading byte: EOF",
errMessage: "cannot decode storage value: decoding uint: reading byte: EOF",
},
"success_for_branch_with_value": {
reader: bytes.NewBuffer(concatByteSlices([][]byte{
Expand Down Expand Up @@ -372,7 +372,7 @@ func Test_decodeLeaf(t *testing.T) {
variant: leafVariant,
partialKeyLength: 1,
errWrapped: ErrDecodeStorageValue,
errMessage: "cannot decode storage value: unknown prefix for compact uint: 255",
errMessage: "cannot decode storage value: decoding uint: unknown prefix for compact uint: 255",
},
"missing_storage_value_data": {
reader: bytes.NewBuffer([]byte{
Expand All @@ -382,7 +382,7 @@ func Test_decodeLeaf(t *testing.T) {
variant: leafVariant,
partialKeyLength: 1,
errWrapped: ErrDecodeStorageValue,
errMessage: "cannot decode storage value: reading byte: EOF",
errMessage: "cannot decode storage value: decoding uint: reading byte: EOF",
},
"empty_storage_value_data": {
reader: bytes.NewBuffer(concatByteSlices([][]byte{
Expand Down
2 changes: 1 addition & 1 deletion lib/runtime/version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func Test_DecodeVersion(t *testing.T) {
{255, 255}, // error
}),
errWrapped: ErrDecodingVersionField,
errMessage: "decoding version field impl name: unknown prefix for compact uint: 255",
errMessage: "decoding version field impl name: decoding uint: unknown prefix for compact uint: 255",
},
// TODO add transaction version decode error once
// https://github.com/ChainSafe/gossamer/pull/2683
Expand Down
18 changes: 17 additions & 1 deletion pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"math"
"math/big"
"reflect"
)
Expand Down Expand Up @@ -475,6 +476,7 @@ func (ds *decodeState) decodeBool(dstv reflect.Value) (err error) {
return
}

// TODO: Should this be renamed to decodeCompactInt?
// decodeUint will decode unsigned integer
func (ds *decodeState) decodeUint(dstv reflect.Value) (err error) {
const maxUint32 = ^uint32(0)
Expand All @@ -491,8 +493,12 @@ func (ds *decodeState) decodeUint(dstv reflect.Value) (err error) {
var value uint64
switch mode {
case 0:
// 0b00: single-byte mode; upper six bits are the LE encoding of the value (valid only for
// values of 0-63).
value = uint64(prefix >> 2)
case 1:
// 0b01: two-byte mode: upper six bits and the following byte is the LE encoding of the
// value (valid only for values 64-(2**14-1))
buf, err := ds.ReadByte()
if err != nil {
return fmt.Errorf("reading byte: %w", err)
Expand All @@ -502,6 +508,8 @@ func (ds *decodeState) decodeUint(dstv reflect.Value) (err error) {
return fmt.Errorf("%w: %d (%b)", ErrU16OutOfRange, value, value)
}
case 2:
// 0b10: four-byte mode: upper six bits and the following three bytes are the LE encoding
// of the value (valid only for values (2**14)-(2**30-1)).
buf := make([]byte, 3)
_, err = ds.Read(buf)
if err != nil {
Expand All @@ -512,6 +520,9 @@ func (ds *decodeState) decodeUint(dstv reflect.Value) (err error) {
return fmt.Errorf("%w: %d (%b)", ErrU32OutOfRange, value, value)
}
case 3:
// 0b11: Big-integer mode: The upper six bits are the number of bytes following, plus four.
// The value is contained, LE encoded, in the bytes following. The final (most significant)
// byte must be non-zero. Valid only for values (2**30)-(2**536-1).
byteLen := (prefix >> 2) + 4
buf := make([]byte, byteLen)
_, err = ds.Read(buf)
Expand Down Expand Up @@ -557,7 +568,7 @@ func (ds *decodeState) decodeLength() (l uint, err error) {
dstv := reflect.New(reflect.TypeOf(l))
err = ds.decodeUint(dstv.Elem())
if err != nil {
return
return 0, fmt.Errorf("decoding uint: %w", err)
}
l = dstv.Elem().Interface().(uint)
return
Expand All @@ -570,6 +581,11 @@ func (ds *decodeState) decodeBytes(dstv reflect.Value) (err error) {
return
}

// bytes length in encoded as Compact<u32>, so it can't be more than math.MaxUint32
if length > math.MaxUint32 {
return fmt.Errorf("byte array length %d exceeds max value of uint32", length)
}

b := make([]byte, length)

if length > 0 {
Expand Down

0 comments on commit 5edbf89

Please sign in to comment.