From 19f781bf55927a4757406213c3faa4d4bd9c4945 Mon Sep 17 00:00:00 2001 From: Alfonso Acosta Date: Tue, 21 Nov 2023 18:35:35 +0100 Subject: [PATCH] Add LenReader to cover readers which don't provide Len() --- xdr3/decode.go | 97 ++++++++++++++++++++++++++++++++++----------- xdr3/decode_test.go | 37 ++++++++++++----- 2 files changed, 102 insertions(+), 32 deletions(-) diff --git a/xdr3/decode.go b/xdr3/decode.go index 6a58ec1..1c8f383 100644 --- a/xdr3/decode.go +++ b/xdr3/decode.go @@ -33,6 +33,28 @@ var errIODecode = "%s while decoding %d bytes" // DecodeDefaultMaxDepth is the default maximum decoding depth const DecodeDefaultMaxDepth = 200 +// DecodeOptions configures how Decoding is done. +type DecodeOptions struct { + // MaxDepth is the maximum decoding depth (i.e. maximum nesting of data structures). + // It prevents infinite recursions in cyclic datastructures and determines the maximum callstack growth. + // If set to 0, DecodeDefaultMaxDepth will be used. + MaxDepth uint + + // MaxInputLen sets the maximum input size. It is used by the decoder to sanity-check + // allocation sizes and avoid heap explosions from doctored inputs. + // + // If set to 0, the decoder will try to figure out the input size by checking whether + // the provided io.Reader implements Len() (e.g. strings.Reader, bytes.Reader and bytes.Buffer do). + // Otherwise, no sanity checks will be done. + MaxInputLen int +} + +// DefaultDecodeOptions are the default decoding options. +var DefaultDecodeOptions = DecodeOptions{ + MaxDepth: DecodeDefaultMaxDepth, + MaxInputLen: 0, +} + /* Unmarshal parses XDR-encoded data into the value pointed to by v reading from reader r and returning the total number of bytes read. An addressable pointer @@ -66,7 +88,7 @@ by v and performs a mapping of underlying XDR types to Go types as follows: Notes and Limitations: - Automatic unmarshalling of variable and fixed-length arrays of uint8s - requires a special struct tag `xdropaque:"false"` since byte slices + requires a special struct tag xdropaque:"false" since byte slices and byte arrays are assumed to be opaque data and byte is a Go alias for uint8 thus indistinguishable under reflection - Cyclic data structures are not supported and will result in ErrMaxDecodingDepth errors @@ -78,12 +100,15 @@ potential issues are unsupported Go types, attempting to decode a value which is too large to fit into a specified Go type, and exceeding max slice limitations. */ func Unmarshal(r io.Reader, v interface{}) (int, error) { - d := newDecoder(r) + d := NewDecoder(r) return d.Decode(v) } -// lenLeft tells you how many bytes are left to read. -// It is satisfied by io.Readers like bytes.Buffer, bytes.Reader +// UnmarshalWithOptions works like Unmarshal but accepts decoding options. +func UnmarshalWithOptions(r io.Reader, v interface{}, options DecodeOptions) (int, error) { + d := NewDecoderWithOptions(r, options) + return d.Decode(v) +} type lenLeft interface { Len() int @@ -95,7 +120,7 @@ type lenLeft interface { // used to get a new Decoder directly. // // Typically, Unmarshal should be used instead of manual decoding. A Decoder -// is exposed so it is possible to perform manual decoding should it be +// is exposed, so it is possible to perform manual decoding should it be // necessary in complex scenarios where automatic reflection-based decoding // won't work. type Decoder struct { @@ -103,14 +128,53 @@ type Decoder struct { scratchBuf [8]byte r io.Reader l lenLeft + maxDepth uint +} + +// readerLenWrapper wraps a reader an initial length and provides a Len() method indicating +// how much input is left +type readerLenWrapper struct { + inner io.Reader + readCount int + initialLen int +} + +func (l *readerLenWrapper) Len() int { + return l.initialLen - l.readCount +} + +func (l *readerLenWrapper) Read(p []byte) (int, error) { + n, err := l.inner.Read(p) + if n > 0 { + l.readCount += n + } + return n, err +} + +// NewDecoder returns a Decoder that can be used to manually decode XDR data +// from a provided reader. Typically, Unmarshal should be used instead of +// manually creating a Decoder. +func NewDecoder(r io.Reader) *Decoder { + return NewDecoderWithOptions(r, DefaultDecodeOptions) } -func newDecoder(r io.Reader) *Decoder { - d := &Decoder{r: r} +// NewDecoderWithOptions works like NewDecoder but allows supplying decoding options. +func NewDecoderWithOptions(r io.Reader, options DecodeOptions) *Decoder { + maxDepth := options.MaxDepth + if maxDepth < 1 { + maxDepth = DecodeDefaultMaxDepth + } if l, ok := r.(lenLeft); ok { - d.l = l + return &Decoder{r: r, l: l, maxDepth: maxDepth} + } + if options.MaxInputLen > 0 { + rlw := &readerLenWrapper{ + inner: r, + initialLen: options.MaxInputLen, + } + return &Decoder{r: rlw, l: rlw, maxDepth: maxDepth} } - return d + return &Decoder{r: r, l: nil, maxDepth: options.MaxDepth} } // DecodeInt treats the next 4 bytes as an XDR encoded integer and returns the @@ -753,7 +817,6 @@ func (d *Decoder) decodeMap(v reflect.Value, maxDepth uint) (int, error) { // Allocate storage for the underlying map if needed. vt := v.Type() if v.IsNil() { - // We assume that the map allocation won't exceed d.maxAllocSize v.Set(reflect.MakeMap(vt)) } @@ -1089,11 +1152,6 @@ func (d *Decoder) indirectIfPtr(v reflect.Value) (reflect.Value, error) { // data instead of a user-supplied reader. See the Unmarhsal documentation for // specifics. Decode(v) is equivalent to DecodeWithMaxDepth(v, DecodeDefaultMaxDepth) func (d *Decoder) Decode(v interface{}) (int, error) { - return d.DecodeWithMaxDepth(v, DecodeDefaultMaxDepth) -} - -// DecodeWithMaxDepth behaves like Decode, except an explicit maximum decoding depth is used -func (d *Decoder) DecodeWithMaxDepth(v interface{}, maxDepth uint) (int, error) { if v == nil { msg := "can't unmarshal to nil interface" return 0, unmarshalError("Unmarshal", ErrNilInterface, msg, nil, @@ -1114,7 +1172,7 @@ func (d *Decoder) DecodeWithMaxDepth(v interface{}, maxDepth uint) (int, error) return 0, err } - return d.decode(vv.Elem(), 0, maxDepth) + return d.decode(vv.Elem(), 0, d.maxDepth) } // InputLen returns the size left to read from the decoder's input if available @@ -1124,10 +1182,3 @@ func (d *Decoder) InputLen() (int, bool) { } return d.l.Len(), true } - -// NewDecoder returns a Decoder that can be used to manually decode XDR data -// from a provided reader. Typically, Unmarshal should be used instead of -// manually creating a Decoder. -func NewDecoder(r io.Reader) *Decoder { - return newDecoder(r) -} diff --git a/xdr3/decode_test.go b/xdr3/decode_test.go index 526b491..747ebf9 100644 --- a/xdr3/decode_test.go +++ b/xdr3/decode_test.go @@ -18,6 +18,7 @@ package xdr_test import ( "bytes" + "encoding/base64" "fmt" "math" "reflect" @@ -1143,20 +1144,20 @@ func TestDecodeMaxDepth(t *testing.T) { } bufCopy := buf - decoder := NewDecoder(&bufCopy) + decoder := NewDecoderWithOptions(&bufCopy, DecodeOptions{MaxDepth: 3}) var s structWithPointer - _, err = decoder.DecodeWithMaxDepth(&s, 3) + _, err = decoder.Decode(&s) if err != nil { t.Error("unexpected error") } bufCopy = buf - decoder = NewDecoder(&bufCopy) - _, err = decoder.DecodeWithMaxDepth(&s, 2) + decoder = NewDecoderWithOptions(&bufCopy, DecodeOptions{MaxDepth: 2}) + _, err = decoder.Decode(&s) assertError(t, "", err, &UnmarshalError{ErrorCode: ErrMaxDecodingDepth}) } -func TestDecodeMaxAllocationCheck(t *testing.T) { +func TestDecodeMaxAllocationCheck_ImplicitLenReader(t *testing.T) { var buf bytes.Buffer _, err := Marshal(&buf, "thisstringis23charslong") if err != nil { @@ -1164,12 +1165,30 @@ func TestDecodeMaxAllocationCheck(t *testing.T) { } // Reduce the buffer size so that the length of the buffer - // is shorter than the encoded string length + // is shorter than the encoded XDR length buf.Truncate(buf.Len() - 4) - bufCopy := buf - decoder := NewDecoder(&bufCopy) + decoder := NewDecoder(&buf) + var s string + _, err = decoder.Decode(&s) + assertError(t, "", err, &UnmarshalError{ErrorCode: ErrOverflow}) +} + +func TestDecodeMaxAllocationCheck_ExplicitLenReader(t *testing.T) { + var buf bytes.Buffer + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + _, err := Marshal(encoder, "thisstringis23charslong") + if err != nil { + t.Error("unexpected error") + } + + xdrLen := base64.StdEncoding.DecodedLen(buf.Len()) + // Reduce the buffer size so that the length of the buffer + // is shorter than the encoded XDR length + reducedLen := xdrLen - 4 + + decoder := NewDecoderWithOptions(&buf, DecodeOptions{MaxInputLen: reducedLen}) var s string - _, err = decoder.DecodeWithMaxDepth(&s, DecodeDefaultMaxDepth) + _, err = decoder.Decode(&s) assertError(t, "", err, &UnmarshalError{ErrorCode: ErrOverflow}) }