Skip to content

Commit

Permalink
Merge pull request #25 from Jigsaw-Code/bemasc-ip
Browse files Browse the repository at this point in the history
Optimize trial decryption ordering
  • Loading branch information
Benjamin M. Schwartz authored Aug 2, 2019
2 parents 8972cd6 + 807b589 commit 2d8ba61
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 25 deletions.
37 changes: 29 additions & 8 deletions shadowsocks/cipher_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,26 @@ package shadowsocks

import (
"container/list"
"net"
"sync"

"github.com/shadowsocks/go-shadowsocks2/shadowaead"
)

// CipherEntry holds a Cipher with an identifier.
// The public fields are constant, but lastAddress is mutable under cipherList.mu.
type CipherEntry struct {
ID string
Cipher shadowaead.Cipher
ID string
Cipher shadowaead.Cipher
lastClientIP net.IP
}

// CipherList is a list of CipherEntry elements that allows for thread-safe snapshotting and
// moving to front.
type CipherList interface {
PushBack(id string, cipher shadowaead.Cipher) *list.Element
SafeSnapshot() []*list.Element
SafeMoveToFront(e *list.Element)
SafeSnapshotForClientIP(clientIP net.IP) []*list.Element
SafeMarkUsedByClientIP(e *list.Element, clientIP net.IP)
}

type cipherList struct {
Expand All @@ -50,20 +53,38 @@ func (cl *cipherList) PushBack(id string, cipher shadowaead.Cipher) *list.Elemen
return cl.list.PushBack(&CipherEntry{ID: id, Cipher: cipher})
}

