Skip to content

Commit

Permalink
[feat] - additional buffer pool (#2829)
Browse files Browse the repository at this point in the history
* move buffer pool logic into own pkg

* move

* fix test

* fix test

* fix test

* remove

* fix test

* whoops

* revert
  • Loading branch information
ahrav authored May 16, 2024
1 parent ac3de97 commit 5e3d660
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 33 deletions.
14 changes: 4 additions & 10 deletions pkg/buffers/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,19 @@ func (poolMetrics) recordBufferReturn(buf *buffer.Buffer) {
buf.RecordMetric()
}

// Opts is a function that configures a BufferPool.
type Opts func(pool *Pool)

// Pool of buffers.
type Pool struct {
*sync.Pool
bufferSize uint32
bufferSize int

metrics poolMetrics
}

const defaultBufferSize = 1 << 12 // 4KB
// NewBufferPool creates a new instance of BufferPool.
func NewBufferPool(opts ...Opts) *Pool {
pool := &Pool{bufferSize: defaultBufferSize}
func NewBufferPool(size int) *Pool {
pool := &Pool{bufferSize: size}

for _, opt := range opts {
opt(pool)
}
pool.Pool = &sync.Pool{
New: func() any {
return &buffer.Buffer{Buffer: bytes.NewBuffer(make([]byte, 0, pool.bufferSize))}
Expand Down Expand Up @@ -72,7 +66,7 @@ func (p *Pool) Put(buf *buffer.Buffer) {
// If the Buffer is more than twice the default size, replace it with a new Buffer.
// This prevents us from returning very large buffers to the pool.
const maxAllowedCapacity = 2 * defaultBufferSize
if buf.Cap() > maxAllowedCapacity {
if buf.Cap() > int(maxAllowedCapacity) {
p.metrics.recordShrink(buf.Cap() - defaultBufferSize)
buf = &buffer.Buffer{Buffer: bytes.NewBuffer(make([]byte, 0, p.bufferSize))}
} else {
Expand Down
20 changes: 10 additions & 10 deletions pkg/buffers/pool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ func TestNewBufferPool(t *testing.T) {
t.Parallel()
tests := []struct {
name string
opts []Opts
expectedBuffSize uint32
size int
expectedBuffSize int
}{
{name: "Default pool size", expectedBuffSize: defaultBufferSize},
{name: "Default pool size", size: defaultBufferSize, expectedBuffSize: defaultBufferSize},
{
name: "Custom pool size",
opts: []Opts{func(p *Pool) { p.bufferSize = 8 * 1024 }}, // 8KB
size: 8 * 1024,
expectedBuffSize: 8 * 1024,
},
}
Expand All @@ -28,7 +28,7 @@ func TestNewBufferPool(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
pool := NewBufferPool(tc.opts...)
pool := NewBufferPool(tc.size)
assert.Equal(t, tc.expectedBuffSize, pool.bufferSize)
})
}
Expand All @@ -47,25 +47,25 @@ func TestBufferPoolGetPut(t *testing.T) {
preparePool: func(_ *Pool) *buffer.Buffer {
return nil // No initial buffer to put
},
expectedCapBefore: int(defaultBufferSize),
expectedCapAfter: int(defaultBufferSize),
expectedCapBefore: defaultBufferSize,
expectedCapAfter: defaultBufferSize,
},
{
name: "Put oversized buffer, expect shrink",
preparePool: func(p *Pool) *buffer.Buffer {
buf := &buffer.Buffer{Buffer: bytes.NewBuffer(make([]byte, 0, 3*defaultBufferSize))}
return buf
},
expectedCapBefore: int(defaultBufferSize),
expectedCapAfter: int(defaultBufferSize), // Should shrink back to default
expectedCapBefore: defaultBufferSize,
expectedCapAfter: defaultBufferSize, // Should shrink back to default
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
pool := NewBufferPool()
pool := NewBufferPool(defaultBufferSize)
initialBuf := tc.preparePool(pool)
if initialBuf != nil {
pool.Put(initialBuf)
Expand Down
8 changes: 5 additions & 3 deletions pkg/writers/buffer_writer/bufferwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ func (metrics) recordDataProcessed(size int64, dur time.Duration) {
totalWriteDuration.Add(float64(dur.Microseconds()))
}

func init() { bufferPool = pool.NewBufferPool() }
const defaultBufferSize = 1 << 12 // 4KB
func init() { bufferPool = pool.NewBufferPool(defaultBufferSize) }

// bufferPool is the shared Buffer pool used by all BufferedFileWriters.
// This allows for efficient reuse of buffers across multiple writers.
Expand All @@ -44,14 +45,15 @@ type BufferWriter struct {
}

// New creates a new instance of BufferWriter.
func New() *BufferWriter { return &BufferWriter{state: writeOnly, bufPool: bufferPool} }
func New() *BufferWriter {
return &BufferWriter{state: writeOnly, bufPool: bufferPool}
}

// Write delegates the writing operation to the underlying bytes.Buffer.
func (b *BufferWriter) Write(data []byte) (int, error) {
if b.state != writeOnly {
return 0, fmt.Errorf("buffer must be in write-only mode to write data; current state: %d", b.state)
}

if b.buf == nil {
b.buf = b.bufPool.Get()
if b.buf == nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/writers/buffer_writer/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ var (
Namespace: common.MetricsNamespace,
Subsystem: common.MetricsSubsystem,
Name: "buffer_writer_write_size_bytes",
Help: "Size of data written by the BufferWriter in bytes.",
Help: "Total size of data written by the BufferWriter in bytes.",
Buckets: prometheus.ExponentialBuckets(100, 10, 7),
})

Expand Down
60 changes: 52 additions & 8 deletions pkg/writers/buffered_file_writer/bufferedfilewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/cleantemp"
)

// sharedBufferPool is the shared buffer pool used by all BufferedFileWriters.
// This allows for efficient reuse of buffers across multiple writers.
var sharedBufferPool *pool.Pool

func init() { sharedBufferPool = pool.NewBufferPool() }

type bufferedFileWriterMetrics struct{}

func (bufferedFileWriterMetrics) recordDataProcessed(size uint64, dur time.Duration) {
Expand All @@ -33,6 +27,30 @@ func (bufferedFileWriterMetrics) recordDiskWrite(size int64) {
fileSizeHistogram.Observe(float64(size))
}

type PoolSize int

const (
Default PoolSize = iota
Large
)

const (
defaultBufferSize = 1 << 12 // 4KB
largeBufferSize = 1 << 16 // 64KB
)

func init() {
defaultBufferPool = pool.NewBufferPool(defaultBufferSize)
largeBufferPool = pool.NewBufferPool(largeBufferSize)
}

// Different buffer pools for different buffer sizes.
// This allows for more efficient memory management based on the size of the data being written.
var (
defaultBufferPool *pool.Pool
largeBufferPool *pool.Pool
)

// state represents the current mode of BufferedFileWriter.
type state uint8

Expand Down Expand Up @@ -67,28 +85,51 @@ func WithThreshold(threshold uint64) Option {
return func(w *BufferedFileWriter) { w.threshold = threshold }
}

// WithBufferSize sets the buffer size for the BufferedFileWriter.
func WithBufferSize(size PoolSize) Option {
return func(w *BufferedFileWriter) {
switch size {
case Default:
w.bufPool = defaultBufferPool
case Large:
w.bufPool = largeBufferPool
default:
w.bufPool = defaultBufferPool
}
}
}

const defaultThreshold = 10 * 1024 * 1024 // 10MB
// New creates a new BufferedFileWriter with the given options.
func New(opts ...Option) *BufferedFileWriter {
w := &BufferedFileWriter{
threshold: defaultThreshold,
state: writeOnly,
bufPool: sharedBufferPool,
}

for _, opt := range opts {
opt(w)
}

if w.bufPool == nil {
w.bufPool = defaultBufferPool
}

return w
}

// NewFromReader creates a new instance of BufferedFileWriter and writes the content from the provided reader to the writer.
func NewFromReader(r io.Reader, opts ...Option) (*BufferedFileWriter, error) {
opts = append(opts, WithBufferSize(Large))
writer := New(opts...)
if _, err := io.Copy(writer, r); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("error writing to buffered file writer: %w", err)
}

if writer.buf == nil {
return nil, fmt.Errorf("buffer is empty, no reader created")
}

return writer, nil
}

Expand Down Expand Up @@ -163,9 +204,12 @@ func (w *BufferedFileWriter) Write(data []byte) (int, error) {
// This ensures all the data is in one place - either entirely in the buffer or the file.
if bufferLength > 0 {
if _, err := w.buf.WriteTo(w.file); err != nil {
if err := os.RemoveAll(w.filename); err != nil {
return 0, fmt.Errorf("failed to remove file: %w", err)
}
return 0, err
}
w.bufPool.Put(w.buf)
w.buf.Reset()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ func TestNewFromReaderThresholdExceeded(t *testing.T) {
}

func TestBufferWriterCloseForWritingWithFile(t *testing.T) {
bufPool := pool.NewBufferPool()
bufPool := pool.NewBufferPool(defaultBufferSize)

buf := bufPool.Get()
writer := &BufferedFileWriter{
Expand Down

0 comments on commit 5e3d660

Please sign in to comment.