Skip to content

Commit

Permalink
zstd: Add stream encoding without goroutines (#505)
Browse files Browse the repository at this point in the history
* zstd: Add stream encoding without goroutines

Does not use goroutines when encoder concurrency is 1.

Fixes #264

Can probably be clean up a bit.

* Reduce allocs for concurrent buffers when not used.
  • Loading branch information
klauspost authored Feb 27, 2022
1 parent 308a751 commit 15b48b6
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 17 deletions.
8 changes: 6 additions & 2 deletions zstd/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ func (d *Decoder) Reset(r io.Reader) error {
// drainOutput will drain the output until errEndOfStream is sent.
func (d *Decoder) drainOutput() {
if d.current.cancel != nil {
println("cancelling current")
if debugDecoder {
println("cancelling current")
}
d.current.cancel()
d.current.cancel = nil
}
Expand Down Expand Up @@ -816,7 +818,9 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
do.err = ErrFrameSizeMismatch
hasErr = true
} else {
println("fcs ok", block.Last, fcs, decodedFrame)
if debugDecoder {
println("fcs ok", block.Last, fcs, decodedFrame)
}
}
}
output <- do
Expand Down
66 changes: 54 additions & 12 deletions zstd/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,25 @@ func (e *Encoder) Reset(w io.Writer) {
if cap(s.filling) == 0 {
s.filling = make([]byte, 0, e.o.blockSize)
}
if cap(s.current) == 0 {
s.current = make([]byte, 0, e.o.blockSize)
}
if cap(s.previous) == 0 {
s.previous = make([]byte, 0, e.o.blockSize)
if e.o.concurrent > 1 {
if cap(s.current) == 0 {
s.current = make([]byte, 0, e.o.blockSize)
}
if cap(s.previous) == 0 {
s.previous = make([]byte, 0, e.o.blockSize)
}
s.current = s.current[:0]
s.previous = s.previous[:0]
if s.writing == nil {
s.writing = &blockEnc{lowMem: e.o.lowMem}
s.writing.init()
}
s.writing.initNewEncode()
}
if s.encoder == nil {
s.encoder = e.o.encoder()
}
if s.writing == nil {
s.writing = &blockEnc{lowMem: e.o.lowMem}
s.writing.init()
}
s.writing.initNewEncode()
s.filling = s.filling[:0]
s.current = s.current[:0]
s.previous = s.previous[:0]
s.encoder.Reset(e.o.dict, false)
s.headerWritten = false
s.eofWritten = false
Expand Down Expand Up @@ -258,6 +260,46 @@ func (e *Encoder) nextBlock(final bool) error {
return s.err
}

// SYNC:
if e.o.concurrent == 1 {
src := s.filling
s.nInput += int64(len(s.filling))
if debugEncoder {
println("Adding sync block,", len(src), "bytes, final:", final)
}
enc := s.encoder
blk := enc.Block()
blk.reset(nil)
enc.Encode(blk, src)
blk.last = final
if final {
s.eofWritten = true
}

err := errIncompressible
// If we got the exact same number of literals as input,
// assume the literals cannot be compressed.
if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
}
switch err {
case errIncompressible:
if debugEncoder {
println("Storing incompressible block as raw")
}
blk.encodeRaw(src)
// In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
case nil:
default:
s.err = err
return err
}
_, s.err = s.w.Write(blk.output)
s.nWritten += int64(len(blk.output))
s.filling = s.filling[:0]
return s.err
}

// Move blocks forward.
s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
s.nInput += int64(len(s.current))
Expand Down
1 change: 1 addition & 0 deletions zstd/encoder_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ func WithEncoderCRC(b bool) EOption {
// WithEncoderConcurrency will set the concurrency,
// meaning the maximum number of encoders to run concurrently.
// The value supplied must be at least 1.
// For streams, setting a value of 1 will disable async compression.
// By default this will be set to GOMAXPROCS.
func WithEncoderConcurrency(n int) EOption {
return func(o *encoderOptions) error {
Expand Down
55 changes: 52 additions & 3 deletions zstd/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ func testEncoderRoundtrip(t *testing.T, file string, wantCRC []byte) {
for _, opt := range getEncOpts(1) {
t.Run(opt.name, func(t *testing.T) {
opt := opt
t.Parallel()
//t.Parallel()
f, err := os.Open(file)
if err != nil {
if os.IsNotExist(err) {
Expand Down Expand Up @@ -851,7 +851,7 @@ func TestEncoder_EncodeAllEmpty(t *testing.T) {
}

func TestEncoder_EncodeAllEnwik9(t *testing.T) {
if false || testing.Short() {
if testing.Short() {
t.SkipNow()
}
file := "testdata/enwik9.zst"
Expand All @@ -873,8 +873,11 @@ func TestEncoder_EncodeAllEnwik9(t *testing.T) {
}

start := time.Now()
var e Encoder
e, err := NewWriter(nil)
dst := e.EncodeAll(in, nil)
if err != nil {
t.Fatal(err)
}
t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
Expand All @@ -889,6 +892,52 @@ func TestEncoder_EncodeAllEnwik9(t *testing.T) {
t.Log("Encoded content matched")
}

func TestEncoder_EncoderStreamEnwik9(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
file := "testdata/enwik9.zst"
f, err := os.Open(file)
if err != nil {
if os.IsNotExist(err) {
t.Skip("To run extended tests, download http://mattmahoney.net/dc/enwik9.zip unzip it \n" +
"compress it with 'zstd -15 -T0 enwik9' and place it in " + file)
}
}
dec, err := NewReader(f)
if err != nil {
t.Fatal(err)
}
defer dec.Close()
in, err := ioutil.ReadAll(dec)
if err != nil {
t.Fatal(err)
}

start := time.Now()
var dst bytes.Buffer
e, err := NewWriter(&dst)
_, err = io.Copy(e, bytes.NewBuffer(in))
if err != nil {
t.Fatal(err)
}
e.Close()
t.Log("Full Encoder len", len(in), "-> zstd len", dst.Len())
mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
if false {
decoded, err := dec.DecodeAll(dst.Bytes(), nil)
if err != nil {
t.Error(err, len(decoded))
}
if !bytes.Equal(decoded, in) {
ioutil.WriteFile("testdata/"+t.Name()+"-enwik9.got", decoded, os.ModePerm)
t.Fatal("Decoded does not match")
}
t.Log("Encoded content matched")
}
}

func BenchmarkEncoder_EncodeAllXML(b *testing.B) {
f, err := os.Open("testdata/xml.zst")
if err != nil {
Expand Down

0 comments on commit 15b48b6

Please sign in to comment.