func (cl *cipherList) SafeSnapshot() []*list.Element {
func matchesIP(e *list.Element, clientIP net.IP) bool {
c := e.Value.(*CipherEntry)
return clientIP != nil && clientIP.Equal(c.lastClientIP)
}

func (cl *cipherList) SafeSnapshotForClientIP(clientIP net.IP) []*list.Element {
cl.mu.RLock()
defer cl.mu.RUnlock()
cipherArray := make([]*list.Element, cl.list.Len())
i := 0
// First pass: put all ciphers with matching last known IP at the front.
for e := cl.list.Front(); e != nil; e = e.Next() {
cipherArray[i] = e
i++
if matchesIP(e, clientIP) {
cipherArray[i] = e
i++
}
}
// Second pass: include all remaining ciphers in recency order.
for e := cl.list.Front(); e != nil; e = e.Next() {
if !matchesIP(e, clientIP) {
cipherArray[i] = e
i++
}
}
return cipherArray
}

func (cl *cipherList) SafeMoveToFront(e *list.Element) {
func (cl *cipherList) SafeMarkUsedByClientIP(e *list.Element, clientIP net.IP) {
cl.mu.Lock()
defer cl.mu.Unlock()
cl.list.MoveToFront(e)

c := e.Value.(*CipherEntry)
c.lastClientIP = clientIP
}
31 changes: 23 additions & 8 deletions shadowsocks/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ import (
"net"
"time"

logging "github.com/op/go-logging"

"github.com/Jigsaw-Code/outline-ss-server/metrics"
onet "github.com/Jigsaw-Code/outline-ss-server/net"
logging "github.com/op/go-logging"

"github.com/shadowsocks/go-shadowsocks2/socks"
)
Expand All @@ -47,6 +46,21 @@ func ensureBytes(reader io.Reader, buf []byte, bytesNeeded int) ([]byte, error)
return buf, err
}

func remoteIP(conn net.Conn) net.IP {
addr := conn.RemoteAddr()
if addr == nil {
return nil
}
if tcpaddr, ok := addr.(*net.TCPAddr); ok {
return tcpaddr.IP
}
ipstr, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err == nil {
return net.ParseIP(ipstr)
}
return nil
}

func findAccessKey(clientConn onet.DuplexConn, cipherList CipherList) (string, onet.DuplexConn, error) {
// This must have enough space to hold the salt + 2 bytes chunk length + AEAD tag (Oeverhead) for any cipher
replayBytes := make([]byte, 0, 32+2+16)
Expand All @@ -56,15 +70,16 @@ func findAccessKey(clientConn onet.DuplexConn, cipherList CipherList) (string, o
chunkLenBuf := [2]byte{}
var err error

clientIP := remoteIP(clientConn)
// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
// We snapshot the list because it may be modified while we use it.
// TODO: Ban and log client IPs with too many failures too quick to protect against DoS.
for _, entry := range cipherList.SafeSnapshot() {
for ci, entry := range cipherList.SafeSnapshotForClientIP(clientIP) {
id, cipher := entry.Value.(*CipherEntry).ID, entry.Value.(*CipherEntry).Cipher
replayBytes, err = ensureBytes(clientConn, replayBytes, cipher.SaltSize())
if err != nil {
if logger.IsEnabledFor(logging.DEBUG) {
logger.Debugf("Failed TCP ciper %v: %v", id, err)
logger.Debugf("TCP: Failed to read salt %v: %v", id, err)
}
continue
}
Expand All @@ -73,23 +88,23 @@ func findAccessKey(clientConn onet.DuplexConn, cipherList CipherList) (string, o
replayBytes, err = ensureBytes(clientConn, replayBytes, cipher.SaltSize()+2+aead.Overhead())
if err != nil {
if logger.IsEnabledFor(logging.DEBUG) {
logger.Debugf("Failed TCP ciper %v: %v", id, err)
logger.Debugf("TCP: Failed to read length %v: %v", id, err)
}
continue
}
cipherText := replayBytes[cipher.SaltSize() : cipher.SaltSize()+2+aead.Overhead()]
_, err = aead.Open(chunkLenBuf[:0], zeroCountBuf[:aead.NonceSize()], cipherText, nil)
if err != nil {
if logger.IsEnabledFor(logging.DEBUG) {
logger.Debugf("Failed TCP ciper %v: %v", id, err)
logger.Debugf("TCP: Failed to decrypt length %v: %v", id, err)
}
continue
}
if logger.IsEnabledFor(logging.DEBUG) {
logger.Debugf("Selected TCP cipher %v", id)
logger.Debugf("TCP: Found cipher %v at index %d", id, ci)
}
// Move the active cipher to the front, so that the search is quicker next time.
cipherList.SafeMoveToFront(entry)
cipherList.SafeMarkUsedByClientIP(entry, clientIP)
ssr := NewShadowsocksReader(io.MultiReader(bytes.NewReader(replayBytes), clientConn), cipher)
ssw := NewShadowsocksWriter(clientConn, cipher)
return id, onet.WrapConn(clientConn, ssr, ssw).(onet.DuplexConn), nil
Expand Down
96 changes: 95 additions & 1 deletion shadowsocks/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
package shadowsocks

import (
"errors"
"io"
"net"
"testing"
"time"

logging "github.com/op/go-logging"
onet "github.com/Jigsaw-Code/outline-ss-server/net"
)

func BenchmarkTCPFindCipher(b *testing.B) {
// Simulates receiving invalid TCP connection attempts on a server with 100 ciphers.
func BenchmarkTCPFindCipherFail(b *testing.B) {
b.StopTimer()
b.ResetTimer()

Expand Down Expand Up @@ -54,3 +59,92 @@ func BenchmarkTCPFindCipher(b *testing.B) {
b.StopTimer()
}
}

// Fake DuplexConn
// 1-way pipe, representing the upstream flow as seen by the server.
type conn struct {
onet.DuplexConn
clientAddr net.Addr
reader io.ReadCloser
writer io.WriteCloser
}

func (c *conn) Read(b []byte) (int, error) {
return c.reader.Read(b)
}

func (c *conn) Write(b []byte) (int, error) {
// Any downstream data is ignored.
return len(b), nil
}

func (c *conn) Close() error {
e1 := c.reader.Close()
e2 := c.writer.Close()
if e1 != nil {
return e1
}
return e2
}

func (c *conn) LocalAddr() net.Addr {
return nil
}

func (c *conn) RemoteAddr() net.Addr {
return c.clientAddr
}

func (c *conn) SetDeadline(t time.Time) error {
return errors.New("SetDeadline is not supported")
}

func (c *conn) SetReadDeadline(t time.Time) error {
return errors.New("SetDeadline is not supported")
}

func (c *conn) SetWriteDeadline(t time.Time) error {
return errors.New("SetDeadline is not supported")
}

func (c *conn) CloseRead() error {
return c.reader.Close()
}

func (c *conn) CloseWrite() error {
return nil
}

// Simulates receiving valid TCP connection attempts from 100 different users,
// each with their own cipher and their own IP address.
func BenchmarkTCPFindCipherRepeat(b *testing.B) {
b.StopTimer()
b.ResetTimer()

logging.SetLevel(logging.INFO, "")

const numCiphers = 100 // Must be <256
cipherList, err := MakeTestCiphers(numCiphers)
if err != nil {
b.Fatal(err)
}
cipherEntries := [numCiphers]*CipherEntry{}
for cipherNumber, element := range cipherList.SafeSnapshotForClientIP(nil) {
cipherEntries[cipherNumber] = element.Value.(*CipherEntry)
}
for n := 0; n < b.N; n++ {
cipherNumber := byte(n % numCiphers)
reader, writer := io.Pipe()
addr := &net.TCPAddr{IP: net.IPv4(192, 0, 2, cipherNumber), Port: 54321}
c := conn{clientAddr: addr, reader: reader, writer: writer}
cipher := cipherEntries[cipherNumber].Cipher
go NewShadowsocksWriter(writer, cipher).Write(MakeTestPayload(50))
b.StartTimer()
_, _, err := findAccessKey(&c, cipherList)
b.StopTimer()
if err != nil {
b.Error(err)
}
c.Close()
}
}
13 changes: 7 additions & 6 deletions shadowsocks/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,23 @@ const udpBufSize = 64 * 1024

// upack decrypts src into dst. It tries each cipher until it finds one that authenticates
// correctly. dst and src must not overlap.
func unpack(dst, src []byte, cipherList CipherList) ([]byte, string, shadowaead.Cipher, error) {
func unpack(clientIP net.IP, dst, src []byte, cipherList CipherList) ([]byte, string, shadowaead.Cipher, error) {
// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
// We snapshot the list because it may be modified while we use it.
for _, entry := range cipherList.SafeSnapshot() {
for ci, entry := range cipherList.SafeSnapshotForClientIP(clientIP) {
id, cipher := entry.Value.(*CipherEntry).ID, entry.Value.(*CipherEntry).Cipher
buf, err := shadowaead.Unpack(dst, src, cipher)
if err != nil {
if logger.IsEnabledFor(logging.DEBUG) {
logger.Debugf("Failed UDP cipher %v: %v", id, err)
logger.Debugf("UDP: Failed to unpack with cipher %v: %v", id, err)
}
continue
}
if logger.IsEnabledFor(logging.DEBUG) {
logger.Debugf("Selected UDP cipher %v", id)
logger.Debugf("UDP: Found cipher %v at index %d", id, ci)
}
// Move the active cipher to the front, so that the search is quicker next time.
cipherList.SafeMoveToFront(entry)
cipherList.SafeMarkUsedByClientIP(entry, clientIP)
return buf, id, cipher, nil
}
return nil, "", nil, errors.New("could not find valid cipher")
Expand Down Expand Up @@ -121,7 +121,8 @@ func (s *udpService) Start() {
defer logger.Debugf("UDP done with %v", clientAddr.String())
logger.Debugf("UDP Request from %v with %v bytes", clientAddr, clientProxyBytes)
unpackStart := time.Now()
buf, keyID, cipher, err := unpack(textBuf, cipherBuf[:clientProxyBytes], *s.ciphers)
ip := clientAddr.(*net.IPAddr).IP
buf, keyID, cipher, err := unpack(ip, textBuf, cipherBuf[:clientProxyBytes], *s.ciphers)
timeToCipher = time.Now().Sub(unpackStart)

if err != nil {
Expand Down
71 changes: 69 additions & 2 deletions shadowsocks/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
package shadowsocks

import (
"net"
"testing"

logging "github.com/op/go-logging"
"github.com/shadowsocks/go-shadowsocks2/shadowaead"
)

func BenchmarkUDPUnpack(b *testing.B) {
// Simulates receiving invalid UDP packets on a server with 100 ciphers.
func BenchmarkUDPUnpackFail(b *testing.B) {
logging.SetLevel(logging.INFO, "")

cipherList, err := MakeTestCiphers(100)
Expand All @@ -29,8 +32,72 @@ func BenchmarkUDPUnpack(b *testing.B) {
}
testPayload := MakeTestPayload(50)
textBuf := make([]byte, udpBufSize)
testIP := net.ParseIP("192.0.2.1")
b.ResetTimer()
for n := 0; n < b.N; n++ {
unpack(textBuf, testPayload, cipherList)
unpack(testIP, textBuf, testPayload, cipherList)
}
}

// Simulates receiving valid UDP packets from 100 different users, each with
// their own cipher and IP address.
func BenchmarkUDPUnpackRepeat(b *testing.B) {
logging.SetLevel(logging.INFO, "")

const numCiphers = 100 // Must be <256
cipherList, err := MakeTestCiphers(numCiphers)
if err != nil {
b.Fatal(err)
}
testBuf := make([]byte, udpBufSize)
packets := [numCiphers][]byte{}
ips := [numCiphers]net.IP{}
for i, element := range cipherList.SafeSnapshotForClientIP(nil) {
packets[i] = make([]byte, 0, udpBufSize)
plaintext := MakeTestPayload(50)
packets[i], err = shadowaead.Pack(make([]byte, udpBufSize), plaintext, element.Value.(*CipherEntry).Cipher)
if err != nil {
b.Error(err)
}
ips[i] = net.IPv4(192, 0, 2, byte(i))
}
b.ResetTimer()
for n := 0; n < b.N; n++ {
cipherNumber := n % numCiphers
ip := ips[cipherNumber]
packet := packets[cipherNumber]
_, _, _, err := unpack(ip, testBuf, packet, cipherList)
if err != nil {
b.Error(err)
}
}
}

// Simulates receiving valid UDP packets from 100 different IP addresses,
// all using the same cipher.
func BenchmarkUDPUnpackSharedKey(b *testing.B) {
logging.SetLevel(logging.INFO, "")

cipherList, err := MakeTestCiphers(1) // One widely shared key
if err != nil {
b.Fatal(err)
}
testBuf := make([]byte, udpBufSize)
plaintext := MakeTestPayload(50)
cipher := cipherList.SafeSnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher
packet, err := shadowaead.Pack(make([]byte, udpBufSize), plaintext, cipher)

const numIPs = 100 // Must be <256
ips := [numIPs]net.IP{}
for i := 0; i < numIPs; i++ {
ips[i] = net.IPv4(192, 0, 2, byte(i))
}
b.ResetTimer()
for n := 0; n < b.N; n++ {
ip := ips[n%numIPs]
_, _, _, err := unpack(ip, testBuf, packet, cipherList)
if err != nil {
b.Error(err)
}
}
}

0 comments on commit 2d8ba61

Please sign in to comment.