Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rlp: improve decoder stream implementation #22858

Merged
merged 9 commits into from
May 18, 2021
Merged
175 changes: 86 additions & 89 deletions rlp/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ func decodeDecoder(s *Stream, val reflect.Value) error {
}

// Kind represents the kind of value contained in an RLP stream.
type Kind int
type Kind int8

const (
Byte Kind = iota
Expand Down Expand Up @@ -561,22 +561,16 @@ type ByteReader interface {
type Stream struct {
r ByteReader

// number of bytes remaining to be read from r.
remaining uint64
limited bool

// auxiliary buffer for integer decoding
uintbuf []byte

kind Kind // kind of value ahead
size uint64 // size of value ahead
byteval byte // value of single byte in type tag
kinderr error // error from last readKind
stack []listpos
remaining uint64 // number of bytes remaining to be read from r
size uint64 // size of value ahead
kinderr error // error from last readKind
stack []uint64 // list sizes
uintbuf [8]byte // auxiliary buffer for integer decoding
kind Kind // kind of value ahead
byteval byte // value of single byte in type tag
limited bool // true if input limit is in effect
}

type listpos struct{ pos, size uint64 }

// NewStream creates a new decoding stream reading from r.
//
// If r implements the ByteReader interface, Stream will
Expand Down Expand Up @@ -646,8 +640,8 @@ func (s *Stream) Raw() ([]byte, error) {
s.kind = -1 // rearm Kind
return []byte{s.byteval}, nil
}
// the original header has already been read and is no longer
// available. read content and put a new header in front of it.
// The original header has already been read and is no longer
// available. Read content and put a new header in front of it.
start := headsize(size)
buf := make([]byte, uint64(start)+size)
if err := s.readFull(buf[start:]); err != nil {
Expand Down Expand Up @@ -730,7 +724,14 @@ func (s *Stream) List() (size uint64, err error) {
if kind != List {
return 0, ErrExpectedList
}
s.stack = append(s.stack, listpos{0, size})

// Remove size of inner list from outer list before pushing the new size
// onto the stack. This ensures that the remaining outer list size will
// be correct after the matching call to ListEnd.
if inList, limit := s.listLimit(); inList {
s.stack[len(s.stack)-1] = limit - size
}
s.stack = append(s.stack, size)
s.kind = -1
s.size = 0
return size, nil
Expand All @@ -739,17 +740,13 @@ func (s *Stream) List() (size uint64, err error) {
// ListEnd returns to the enclosing list.
// The input reader must be positioned at the end of a list.
func (s *Stream) ListEnd() error {
if len(s.stack) == 0 {
// Ensure that no more data is remaining in the current list.
if inList, listLimit := s.listLimit(); !inList {
return errNotInList
}
tos := s.stack[len(s.stack)-1]
if tos.pos != tos.size {
} else if listLimit > 0 {
return errNotAtEOL
}
s.stack = s.stack[:len(s.stack)-1] // pop
if len(s.stack) > 0 {
s.stack[len(s.stack)-1].pos += tos.size
}
s.kind = -1
s.size = 0
return nil
Expand Down Expand Up @@ -777,7 +774,7 @@ func (s *Stream) Decode(val interface{}) error {

err = decoder(s, rval.Elem())
if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 {
// add decode target type to error so context has more meaning
// Add decode target type to error so context has more meaning.
decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")"))
}
return err
Expand All @@ -800,6 +797,9 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
case *bytes.Reader:
s.remaining = uint64(br.Len())
s.limited = true
case *bytes.Buffer:
s.remaining = uint64(br.Len())
s.limited = true
case *strings.Reader:
s.remaining = uint64(br.Len())
s.limited = true
Expand All @@ -818,10 +818,8 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
s.size = 0
s.kind = -1
s.kinderr = nil
if s.uintbuf == nil {
s.uintbuf = make([]byte, 8)
}
s.byteval = 0
s.uintbuf = [8]byte{}
}

// Kind returns the kind and size of the next value in the
Expand All @@ -836,35 +834,29 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
// the value. Subsequent calls to Kind (until the value is decoded)
// will not advance the input reader and return cached information.
func (s *Stream) Kind() (kind Kind, size uint64, err error) {
var tos *listpos
if len(s.stack) > 0 {
tos = &s.stack[len(s.stack)-1]
}
if s.kind < 0 {
s.kinderr = nil
// Don't read further if we're at the end of the
// innermost list.
if tos != nil && tos.pos == tos.size {
return 0, 0, EOL
}
s.kind, s.size, s.kinderr = s.readKind()
if s.kinderr == nil {
if tos == nil {
// At toplevel, check that the value is smaller
// than the remaining input length.
if s.limited && s.size > s.remaining {
s.kinderr = ErrValueTooLarge
}
} else {
// Inside a list, check that the value doesn't overflow the list.
if s.size > tos.size-tos.pos {
s.kinderr = ErrElemTooLarge
}
}
if s.kind >= 0 {
return s.kind, s.size, s.kinderr
}

// Check for end of list. This needs to be done here because readKind
// checks against the list size, and would return the wrong error.
inList, listLimit := s.listLimit()
if inList && listLimit == 0 {
return 0, 0, EOL
}
// Read the actual size tag.
s.kind, s.size, s.kinderr = s.readKind()
if s.kinderr == nil {
// Check the data size of the value ahead against input limits. This
// is done here because many decoders require allocating an input
// buffer matching the value size. Checking it here protects those
// decoders from inputs declaring very large value size.
if inList && s.size > listLimit {
s.kinderr = ErrElemTooLarge
} else if s.limited && s.size > s.remaining {
s.kinderr = ErrValueTooLarge
}
}
// Note: this might return a sticky error generated
// by an earlier call to readKind.
return s.kind, s.size, s.kinderr
}

Expand All @@ -891,37 +883,35 @@ func (s *Stream) readKind() (kind Kind, size uint64, err error) {
s.byteval = b
return Byte, 0, nil
case b < 0xB8:
// Otherwise, if a string is 0-55 bytes long,
// the RLP encoding consists of a single byte with value 0x80 plus the
// length of the string followed by the string. The range of the first
// byte is thus [0x80, 0xB7].
// Otherwise, if a string is 0-55 bytes long, the RLP encoding consists
// of a single byte with value 0x80 plus the length of the string
// followed by the string. The range of the first byte is thus [0x80, 0xB7].
return String, uint64(b - 0x80), nil
case b < 0xC0:
// If a string is more than 55 bytes long, the
// RLP encoding consists of a single byte with value 0xB7 plus the length
// of the length of the string in binary form, followed by the length of
// the string, followed by the string. For example, a length-1024 string
// would be encoded as 0xB90400 followed by the string. The range of
// the first byte is thus [0xB8, 0xBF].
// If a string is more than 55 bytes long, the RLP encoding consists of a
// single byte with value 0xB7 plus the length of the length of the
// string in binary form, followed by the length of the string, followed
// by the string. For example, a length-1024 string would be encoded as
// 0xB90400 followed by the string. The range of the first byte is thus
// [0xB8, 0xBF].
size, err = s.readUint(b - 0xB7)
if err == nil && size < 56 {
err = ErrCanonSize
}
return String, size, err
case b < 0xF8:
// If the total payload of a list
// (i.e. the combined length of all its items) is 0-55 bytes long, the
// RLP encoding consists of a single byte with value 0xC0 plus the length
// of the list followed by the concatenation of the RLP encodings of the
// items. The range of the first byte is thus [0xC0, 0xF7].
// If the total payload of a list (i.e. the combined length of all its
// items) is 0-55 bytes long, the RLP encoding consists of a single byte
// with value 0xC0 plus the length of the list followed by the
// concatenation of the RLP encodings of the items. The range of the
// first byte is thus [0xC0, 0xF7].
return List, uint64(b - 0xC0), nil
default:
// If the total payload of a list is more than 55 bytes long,
// the RLP encoding consists of a single byte with value 0xF7
// plus the length of the length of the payload in binary
// form, followed by the length of the payload, followed by
// the concatenation of the RLP encodings of the items. The
// range of the first byte is thus [0xF8, 0xFF].
// If the total payload of a list is more than 55 bytes long, the RLP
// encoding consists of a single byte with value 0xF7 plus the length of
// the length of the payload in binary form, followed by the length of
// the payload, followed by the concatenation of the RLP encodings of
// the items. The range of the first byte is thus [0xF8, 0xFF].
size, err = s.readUint(b - 0xF7)
if err == nil && size < 56 {
err = ErrCanonSize
Expand All @@ -940,22 +930,20 @@ func (s *Stream) readUint(size byte) (uint64, error) {
return uint64(b), err
default:
start := int(8 - size)
for i := 0; i < start; i++ {
s.uintbuf[i] = 0
}
s.uintbuf = [8]byte{}
if err := s.readFull(s.uintbuf[start:]); err != nil {
return 0, err
}
if s.uintbuf[start] == 0 {
// Note: readUint is also used to decode integer
// values. The error needs to be adjusted to become
// ErrCanonInt in this case.
// Note: readUint is also used to decode integer values.
// The error needs to be adjusted to become ErrCanonInt in this case.
return 0, ErrCanonSize
}
return binary.BigEndian.Uint64(s.uintbuf), nil
return binary.BigEndian.Uint64(s.uintbuf[:]), nil
}
}

// readFull reads into buf from the underlying stream.
func (s *Stream) readFull(buf []byte) (err error) {
if err := s.willRead(uint64(len(buf))); err != nil {
return err
Expand All @@ -977,6 +965,7 @@ func (s *Stream) readFull(buf []byte) (err error) {
return err
}

// readByte reads a single byte from the underlying stream.
func (s *Stream) readByte() (byte, error) {
if err := s.willRead(1); err != nil {
return 0, err
Expand All @@ -988,16 +977,16 @@ func (s *Stream) readByte() (byte, error) {
return b, err
}

// willRead is called before any read from the underlying stream. It checks
// n against size limits, and updates the limits if n doesn't overflow them.
func (s *Stream) willRead(n uint64) error {
s.kind = -1 // rearm Kind

if len(s.stack) > 0 {
// check list overflow
tos := s.stack[len(s.stack)-1]
if n > tos.size-tos.pos {
if inList, limit := s.listLimit(); inList {
if n > limit {
return ErrElemTooLarge
}
s.stack[len(s.stack)-1].pos += n
s.stack[len(s.stack)-1] = limit - n
}
if s.limited {
if n > s.remaining {
Expand All @@ -1007,3 +996,11 @@ func (s *Stream) willRead(n uint64) error {
}
return nil
}

// listLimit returns the amount of data remaining in the innermost list.
func (s *Stream) listLimit() (inList bool, limit uint64) {
if len(s.stack) == 0 {
return false, 0
}
return true, s.stack[len(s.stack)-1]
}