diff --git a/zstd/decoder.go b/zstd/decoder.go index f295951025..4d984c3b26 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -113,9 +113,6 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) { // Returns the number of bytes written and any error that occurred. // When the stream is done, io.EOF will be returned. func (d *Decoder) Read(p []byte) (int, error) { - if d.stream == nil { - return 0, ErrDecoderNilInput - } var n int for { if len(d.current.b) > 0 { @@ -167,18 +164,15 @@ func (d *Decoder) Reset(r io.Reader) error { if r == nil { d.current.err = ErrDecoderNilInput + if len(d.current.b) > 0 { + d.current.b = d.current.b[:0] + } d.current.flushed = true return nil } - if d.stream == nil { - d.stream = make(chan decodeStream, 1) - d.streamWg.Add(1) - go d.startStreamDecoder(d.stream) - } - - // If bytes buffer and < 1MB, do sync decoding anyway. - if bb, ok := r.(byter); ok && bb.Len() < 1<<20 { + // If bytes buffer and < 5MB, do sync decoding anyway. + if bb, ok := r.(byter); ok && bb.Len() < 5<<20 { bb2 := bb if debugDecoder { println("*bytes.Buffer detected, doing sync decode, len:", bb.Len()) @@ -202,6 +196,12 @@ func (d *Decoder) Reset(r io.Reader) error { return nil } + if d.stream == nil { + d.stream = make(chan decodeStream, 1) + d.streamWg.Add(1) + go d.startStreamDecoder(d.stream) + } + // Remove current block. d.current.decodeOutput = decodeOutput{} d.current.err = nil @@ -255,9 +255,6 @@ func (d *Decoder) drainOutput() { // The return value n is the number of bytes written. // Any error encountered during the write is also returned. func (d *Decoder) WriteTo(w io.Writer) (int64, error) { - if d.stream == nil { - return 0, ErrDecoderNilInput - } var n int64 for { if len(d.current.b) > 0 {