Skip to content

Commit

Permalink
Merge pull request #45 from Jigsaw-Code/bemasc-threading
Browse files Browse the repository at this point in the history
Ensure thread safety on cipher updates
  • Loading branch information
Benjamin M. Schwartz authored Jan 17, 2020
2 parents 39bccfd + 2859196 commit fdd445a
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 37 deletions.
13 changes: 7 additions & 6 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package main

import (
"container/list"
"flag"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -79,8 +80,8 @@ func (s *SSServer) startPort(portNum int) error {
logger.Infof("Listening TCP and UDP on port %v", portNum)
port := &SSPort{cipherList: shadowsocks.NewCipherList()}
// TODO: Register initial data metrics at zero.
port.tcpService = shadowsocks.NewTCPService(listener, &port.cipherList, &s.replayCache, s.m, tcpReadTimeout)
port.udpService = shadowsocks.NewUDPService(packetConn, s.natTimeout, &port.cipherList, s.m)
port.tcpService = shadowsocks.NewTCPService(listener, port.cipherList, &s.replayCache, s.m, tcpReadTimeout)
port.udpService = shadowsocks.NewUDPService(packetConn, s.natTimeout, port.cipherList, s.m)
s.ports[portNum] = port
go port.udpService.Start()
go port.tcpService.Start()
Expand Down Expand Up @@ -112,12 +113,12 @@ func (s *SSServer) loadConfig(filename string) error {
}

portChanges := make(map[int]int)
portCiphers := make(map[int]shadowsocks.CipherList)
portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry.
for _, keyConfig := range config.Keys {
portChanges[keyConfig.Port] = 1
cipherList, ok := portCiphers[keyConfig.Port]
if !ok {
cipherList = shadowsocks.NewCipherList()
cipherList = list.New()
portCiphers[keyConfig.Port] = cipherList
}
cipher, err := core.PickCipher(keyConfig.Cipher, nil, keyConfig.Secret)
Expand All @@ -131,7 +132,7 @@ func (s *SSServer) loadConfig(filename string) error {
if !ok {
return fmt.Errorf("Only AEAD ciphers are supported. Found %v", keyConfig.Cipher)
}
cipherList.PushBack(keyConfig.ID, aead)
cipherList.PushBack(&shadowsocks.CipherEntry{ID: keyConfig.ID, Cipher: aead})
}
for port := range s.ports {
portChanges[port] = portChanges[port] - 1
Expand All @@ -148,7 +149,7 @@ func (s *SSServer) loadConfig(filename string) error {
}
}
for portNum, cipherList := range portCiphers {
s.ports[portNum].cipherList = cipherList
s.ports[portNum].cipherList.Update(cipherList)
}
logger.Infof("Loaded %v access keys", len(config.Keys))
s.m.SetNumAccessKeys(len(config.Keys), len(portCiphers))
Expand Down
27 changes: 16 additions & 11 deletions shadowsocks/cipher_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ type CipherEntry struct {
lastClientIP net.IP
}

// CipherList is a list of CipherEntry elements that allows for thread-safe snapshotting and
// moving to front.
// CipherList is a thread-safe collection of CipherEntry elements that allows for
// snapshotting and moving to front.
type CipherList interface {
PushBack(id string, cipher shadowaead.Cipher) *list.Element
SafeSnapshotForClientIP(clientIP net.IP) []*list.Element
SafeMarkUsedByClientIP(e *list.Element, clientIP net.IP)
SnapshotForClientIP(clientIP net.IP) []*list.Element
MarkUsedByClientIP(e *list.Element, clientIP net.IP)
// Update replaces the current contents of the CipherList with `contents`,
// which is a List of *CipherEntry. Update takes ownership of `contents`,
// which must not be read or written after this call.
Update(contents *list.List)
}

type cipherList struct {
Expand All @@ -49,16 +52,12 @@ func NewCipherList() CipherList {
return &cipherList{list: list.New()}
}

func (cl *cipherList) PushBack(id string, cipher shadowaead.Cipher) *list.Element {
return cl.list.PushBack(&CipherEntry{ID: id, Cipher: cipher})
}

func matchesIP(e *list.Element, clientIP net.IP) bool {
c := e.Value.(*CipherEntry)
return clientIP != nil && clientIP.Equal(c.lastClientIP)
}

func (cl *cipherList) SafeSnapshotForClientIP(clientIP net.IP) []*list.Element {
func (cl *cipherList) SnapshotForClientIP(clientIP net.IP) []*list.Element {
cl.mu.RLock()
defer cl.mu.RUnlock()
cipherArray := make([]*list.Element, cl.list.Len())
Expand All @@ -80,11 +79,17 @@ func (cl *cipherList) SafeSnapshotForClientIP(clientIP net.IP) []*list.Element {
return cipherArray
}

func (cl *cipherList) SafeMarkUsedByClientIP(e *list.Element, clientIP net.IP) {
func (cl *cipherList) MarkUsedByClientIP(e *list.Element, clientIP net.IP) {
cl.mu.Lock()
defer cl.mu.Unlock()
cl.list.MoveToFront(e)

c := e.Value.(*CipherEntry)
c.lastClientIP = clientIP
}

func (cl *cipherList) Update(src *list.List) {
cl.mu.Lock()
cl.list = src
cl.mu.Unlock()
}
7 changes: 5 additions & 2 deletions shadowsocks/cipher_testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,26 @@
package shadowsocks

import (
"container/list"
"fmt"

"github.com/shadowsocks/go-shadowsocks2/core"
"github.com/shadowsocks/go-shadowsocks2/shadowaead"
)

func MakeTestCiphers(numCiphers int) (CipherList, error) {
cipherList := NewCipherList()
l := list.New()
for i := 0; i < numCiphers; i++ {
cipherID := fmt.Sprintf("id-%v", i)
secret := fmt.Sprintf("secret-%v", i)
cipher, err := core.PickCipher("chacha20-ietf-poly1305", nil, secret)
if err != nil {
return nil, fmt.Errorf("Failed to create cipher %v: %v", i, err)
}
cipherList.PushBack(cipherID, cipher.(shadowaead.Cipher))
l.PushBack(&CipherEntry{ID: cipherID, Cipher: cipher.(shadowaead.Cipher)})
}
cipherList := NewCipherList()
cipherList.Update(l)
return cipherList, nil
}

Expand Down
13 changes: 6 additions & 7 deletions shadowsocks/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func findAccessKey(clientConn onet.DuplexConn, cipherList CipherList) (string, o
// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
// We snapshot the list because it may be modified while we use it.
// TODO: Ban and log client IPs with too many failures too quick to protect against DoS.
for ci, entry := range cipherList.SafeSnapshotForClientIP(clientIP) {
for ci, entry := range cipherList.SnapshotForClientIP(clientIP) {
id, cipher := entry.Value.(*CipherEntry).ID, entry.Value.(*CipherEntry).Cipher
firstBytes, err = ensureBytes(clientConn, firstBytes, cipher.SaltSize())
if err != nil {
Expand Down Expand Up @@ -105,7 +105,7 @@ func findAccessKey(clientConn onet.DuplexConn, cipherList CipherList) (string, o
logger.Debugf("TCP: Found cipher %v at index %d", id, ci)
}
// Move the active cipher to the front, so that the search is quicker next time.
cipherList.SafeMarkUsedByClientIP(entry, clientIP)
cipherList.MarkUsedByClientIP(entry, clientIP)
ssr := NewShadowsocksReader(io.MultiReader(bytes.NewReader(firstBytes), clientConn), cipher)
ssw := NewShadowsocksWriter(clientConn, cipher)
return id, onet.WrapConn(clientConn, ssr, ssw).(onet.DuplexConn), salt, nil
Expand All @@ -114,9 +114,8 @@ func findAccessKey(clientConn onet.DuplexConn, cipherList CipherList) (string, o
}

type tcpService struct {
listener *net.TCPListener
// `ciphers` is a pointer to SSPort.cipherList, which can be updated by SSServer.loadConfig.
ciphers *CipherList
listener *net.TCPListener
ciphers CipherList
m metrics.ShadowsocksMetrics
isRunning bool
readTimeout time.Duration
Expand All @@ -125,7 +124,7 @@ type tcpService struct {
}

// NewTCPService creates a TCPService
func NewTCPService(listener *net.TCPListener, ciphers *CipherList, replayCache *ReplayCache, m metrics.ShadowsocksMetrics, timeout time.Duration) TCPService {
func NewTCPService(listener *net.TCPListener, ciphers CipherList, replayCache *ReplayCache, m metrics.ShadowsocksMetrics, timeout time.Duration) TCPService {
return &tcpService{
listener: listener,
ciphers: ciphers,
Expand Down Expand Up @@ -219,7 +218,7 @@ func (s *tcpService) Start() {
}()

findStartTime := time.Now()
keyID, clientConn, salt, err := findAccessKey(clientConn, *s.ciphers)
keyID, clientConn, salt, err := findAccessKey(clientConn, s.ciphers)
timeToCipher = time.Now().Sub(findStartTime)

if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions shadowsocks/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) {
b.Fatal(err)
}
cipherEntries := [numCiphers]*CipherEntry{}
for cipherNumber, element := range cipherList.SafeSnapshotForClientIP(nil) {
for cipherNumber, element := range cipherList.SnapshotForClientIP(nil) {
cipherEntries[cipherNumber] = element.Value.(*CipherEntry)
}
for n := 0; n < b.N; n++ {
Expand Down Expand Up @@ -194,8 +194,8 @@ func TestReplayDefense(t *testing.T) {
replayCache := NewReplayCache(5)
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
s := NewTCPService(listener, &cipherList, &replayCache, testMetrics, testTimeout)
cipherEntry := cipherList.SafeSnapshotForClientIP(nil)[0].Value.(*CipherEntry)
s := NewTCPService(listener, cipherList, &replayCache, testMetrics, testTimeout)
cipherEntry := cipherList.SnapshotForClientIP(nil)[0].Value.(*CipherEntry)
cipher := cipherEntry.Cipher
reader, writer := io.Pipe()
go NewShadowsocksWriter(writer, cipher).Write([]byte{0})
Expand Down Expand Up @@ -277,7 +277,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) {
t.Fatal(err)
}
testMetrics := &probeTestMetrics{}
s := NewTCPService(listener, &cipherList, nil, testMetrics, testTimeout)
s := NewTCPService(listener, cipherList, nil, testMetrics, testTimeout)

testPayload := MakeTestPayload(payloadSize)
done := make(chan bool)
Expand Down
10 changes: 5 additions & 5 deletions shadowsocks/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const udpBufSize = 64 * 1024
func unpack(clientIP net.IP, dst, src []byte, cipherList CipherList) ([]byte, string, shadowaead.Cipher, error) {
// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
// We snapshot the list because it may be modified while we use it.
for ci, entry := range cipherList.SafeSnapshotForClientIP(clientIP) {
for ci, entry := range cipherList.SnapshotForClientIP(clientIP) {
id, cipher := entry.Value.(*CipherEntry).ID, entry.Value.(*CipherEntry).Cipher
buf, err := shadowaead.Unpack(dst, src, cipher)
if err != nil {
Expand All @@ -51,7 +51,7 @@ func unpack(clientIP net.IP, dst, src []byte, cipherList CipherList) ([]byte, st
logger.Debugf("UDP: Found cipher %v at index %d", id, ci)
}
// Move the active cipher to the front, so that the search is quicker next time.
cipherList.SafeMarkUsedByClientIP(entry, clientIP)
cipherList.MarkUsedByClientIP(entry, clientIP)
return buf, id, cipher, nil
}
return nil, "", nil, errors.New("could not find valid cipher")
Expand All @@ -60,13 +60,13 @@ func unpack(clientIP net.IP, dst, src []byte, cipherList CipherList) ([]byte, st
type udpService struct {
clientConn net.PacketConn
natTimeout time.Duration
ciphers *CipherList
ciphers CipherList
m metrics.ShadowsocksMetrics
isRunning bool
}

// NewUDPService creates a UDPService
func NewUDPService(clientConn net.PacketConn, natTimeout time.Duration, cipherList *CipherList, m metrics.ShadowsocksMetrics) UDPService {
func NewUDPService(clientConn net.PacketConn, natTimeout time.Duration, cipherList CipherList, m metrics.ShadowsocksMetrics) UDPService {
return &udpService{clientConn: clientConn, natTimeout: natTimeout, ciphers: cipherList, m: m}
}

Expand Down Expand Up @@ -122,7 +122,7 @@ func (s *udpService) Start() {
logger.Debugf("UDP Request from %v with %v bytes", clientAddr, clientProxyBytes)
unpackStart := time.Now()
ip := clientAddr.(*net.UDPAddr).IP
buf, keyID, cipher, err := unpack(ip, textBuf, cipherBuf[:clientProxyBytes], *s.ciphers)
buf, keyID, cipher, err := unpack(ip, textBuf, cipherBuf[:clientProxyBytes], s.ciphers)
timeToCipher = time.Now().Sub(unpackStart)

if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions shadowsocks/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func BenchmarkUDPUnpackRepeat(b *testing.B) {
testBuf := make([]byte, udpBufSize)
packets := [numCiphers][]byte{}
ips := [numCiphers]net.IP{}
for i, element := range cipherList.SafeSnapshotForClientIP(nil) {
for i, element := range cipherList.SnapshotForClientIP(nil) {
packets[i] = make([]byte, 0, udpBufSize)
plaintext := MakeTestPayload(50)
packets[i], err = shadowaead.Pack(make([]byte, udpBufSize), plaintext, element.Value.(*CipherEntry).Cipher)
Expand Down Expand Up @@ -84,7 +84,7 @@ func BenchmarkUDPUnpackSharedKey(b *testing.B) {
}
testBuf := make([]byte, udpBufSize)
plaintext := MakeTestPayload(50)
cipher := cipherList.SafeSnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher
cipher := cipherList.SnapshotForClientIP(nil)[0].Value.(*CipherEntry).Cipher
packet, err := shadowaead.Pack(make([]byte, udpBufSize), plaintext, cipher)

const numIPs = 100 // Must be <256
Expand Down

0 comments on commit fdd445a

Please sign in to comment.