Skip to content

Commit

Permalink
Avoid using zstd.Encoder.EncodeAll, to reduce memory usage
Browse files Browse the repository at this point in the history
Followup to #25.
  • Loading branch information
mostynb committed Dec 19, 2024
1 parent 9adaaed commit ce6ebe9
Showing 1 changed file with 32 additions and 30 deletions.
62 changes: 32 additions & 30 deletions internal/zstd/zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package zstd

import (
"bytes"
"errors"
"io"
"runtime"
Expand Down Expand Up @@ -48,21 +47,22 @@ type decoderWrapper struct {
*zstd.Decoder
}

type encoderWrapper struct {
*zstd.Encoder
pool *sync.Pool
}

type compressor struct {
encoder *zstd.Encoder
decoderPool sync.Pool // To hold *zstd.Decoder's.
encoderPool sync.Pool
decoderPool sync.Pool
}

func PretendInit(clobbering bool) {
if !clobbering && encoding.GetCompressor(Name) != nil {
return
}

enc, _ := zstd.NewWriter(nil, encoderOptions...)
c := &compressor{
encoder: enc,
}
encoding.RegisterCompressor(c)
encoding.RegisterCompressor(&compressor{})
}

var ErrNotInUse = errors.New("SetLevel ineffective because another zstd compressor has been registered")
Expand All @@ -71,40 +71,42 @@ var ErrNotInUse = errors.New("SetLevel ineffective because another zstd compress
// level. NOTE: this function must only be called from an init function, and
// is not threadsafe.
func SetLevel(level zstd.EncoderLevel) error {
c, ok := encoding.GetCompressor(Name).(*compressor)
_, ok := encoding.GetCompressor(Name).(*compressor)
if !ok {
return ErrNotInUse
}

enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
if err != nil {
return err
}

c.encoder = enc
encoderOptions = append(encoderOptions, zstd.WithEncoderLevel(level))
return nil
}

func (c *compressor) Compress(w io.Writer) (io.WriteCloser, error) {
return &zstdWriteCloser{
enc: c.encoder,
writer: w,
}, nil
}
var err error
var found bool
var encoder *zstd.Encoder

type zstdWriteCloser struct {
enc *zstd.Encoder
writer io.Writer // Compressed data will be written here.
buf bytes.Buffer // Buffer uncompressed data here, compress on Close.
}
encoder, found = c.encoderPool.Get().(*zstd.Encoder)
if !found {
encoder, err = zstd.NewWriter(w, encoderOptions...)
if err != nil {
return nil, err
}
} else {
encoder.Reset(w)
}

wrapper := &encoderWrapper{Encoder: encoder, pool: &c.encoderPool}
runtime.SetFinalizer(wrapper, func(ew *encoderWrapper) {
ew.Reset(nil)
c.encoderPool.Put(ew.Encoder)
})

func (z *zstdWriteCloser) Write(p []byte) (int, error) {
return z.buf.Write(p)
return wrapper, nil
}

func (z *zstdWriteCloser) Close() error {
compressed := z.enc.EncodeAll(z.buf.Bytes(), nil)
_, err := io.Copy(z.writer, bytes.NewReader(compressed))
func (w *encoderWrapper) Close() error {
err := w.Encoder.Close()
w.pool.Put(w.Encoder)
return err
}

Expand Down

0 comments on commit ce6ebe9

Please sign in to comment.