From a41ed882e56f2ea744680fc0e3acd82d321b502e Mon Sep 17 00:00:00 2001 From: Ben Schwartz Date: Tue, 24 Nov 2020 13:13:48 -0500 Subject: [PATCH] Make encryption buffer size adjustable This will allow saving memory in cases where we are willing to limit the encryption chunk size. --- shadowsocks/stream.go | 23 +++++++++++++--- shadowsocks/stream_test.go | 54 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/shadowsocks/stream.go b/shadowsocks/stream.go index a1a478c2..1db4dfca 100644 --- a/shadowsocks/stream.go +++ b/shadowsocks/stream.go @@ -51,6 +51,8 @@ type Writer struct { byteWrapper bytes.Reader // Number of plaintext bytes that are currently buffered. pending int + // Maximum number of plaintext bytes to write in each chunk. + maxChunkSize int // These are populated by init(): buf []byte aead cipher.AEAD @@ -61,7 +63,12 @@ type Writer struct { // NewShadowsocksWriter creates a Writer that encrypts the given Writer using // the shadowsocks protocol with the given shadowsocks cipher. func NewShadowsocksWriter(writer io.Writer, ssCipher *Cipher) *Writer { - return &Writer{writer: writer, ssCipher: ssCipher, saltGenerator: RandomSaltGenerator} + return &Writer{ + writer: writer, + ssCipher: ssCipher, + saltGenerator: RandomSaltGenerator, + maxChunkSize: payloadSizeMask, + } } // SetSaltGenerator sets the salt generator to be used. Must be called before the first write. @@ -69,6 +76,16 @@ func (sw *Writer) SetSaltGenerator(saltGenerator SaltGenerator) { sw.saltGenerator = saltGenerator } +// SetMaxChunkSize sets the maximum number of bytes to encrypt as a chunk. +// Defaults to 2^16-1. Smaller values save memory but increase framing overhead. +// Must be called before the first write. +func (sw *Writer) SetMaxChunkSize(maxChunkSize int) { + if maxChunkSize > payloadSizeMask || maxChunkSize <= 0 { + maxChunkSize = payloadSizeMask + } + sw.maxChunkSize = maxChunkSize +} + // init generates a random salt, sets up the AEAD object and writes // the salt to the inner Writer. func (sw *Writer) init() (err error) { @@ -86,7 +103,7 @@ func (sw *Writer) init() (err error) { // The maximum length message is the salt (first message only), length, length tag, // payload, and payload tag. sizeBufSize := 2 + sw.aead.Overhead() - maxPayloadBufSize := payloadSizeMask + sw.aead.Overhead() + maxPayloadBufSize := sw.maxChunkSize + sw.aead.Overhead() sw.buf = make([]byte, len(salt)+sizeBufSize+maxPayloadBufSize) // Store the salt at the start of sw.buf. copy(sw.buf, salt) @@ -165,7 +182,7 @@ func (sw *Writer) buffers() (sizeBuf, payloadBuf []byte) { // followed by a variable-length payload block. sizeBuf = sw.buf[saltSize : saltSize+2] payloadStart := saltSize + 2 + sw.aead.Overhead() - payloadBuf = sw.buf[payloadStart : payloadStart+payloadSizeMask] + payloadBuf = sw.buf[payloadStart : payloadStart+sw.maxChunkSize] return } diff --git a/shadowsocks/stream_test.go b/shadowsocks/stream_test.go index 78eebfce..80a17b9a 100644 --- a/shadowsocks/stream_test.go +++ b/shadowsocks/stream_test.go @@ -411,3 +411,57 @@ func TestLazyWriteConcurrentFlush(t *testing.T) { t.Errorf("Wrong final content: %v", decrypted) } } + +func TestChunkSizeIntegrity(t *testing.T) { + t.Parallel() + cipher := newTestCipher(t) + + // Test extreme and reasonable values. + testChunkSizes := []int{ + 1, 2, 3, 4, 250, 256, 257, 1000, 4000, 8192, 16383, + } + + const numWrites = 5 + + input := make([]byte, numWrites*16383) + for i := range input { + input[i] = byte(i) // Arbitrary test contents + } + + for _, maxChunkSize := range testChunkSizes { + maxChunkSize := maxChunkSize + t.Run(fmt.Sprintf("maxChunkSize=%d", maxChunkSize), func(t *testing.T) { + t.Parallel() + connReader, connWriter := io.Pipe() + writer := NewShadowsocksWriter(connWriter, cipher) + writer.SetMaxChunkSize(maxChunkSize) + reader := NewShadowsocksReader(connReader, cipher) + go func() { + defer connWriter.Close() + if _, err := writer.Write(input[:numWrites*maxChunkSize]); err != nil { + t.Errorf("Failed Write: %v", err) + } + }() + + // Check that all writes have the expected size and contents. + buf := make([]byte, 2*maxChunkSize) + for i := 0; i < numWrites; i++ { + n, err := reader.Read(buf) + if err != nil { + t.Errorf("Read failed at chunk %d: %v", i, err) + } else if n != maxChunkSize { + t.Errorf("Chunk %d has wrong size: %d", i, n) + } else if !bytes.Equal(input[i*maxChunkSize:][:maxChunkSize], buf[:n]) { + t.Errorf("Data mismatch at chunk %d", i) + } + } + // Check that the stream is closed cleanly. + n, err := reader.Read(buf) + if n != 0 { + t.Errorf("Got %d extra bytes", n) + } else if err != io.EOF { + t.Errorf("Wanted EOF, got %v", err) + } + }) + } +}