diff --git a/zstd/blockdec.go b/zstd/blockdec.go index e30af505ca..8a98c4562e 100644 --- a/zstd/blockdec.go +++ b/zstd/blockdec.go @@ -168,10 +168,10 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error { // Read block data. if cap(b.dataStorage) < cSize { - if b.lowMem { + if b.lowMem || cSize > maxCompressedBlockSize { b.dataStorage = make([]byte, 0, cSize) } else { - b.dataStorage = make([]byte, 0, maxBlockSize) + b.dataStorage = make([]byte, 0, maxCompressedBlockSize) } } if cap(b.dst) <= maxSize { diff --git a/zstd/decoder_options.go b/zstd/decoder_options.go index c0fd058c28..95cc9b8b81 100644 --- a/zstd/decoder_options.go +++ b/zstd/decoder_options.go @@ -17,14 +17,16 @@ type decoderOptions struct { lowMem bool concurrent int maxDecodedSize uint64 + maxWindowSize uint64 dicts []dict } func (o *decoderOptions) setDefault() { *o = decoderOptions{ // use less ram: true for now, but may change. - lowMem: true, - concurrent: runtime.GOMAXPROCS(0), + lowMem: true, + concurrent: runtime.GOMAXPROCS(0), + maxWindowSize: MaxWindowSize, } o.maxDecodedSize = 1 << 63 } @@ -52,7 +54,6 @@ func WithDecoderConcurrency(n int) DOption { // WithDecoderMaxMemory allows to set a maximum decoded size for in-memory // non-streaming operations or maximum window size for streaming operations. // This can be used to control memory usage of potentially hostile content. -// For streaming operations, the maximum window size is capped at 1<<30 bytes. // Maximum and default is 1 << 63 bytes. func WithDecoderMaxMemory(n uint64) DOption { return func(o *decoderOptions) error { @@ -81,3 +82,21 @@ func WithDecoderDicts(dicts ...[]byte) DOption { return nil } } + +// WithDecoderMaxWindow allows to set a maximum window size for decodes. +// This allows rejecting packets that will cause big memory usage. +// The Decoder will likely allocate more memory based on the WithDecoderLowmem setting. +// If WithDecoderMaxMemory is set to a lower value, that will be used. +// Default is 512MB, Maximum is ~3.75 TB as per zstandard spec. +func WithDecoderMaxWindow(size uint64) DOption { + return func(o *decoderOptions) error { + if size < MinWindowSize { + return errors.New("WithMaxWindowSize must be at least 1KB, 1024 bytes") + } + if size > (1<<41)+7*(1<<38) { + return errors.New("WithMaxWindowSize must be less than (1<<41) + 7*(1<<38) ~ 3.75TB") + } + o.maxWindowSize = size + return nil + } +} diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index fcc5dd98a8..7a6807159b 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -178,7 +178,7 @@ func TestNewDecoder(t *testing.T) { func TestNewDecoderMemory(t *testing.T) { defer timeout(60 * time.Second)() var testdata bytes.Buffer - enc, err := NewWriter(&testdata, WithWindowSize(64<<10), WithSingleSegment(false)) + enc, err := NewWriter(&testdata, WithWindowSize(32<<10), WithSingleSegment(false)) if err != nil { t.Fatal(err) } @@ -200,6 +200,9 @@ func TestNewDecoderMemory(t *testing.T) { n = 200 } + // 16K buffer + var tmp [16 << 10]byte + var before, after runtime.MemStats runtime.GC() runtime.ReadMemStats(&before) @@ -214,8 +217,6 @@ func TestNewDecoderMemory(t *testing.T) { } } - // 32K buffer - var tmp [128 << 10]byte for i := range decs { _, err := io.ReadFull(decs[i], tmp[:]) if err != nil { @@ -226,10 +227,12 @@ func TestNewDecoderMemory(t *testing.T) { runtime.GC() runtime.ReadMemStats(&after) size := (after.HeapInuse - before.HeapInuse) / uint64(n) / 1024 + + const expect = 124 t.Log(size, "KiB per decoder") // This is not exact science, but fail if we suddenly get more than 2x what we expect. - if size > 221*2 && !testing.Short() { - t.Errorf("expected < 221KB per decoder, got %d", size) + if size > expect*2 && !testing.Short() { + t.Errorf("expected < %dKB per decoder, got %d", expect, size) } for _, dec := range decs { @@ -237,6 +240,115 @@ func TestNewDecoderMemory(t *testing.T) { } } +func TestNewDecoderMemoryHighMem(t *testing.T) { + defer timeout(60 * time.Second)() + var testdata bytes.Buffer + enc, err := NewWriter(&testdata, WithWindowSize(32<<10), WithSingleSegment(false)) + if err != nil { + t.Fatal(err) + } + // Write 256KB + for i := 0; i < 256; i++ { + tmp := strings.Repeat(string([]byte{byte(i)}), 1024) + _, err := enc.Write([]byte(tmp)) + if err != nil { + t.Fatal(err) + } + } + err = enc.Close() + if err != nil { + t.Fatal(err) + } + + var n = 50 + if testing.Short() { + n = 10 + } + + // 16K buffer + var tmp [16 << 10]byte + + var before, after runtime.MemStats + runtime.GC() + runtime.ReadMemStats(&before) + + var decs = make([]*Decoder, n) + for i := range decs { + // Wrap in NopCloser to avoid shortcut. + input := ioutil.NopCloser(bytes.NewBuffer(testdata.Bytes())) + decs[i], err = NewReader(input, WithDecoderConcurrency(1), WithDecoderLowmem(false)) + if err != nil { + t.Fatal(err) + } + } + + for i := range decs { + _, err := io.ReadFull(decs[i], tmp[:]) + if err != nil { + t.Fatal(err) + } + } + + runtime.GC() + runtime.ReadMemStats(&after) + size := (after.HeapInuse - before.HeapInuse) / uint64(n) / 1024 + + const expect = 3915 + t.Log(size, "KiB per decoder") + // This is not exact science, but fail if we suddenly get more than 2x what we expect. + if size > expect*2 && !testing.Short() { + t.Errorf("expected < %dKB per decoder, got %d", expect, size) + } + + for _, dec := range decs { + dec.Close() + } +} + +func TestNewDecoderFrameSize(t *testing.T) { + defer timeout(60 * time.Second)() + var testdata bytes.Buffer + enc, err := NewWriter(&testdata, WithWindowSize(64<<10)) + if err != nil { + t.Fatal(err) + } + // Write 256KB + for i := 0; i < 256; i++ { + tmp := strings.Repeat(string([]byte{byte(i)}), 1024) + _, err := enc.Write([]byte(tmp)) + if err != nil { + t.Fatal(err) + } + } + err = enc.Close() + if err != nil { + t.Fatal(err) + } + // Must fail + dec, err := NewReader(bytes.NewReader(testdata.Bytes()), WithDecoderMaxWindow(32<<10)) + if err != nil { + t.Fatal(err) + } + _, err = io.Copy(ioutil.Discard, dec) + if err == nil { + dec.Close() + t.Fatal("Wanted error, got none") + } + dec.Close() + + // Must succeed. + dec, err = NewReader(bytes.NewReader(testdata.Bytes()), WithDecoderMaxWindow(64<<10)) + if err != nil { + t.Fatal(err) + } + _, err = io.Copy(ioutil.Discard, dec) + if err != nil { + dec.Close() + t.Fatalf("Wanted no error, got %+v", err) + } + dec.Close() +} + func TestNewDecoderGood(t *testing.T) { defer timeout(30 * time.Second)() testDecoderFile(t, "testdata/good.zip") diff --git a/zstd/framedec.go b/zstd/framedec.go index e8cc9a2c22..989c79f8c3 100644 --- a/zstd/framedec.go +++ b/zstd/framedec.go @@ -22,10 +22,6 @@ type frameDec struct { WindowSize uint64 - // maxWindowSize is the maximum windows size to support. - // should never be bigger than max-int. - maxWindowSize uint64 - // In order queue of blocks being decoded. decoding chan *blockDec @@ -50,8 +46,11 @@ type frameDec struct { } const ( - // The minimum Window_Size is 1 KB. + // MinWindowSize is the minimum Window Size, which is 1 KB. MinWindowSize = 1 << 10 + + // MaxWindowSize is the maximum encoder window size + // and the default decoder maximum window size. MaxWindowSize = 1 << 29 ) @@ -61,12 +60,11 @@ var ( ) func newFrameDec(o decoderOptions) *frameDec { - d := frameDec{ - o: o, - maxWindowSize: MaxWindowSize, + if o.maxWindowSize > o.maxDecodedSize { + o.maxWindowSize = o.maxDecodedSize } - if d.maxWindowSize > o.maxDecodedSize { - d.maxWindowSize = o.maxDecodedSize + d := frameDec{ + o: o, } return &d } @@ -251,13 +249,17 @@ func (d *frameDec) reset(br byteBuffer) error { } } - if d.WindowSize > d.maxWindowSize { - printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize) + if d.WindowSize > uint64(d.o.maxWindowSize) { + if debugDecoder { + printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize) + } return ErrWindowSizeExceeded } // The minimum Window_Size is 1 KB. if d.WindowSize < MinWindowSize { - println("got window size: ", d.WindowSize) + if debugDecoder { + println("got window size: ", d.WindowSize) + } return ErrWindowSizeTooSmall } d.history.windowSize = int(d.WindowSize) @@ -352,8 +354,8 @@ func (d *frameDec) checkCRC() error { func (d *frameDec) initAsync() { if !d.o.lowMem && !d.SingleSegment { - // set max extra size history to 10MB. - d.history.maxSize = d.history.windowSize + maxBlockSize*5 + // set max extra size history to 2MB. + d.history.maxSize = d.history.windowSize + maxBlockSize } // re-alloc if more than one extra block size. if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize {