diff --git a/s2/decode.go b/s2/decode.go index 039ddb911b..605b85dd91 100644 --- a/s2/decode.go +++ b/s2/decode.go @@ -77,10 +77,38 @@ func Decode(dst, src []byte) ([]byte, error) { // NewReader returns a new Reader that decompresses from r, using the framing // format described at // https://github.com/google/snappy/blob/master/framing_format.txt with S2 changes. -func NewReader(r io.Reader) *Reader { - return &Reader{ - r: r, - buf: make([]byte, MaxEncodedLen(maxBlockSize)+checksumSize), +func NewReader(r io.Reader, opts ...ReaderOption) *Reader { + nr := Reader{ + r: r, + maxBlock: maxBlockSize, + } + for _, opt := range opts { + if err := opt(&nr); err != nil { + nr.err = err + return &nr + } + } + nr.buf = make([]byte, MaxEncodedLen(nr.maxBlock)+checksumSize) + nr.paramsOK = true + return &nr +} + +// ReaderOption is an option for creating a decoder. +type ReaderOption func(*Reader) error + +// ReaderMaxBlockSize allows to control allocations if the stream +// has been compressed with a smaller WriterBlockSize, or with the default 1MB. +// Blocks must be this size or smaller to decompress, +// otherwise the decoder will return ErrUnsupported. +// +// Default is the maximum limit of 4MB. +func ReaderMaxBlockSize(n int) ReaderOption { + return func(r *Reader) error { + if n > maxBlockSize || n <= 0 { + return errors.New("s2: block size too large. Must be <= 4MB and > 0") + } + r.maxBlock = n + return nil } } @@ -92,13 +120,18 @@ type Reader struct { buf []byte // decoded[i:j] contains decoded bytes that have not yet been passed on. i, j int + maxBlock int readHeader bool + paramsOK bool } // Reset discards any buffered data, resets all state, and switches the Snappy // reader to read from r. This permits reusing a Reader rather than allocating // a new one. func (r *Reader) Reset(reader io.Reader) { + if !r.paramsOK { + return + } r.r = reader r.err = nil r.i = 0 @@ -116,6 +149,36 @@ func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) { return true } +// skipN will skip n bytes. +// If the supplied reader supports seeking that is used. +// tmp is used as a temporary buffer for reading. +// The supplied slice does not need to be the size of the read. +func (r *Reader) skipN(tmp []byte, n int, allowEOF bool) (ok bool) { + if rs, ok := r.r.(io.ReadSeeker); ok { + _, err := rs.Seek(int64(n), io.SeekCurrent) + if err == nil { + return true + } + if err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) { + r.err = ErrCorrupt + return false + } + } + for n > 0 { + if n < len(tmp) { + tmp = tmp[:n] + } + if _, r.err = io.ReadFull(r.r, tmp); r.err != nil { + if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) { + r.err = ErrCorrupt + } + return false + } + n -= len(tmp) + } + return true +} + // Read satisfies the io.Reader interface. func (r *Reader) Read(p []byte) (int, error) { if r.err != nil { @@ -139,10 +202,6 @@ func (r *Reader) Read(p []byte) (int, error) { r.readHeader = true } chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16 - if chunkLen > len(r.buf) { - r.err = ErrUnsupported - return 0, r.err - } // The chunk types are specified at // https://github.com/google/snappy/blob/master/framing_format.txt @@ -153,6 +212,10 @@ func (r *Reader) Read(p []byte) (int, error) { r.err = ErrCorrupt return 0, r.err } + if chunkLen > len(r.buf) { + r.err = ErrUnsupported + return 0, r.err + } buf := r.buf[:chunkLen] if !r.readFull(buf, false) { return 0, r.err @@ -166,7 +229,7 @@ func (r *Reader) Read(p []byte) (int, error) { return 0, r.err } if n > len(r.decoded) { - if n > maxBlockSize { + if n > r.maxBlock { r.err = ErrCorrupt return 0, r.err } @@ -189,6 +252,10 @@ func (r *Reader) Read(p []byte) (int, error) { r.err = ErrCorrupt return 0, r.err } + if chunkLen > len(r.buf) { + r.err = ErrUnsupported + return 0, r.err + } buf := r.buf[:checksumSize] if !r.readFull(buf, false) { return 0, r.err @@ -197,7 +264,7 @@ func (r *Reader) Read(p []byte) (int, error) { // Read directly into r.decoded instead of via r.buf. n := chunkLen - checksumSize if n > len(r.decoded) { - if n > maxBlockSize { + if n > r.maxBlock { r.err = ErrCorrupt return 0, r.err } @@ -238,7 +305,12 @@ func (r *Reader) Read(p []byte) (int, error) { } // Section 4.4 Padding (chunk type 0xfe). // Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd). - if !r.readFull(r.buf[:chunkLen], false) { + if chunkLen > maxBlockSize { + r.err = ErrUnsupported + return 0, r.err + } + + if !r.skipN(r.buf, chunkLen, false) { return 0, r.err } } @@ -286,10 +358,6 @@ func (r *Reader) Skip(n int64) error { r.readHeader = true } chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16 - if chunkLen > len(r.buf) { - r.err = ErrUnsupported - return r.err - } // The chunk types are specified at // https://github.com/google/snappy/blob/master/framing_format.txt @@ -300,6 +368,10 @@ func (r *Reader) Skip(n int64) error { r.err = ErrCorrupt return r.err } + if chunkLen > len(r.buf) { + r.err = ErrUnsupported + return r.err + } buf := r.buf[:chunkLen] if !r.readFull(buf, false) { return r.err @@ -312,7 +384,7 @@ func (r *Reader) Skip(n int64) error { r.err = err return r.err } - if dLen > maxBlockSize { + if dLen > r.maxBlock { r.err = ErrCorrupt return r.err } @@ -342,6 +414,10 @@ func (r *Reader) Skip(n int64) error { r.err = ErrCorrupt return r.err } + if chunkLen > len(r.buf) { + r.err = ErrUnsupported + return r.err + } buf := r.buf[:checksumSize] if !r.readFull(buf, false) { return r.err @@ -350,7 +426,7 @@ func (r *Reader) Skip(n int64) error { // Read directly into r.decoded instead of via r.buf. n2 := chunkLen - checksumSize if n2 > len(r.decoded) { - if n2 > maxBlockSize { + if n2 > r.maxBlock { r.err = ErrCorrupt return r.err } @@ -391,13 +467,15 @@ func (r *Reader) Skip(n int64) error { r.err = ErrUnsupported return r.err } + if chunkLen > maxBlockSize { + r.err = ErrUnsupported + return r.err + } // Section 4.4 Padding (chunk type 0xfe). // Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd). - if !r.readFull(r.buf[:chunkLen], false) { + if !r.skipN(r.buf, chunkLen, false) { return r.err } - - return io.ErrUnexpectedEOF } return nil } diff --git a/s2/decode_test.go b/s2/decode_test.go index 0a6a8125d4..bc22c6baf0 100644 --- a/s2/decode_test.go +++ b/s2/decode_test.go @@ -6,6 +6,7 @@ package s2 import ( "bytes" + "fmt" "io/ioutil" "strings" "testing" @@ -41,3 +42,120 @@ func TestDecodeRegression(t *testing.T) { }) } } + +func TestDecoderMaxBlockSize(t *testing.T) { + data, err := ioutil.ReadFile("testdata/enc_regressions.zip") + if err != nil { + t.Fatal(err) + } + zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + t.Fatal(err) + } + sizes := []int{4 << 10, 10 << 10, 1 << 20, 4 << 20} + test := func(t *testing.T, data []byte) { + for _, size := range sizes { + t.Run(fmt.Sprintf("%d", size), func(t *testing.T) { + var buf bytes.Buffer + dec := NewReader(nil, ReaderMaxBlockSize(size)) + enc := NewWriter(&buf, WriterBlockSize(size), WriterPadding(16<<10), WriterPaddingSrc(zeroReader{})) + + // Test writer. + n, err := enc.Write(data) + if err != nil { + t.Error(err) + return + } + if n != len(data) { + t.Error(fmt.Errorf("Write: Short write, want %d, got %d", len(data), n)) + return + } + err = enc.Close() + if err != nil { + t.Error(err) + return + } + // Calling close twice should not affect anything. + err = enc.Close() + if err != nil { + t.Error(err) + return + } + + dec.Reset(&buf) + got, err := ioutil.ReadAll(dec) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(data, got) { + t.Error("block (reset) decoder mismatch") + return + } + + // Test Reset on both and use ReadFrom instead. + buf.Reset() + enc.Reset(&buf) + n2, err := enc.ReadFrom(bytes.NewBuffer(data)) + if err != nil { + t.Error(err) + return + } + if n2 != int64(len(data)) { + t.Error(fmt.Errorf("ReadFrom: Short read, want %d, got %d", len(data), n2)) + return + } + // Encode twice... + n2, err = enc.ReadFrom(bytes.NewBuffer(data)) + if err != nil { + t.Error(err) + return + } + if n2 != int64(len(data)) { + t.Error(fmt.Errorf("ReadFrom: Short read, want %d, got %d", len(data), n2)) + return + } + + err = enc.Close() + if err != nil { + t.Error(err) + return + } + if enc.pad > 0 && buf.Len()%enc.pad != 0 { + t.Error(fmt.Errorf("wanted size to be mutiple of %d, got size %d with remainder %d", enc.pad, buf.Len(), buf.Len()%enc.pad)) + return + } + dec.Reset(&buf) + // Skip first... + dec.Skip(int64(len(data))) + got, err = ioutil.ReadAll(dec) + if err != nil { + t.Error(err) + return + } + if !bytes.Equal(data, got) { + t.Error("frame (reset) decoder mismatch") + return + } + }) + } + } + for _, tt := range zr.File { + if !strings.HasSuffix(t.Name(), "") { + continue + } + t.Run(tt.Name, func(t *testing.T) { + r, err := tt.Open() + if err != nil { + t.Error(err) + return + } + b, err := ioutil.ReadAll(r) + if err != nil { + t.Error(err) + return + } + test(t, b[:len(b):len(b)]) + }) + } +} diff --git a/s2/encode.go b/s2/encode.go index cdb3ab45b2..9dc1ece2ca 100644 --- a/s2/encode.go +++ b/s2/encode.go @@ -907,7 +907,7 @@ func WriterUncompressed() WriterOption { // Default block size is 1MB. func WriterBlockSize(n int) WriterOption { return func(w *Writer) error { - if w.blockSize > maxBlockSize || w.blockSize < minBlockSize { + if n > maxBlockSize || n < minBlockSize { return errors.New("s2: block size too large. Must be <= 4MB and >=4KB") } w.blockSize = n diff --git a/s2/encode_test.go b/s2/encode_test.go index b4490b0411..4c15bf4d69 100644 --- a/s2/encode_test.go +++ b/s2/encode_test.go @@ -39,7 +39,7 @@ func testOptions(t testing.TB) map[string][]WriterOption { for name, opt := range testOptions { x[name] = opt if !testing.Short() { - x[name+"-1k-win"] = cloneAdd(opt, WriterBlockSize(1<<10)) + x[name+"-4k-win"] = cloneAdd(opt, WriterBlockSize(4<<10)) x[name+"-4M-win"] = cloneAdd(opt, WriterBlockSize(4<<20)) } } @@ -79,8 +79,9 @@ func TestEncoderRegression(t *testing.T) { test := func(t *testing.T, data []byte) { for name, opts := range testOptions(t) { t.Run(name, func(t *testing.T) { + var buf bytes.Buffer dec := NewReader(nil) - enc := NewWriter(nil, opts...) + enc := NewWriter(&buf, opts...) comp := Encode(make([]byte, MaxEncodedLen(len(data))), data) decoded, err := Decode(nil, comp) @@ -112,8 +113,6 @@ func TestEncoderRegression(t *testing.T) { } // Test writer. - var buf bytes.Buffer - enc.Reset(&buf) n, err := enc.Write(data) if err != nil { t.Error(err) @@ -151,10 +150,9 @@ func TestEncoderRegression(t *testing.T) { } // Test Reset on both and use ReadFrom instead. - input := bytes.NewBuffer(data) - buf = bytes.Buffer{} + buf.Reset() enc.Reset(&buf) - n2, err := enc.ReadFrom(input) + n2, err := enc.ReadFrom(bytes.NewBuffer(data)) if err != nil { t.Error(err) return