Skip to content

Commit

Permalink
Ensure thread safety on cipher updates
Browse files Browse the repository at this point in the history
This change makes cipher updates a thread-safe operation, and
avoids the unusual pointer-to-interface construction.
  • Loading branch information
Ben Schwartz committed Jan 14, 2020
1 parent 3c2b0f1 commit 6fbc31c
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 12 deletions.
6 changes: 3 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ func (s *SSServer) startPort(portNum int) error {
logger.Infof("Listening TCP and UDP on port %v", portNum)
port := &SSPort{cipherList: shadowsocks.NewCipherList()}
// TODO: Register initial data metrics at zero.
port.tcpService = shadowsocks.NewTCPService(listener, &port.cipherList, &s.replayCache, s.m, tcpReadTimeout)
port.udpService = shadowsocks.NewUDPService(packetConn, s.natTimeout, &port.cipherList, s.m)
port.tcpService = shadowsocks.NewTCPService(listener, port.cipherList, &s.replayCache, s.m, tcpReadTimeout)
port.udpService = shadowsocks.NewUDPService(packetConn, s.natTimeout, port.cipherList, s.m)
s.ports[portNum] = port
go port.udpService.Start()
go port.tcpService.Start()
Expand Down Expand Up @@ -148,7 +148,7 @@ func (s *SSServer) loadConfig(filename string) error {
}
}
for portNum, cipherList := range portCiphers {
s.ports[portNum].cipherList = cipherList
s.ports[portNum].cipherList.SafeSwap(cipherList)
}
logger.Infof("Loaded %v access keys", len(config.Keys))
s.m.SetNumAccessKeys(len(config.Keys), len(portCiphers))
Expand Down
10 changes: 10 additions & 0 deletions shadowsocks/cipher_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type CipherList interface {
PushBack(id string, cipher shadowaead.Cipher) *list.Element
SafeSnapshotForClientIP(clientIP net.IP) []*list.Element
SafeMarkUsedByClientIP(e *list.Element, clientIP net.IP)
SafeSwap(other CipherList)
}

type cipherList struct {
Expand Down Expand Up @@ -88,3 +89,12 @@ func (cl *cipherList) SafeMarkUsedByClientIP(e *list.Element, clientIP net.IP) {
c := e.Value.(*CipherEntry)
c.lastClientIP = clientIP
}

func (cl *cipherList) SafeSwap(other CipherList) {
cl2 := other.(*cipherList)
cl.mu.Lock()
cl2.mu.Lock()
cl.list, cl2.list = cl2.list, cl.list
cl2.mu.Unlock()
cl.mu.Unlock()
}
7 changes: 3 additions & 4 deletions shadowsocks/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ func findAccessKey(clientConn onet.DuplexConn, cipherList CipherList) (string, o

type tcpService struct {
listener *net.TCPListener
// `ciphers` is a pointer to SSPort.cipherList, which can be updated by SSServer.loadConfig.
ciphers *CipherList
ciphers CipherList
m metrics.ShadowsocksMetrics
isRunning bool
readTimeout time.Duration
Expand All @@ -125,7 +124,7 @@ type tcpService struct {
}

// NewTCPService creates a TCPService
func NewTCPService(listener *net.TCPListener, ciphers *CipherList, replayCache *ReplayCache, m metrics.ShadowsocksMetrics, timeout time.Duration) TCPService {
func NewTCPService(listener *net.TCPListener, ciphers CipherList, replayCache *ReplayCache, m metrics.ShadowsocksMetrics, timeout time.Duration) TCPService {
return &tcpService{
listener: listener,
ciphers: ciphers,
Expand Down Expand Up @@ -219,7 +218,7 @@ func (s *tcpService) Start() {
}()

findStartTime := time.Now()
keyID, clientConn, salt, err := findAccessKey(clientConn, *s.ciphers)
keyID, clientConn, salt, err := findAccessKey(clientConn, s.ciphers)
timeToCipher = time.Now().Sub(findStartTime)

if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions shadowsocks/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func TestReplayDefense(t *testing.T) {
replayCache := NewReplayCache(5)
testMetrics := &probeTestMetrics{}
const testTimeout = 200 * time.Millisecond
s := NewTCPService(listener, &cipherList, &replayCache, testMetrics, testTimeout)
s := NewTCPService(listener, cipherList, &replayCache, testMetrics, testTimeout)
cipherEntry := cipherList.SafeSnapshotForClientIP(nil)[0].Value.(*CipherEntry)
cipher := cipherEntry.Cipher
reader, writer := io.Pipe()
Expand Down Expand Up @@ -277,7 +277,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) {
t.Fatal(err)
}
testMetrics := &probeTestMetrics{}
s := NewTCPService(listener, &cipherList, nil, testMetrics, testTimeout)
s := NewTCPService(listener, cipherList, nil, testMetrics, testTimeout)

testPayload := MakeTestPayload(payloadSize)
done := make(chan bool)
Expand Down
6 changes: 3 additions & 3 deletions shadowsocks/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ func unpack(clientIP net.IP, dst, src []byte, cipherList CipherList) ([]byte, st
type udpService struct {
clientConn net.PacketConn
natTimeout time.Duration
ciphers *CipherList
ciphers CipherList
m metrics.ShadowsocksMetrics
isRunning bool
}

// NewUDPService creates a UDPService
func NewUDPService(clientConn net.PacketConn, natTimeout time.Duration, cipherList *CipherList, m metrics.ShadowsocksMetrics) UDPService {
func NewUDPService(clientConn net.PacketConn, natTimeout time.Duration, cipherList CipherList, m metrics.ShadowsocksMetrics) UDPService {
return &udpService{clientConn: clientConn, natTimeout: natTimeout, ciphers: cipherList, m: m}
}

Expand Down Expand Up @@ -122,7 +122,7 @@ func (s *udpService) Start() {
logger.Debugf("UDP Request from %v with %v bytes", clientAddr, clientProxyBytes)
unpackStart := time.Now()
ip := clientAddr.(*net.UDPAddr).IP
buf, keyID, cipher, err := unpack(ip, textBuf, cipherBuf[:clientProxyBytes], *s.ciphers)
buf, keyID, cipher, err := unpack(ip, textBuf, cipherBuf[:clientProxyBytes], s.ciphers)
timeToCipher = time.Now().Sub(unpackStart)

if err != nil {
Expand Down

0 comments on commit 6fbc31c

Please sign in to comment.