Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coalesce the salt and the first message #69

Merged
merged 2 commits into from
Jul 7, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions shadowsocks/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,28 @@ func (sw *shadowsocksWriter) init() (err error) {
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return fmt.Errorf("failed to generate salt: %v", err)
}
_, err := sw.writer.Write(salt)
if err != nil {
return fmt.Errorf("failed to write salt: %v", err)
}
sw.aead, err = sw.ssCipher.Encrypter(salt)
if err != nil {
return fmt.Errorf("failed to create AEAD: %v", err)
}
sw.counter = make([]byte, sw.aead.NonceSize())
sw.buf = make([]byte, 2+sw.aead.Overhead()+payloadSizeMask+sw.aead.Overhead())
// The maximum length message is the salt (first message only), length, length tag,
// payload, and payload tag.
sizeBufSize := 2 + sw.aead.Overhead()
maxPayloadBufSize := payloadSizeMask + sw.aead.Overhead()
sw.buf = make([]byte, len(salt)+sizeBufSize+maxPayloadBufSize)
fortuna marked this conversation as resolved.
Show resolved Hide resolved
// Store the salt at the start of sw.buf.
copy(sw.buf, salt)
}
return nil
}

// WriteBlock encrypts and writes the input buffer as one signed block.
func (sw *shadowsocksWriter) encryptBlock(ciphertext []byte, plaintext []byte) ([]byte, error) {
out := sw.aead.Seal(ciphertext, sw.counter, plaintext, nil)
// encryptBlock encrypts `plaintext` in-place. The slice must have enough capacity
// for the tag. Returns the total ciphertext length.
func (sw *shadowsocksWriter) encryptBlock(plaintext []byte) int {
out := sw.aead.Seal(plaintext[:0], sw.counter, plaintext, nil)
increment(sw.counter)
return out, nil
return len(out)
}

func (sw *shadowsocksWriter) Write(p []byte) (int, error) {
Expand All @@ -88,28 +91,43 @@ func (sw *shadowsocksWriter) Write(p []byte) (int, error) {
return int(n), err
}

func isZero(b []byte) bool {
for _, v := range b {
if v != 0 {
return false
}
}
return true
}

func (sw *shadowsocksWriter) ReadFrom(r io.Reader) (int64, error) {
if err := sw.init(); err != nil {
return 0, err
}
var written int64
sizeBuf := sw.buf[:2+sw.aead.Overhead()]
payloadBuf := sw.buf[len(sizeBuf):]

// sw.buf starts with the salt.
saltSize := sw.ssCipher.SaltSize()
// Normally we ignore the salt at the beginning of sw.buf.
fortuna marked this conversation as resolved.
Show resolved Hide resolved
start := saltSize
if isZero(sw.counter) {
// For the first message, include the salt.
start = 0
}

// Each Shadowsocks-TCP message consists of a fixed-length size block, followed by
// a variable-length payload block.
sizeBuf := sw.buf[saltSize : saltSize+2+sw.aead.Overhead()]
payloadBuf := sw.buf[saltSize+len(sizeBuf):]
for {
plaintextSize, err := r.Read(payloadBuf[:payloadSizeMask])
if plaintextSize > 0 {
binary.BigEndian.PutUint16(sizeBuf, uint16(plaintextSize))
_, err = sw.encryptBlock(sizeBuf[:0], sizeBuf[:2])
if err != nil {
return written, fmt.Errorf("failed to encypt payload size: %v", err)
}
_, err := sw.encryptBlock(payloadBuf[:0], payloadBuf[:plaintextSize])
if err != nil {
return written, fmt.Errorf("failed to encrypt payload: %v", err)
}
payloadSize := plaintextSize + sw.aead.Overhead()
_, err = sw.writer.Write(sw.buf[:len(sizeBuf)+payloadSize])
sw.encryptBlock(sizeBuf[:2])
payloadSize := sw.encryptBlock(payloadBuf[:plaintextSize])
_, err = sw.writer.Write(sw.buf[start : saltSize+len(sizeBuf)+payloadSize])
written += int64(plaintextSize)
start = saltSize // Skip the salt for all writes except the first.
}
if err != nil {
if err == io.EOF { // ignore EOF as per io.ReaderFrom contract
Expand Down