Skip to content

Commit

Permalink
Merge branch 'master' into staging-client
Browse files Browse the repository at this point in the history
  • Loading branch information
rod-hynes committed Aug 22, 2023
2 parents a86d3bd + ae17ec8 commit 3f91b1b
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 75 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ require (
github.com/Psiphon-Labs/bolt v0.0.0-20200624191537-23cedaef7ad7
github.com/Psiphon-Labs/goptlib v0.0.0-20200406165125-c0e32a7a3464
github.com/Psiphon-Labs/quic-go v0.0.0-20230626192210-73f29effc9da
github.com/Psiphon-Labs/tls-tris v0.0.0-20210713133851-676a693d51ad
github.com/Psiphon-Labs/tls-tris v0.0.0-20230821160547-c948ccd6c156
github.com/armon/go-proxyproto v0.0.0-20180202201750-5b7edb60ff5f
github.com/bifurcation/mint v0.0.0-20180306135233-198357931e61
github.com/cheekybits/genny v0.0.0-20170328200008-9127e812e1e9
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ github.com/Psiphon-Labs/quic-go v0.0.0-20230626192210-73f29effc9da h1:TI2+ExyFR3
github.com/Psiphon-Labs/quic-go v0.0.0-20230626192210-73f29effc9da/go.mod h1:wTIxqsKVrEQIxVIIYOEHuscY+PM3h6Wz79u5aF60fo0=
github.com/Psiphon-Labs/tls-tris v0.0.0-20210713133851-676a693d51ad h1:m6HS84+b5xDPLj7D/ya1CeixyaHOCZoMbBilJ48y+Ts=
github.com/Psiphon-Labs/tls-tris v0.0.0-20210713133851-676a693d51ad/go.mod h1:v3y9GXFo9Sf2mO6auD2ExGG7oDgrK8TI7eb49ZnUxrE=
github.com/Psiphon-Labs/tls-tris v0.0.0-20230821160547-c948ccd6c156 h1:TlKg/9XkSlo5AqSJRVkTKIkwy/JXrQD6ybK3PZuAOwE=
github.com/Psiphon-Labs/tls-tris v0.0.0-20230821160547-c948ccd6c156/go.mod h1:v3y9GXFo9Sf2mO6auD2ExGG7oDgrK8TI7eb49ZnUxrE=
github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 h1:w1UutsfOrms1J05zt7ISrnJIXKzwaspym5BTKGx93EI=
github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412/go.mod h1:WPjqKcmVOxf0XSf3YxCJs6N6AOSrOx3obionmG7T0y0=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
Expand Down
167 changes: 108 additions & 59 deletions psiphon/server/demux.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,23 @@ package server
import (
"bytes"
"context"
std_errors "errors"
"net"
"time"

"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
"github.com/sirupsen/logrus"
)

// protocolDemux enables a single listener to support multiple protocols
// by demultiplexing each conn it accepts into the corresponding protocol
// handler.
type protocolDemux struct {
ctx context.Context
cancelFunc context.CancelFunc
innerListener net.Listener
classifiers []protocolClassifier
accept chan struct{}
ctx context.Context
cancelFunc context.CancelFunc
innerListener net.Listener
classifiers []protocolClassifier
connClassificationTimeout time.Duration

conns []chan net.Conn
}
Expand All @@ -58,8 +61,13 @@ type protocolClassifier struct {
// newProtocolDemux returns a newly initialized ProtocolDemux and an
// array of protocol listeners. For each protocol classifier in classifiers
// there will be a corresponding protocol listener at the same index in the
// array of returned protocol listeners.
func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []protocolClassifier) (*protocolDemux, []protoListener) {
// array of returned protocol listeners. If connClassificationTimeout is >0,
// then any conn not classified in this amount of time will be closed.
//
// Limitation: the conn is also closed after reading maxBytesToMatch and
// failing to find a match, which can be a fingerprint for a raw conn with no
// preceding anti-probing measure, such as TLS passthrough.
func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []protocolClassifier, connClassificationTimeout time.Duration) (*protocolDemux, []protoListener) {

ctx, cancelFunc := context.WithCancel(ctx)

Expand All @@ -69,12 +77,12 @@ func newProtocolDemux(ctx context.Context, listener net.Listener, classifiers []
}

p := protocolDemux{
ctx: ctx,
cancelFunc: cancelFunc,
innerListener: listener,
conns: conns,
classifiers: classifiers,
accept: make(chan struct{}, 1),
ctx: ctx,
cancelFunc: cancelFunc,
innerListener: listener,
conns: conns,
classifiers: classifiers,
connClassificationTimeout: connClassificationTimeout,
}

protoListeners := make([]protoListener, len(classifiers))
Expand Down Expand Up @@ -115,78 +123,124 @@ func (mux *protocolDemux) run() error {

for mux.ctx.Err() == nil {

// Accept first conn immediately and then wait for downstream listeners
// to request new conns.
// Accept new conn and spawn a goroutine where it is read until
// either:
// - It matches one of the configured protocols and is sent downstream
// to the corresponding protocol listener
// - It does not match any of the configured protocols, an error
// occurs, or mux.connClassificationTimeout elapses before the conn
// is classified and the conn is closed
// New conns are accepted, and classified, continuously even if the
// downstream consumers are not ready to process them, which could
// result in spawning many goroutines that become blocked until the
// downstream consumers manage to catch up. Although, this scenario
// should be unlikely because the producer - accepting new conns - is
// bounded by network I/O and the consumer is not. Generally, the
// consumer continuously loops accepting new conns, from its
// corresponding protocol listener, and immediately spawns a goroutine
// to handle each new conn after it is accepted.

conn, err := mux.innerListener.Accept()
if err != nil {
if mux.ctx.Err() == nil {
log.WithTraceFields(LogFields{"error": err}).Debug("accept failed")
// TODO: add backoff before continue?
}
continue
}

go func() {

var acc bytes.Buffer
b := make([]byte, readBufferSize)
type classifiedConnResult struct {
index int
acc bytes.Buffer
err error
errLogLevel logrus.Level
}

for mux.ctx.Err() == nil {
resultChannel := make(chan *classifiedConnResult, 2)

n, err := conn.Read(b)
if err != nil {
log.WithTraceFields(LogFields{"error": err}).Debug("read conn failed")
break // conn will be closed
}
var connClassifiedAfterFunc *time.Timer

acc.Write(b[:n])
if mux.connClassificationTimeout > 0 {
connClassifiedAfterFunc = time.AfterFunc(mux.connClassificationTimeout, func() {
resultChannel <- &classifiedConnResult{
err: std_errors.New("conn classification timeout"),
errLogLevel: logrus.DebugLevel,
}
})
}

for i, detector := range mux.classifiers {
go func() {
var acc bytes.Buffer
b := make([]byte, readBufferSize)

if acc.Len() >= detector.minBytesToMatch {
for mux.ctx.Err() == nil {

if detector.match(acc.Bytes()) {
n, err := conn.Read(b)
if err != nil {
resultChannel <- &classifiedConnResult{
err: errors.TraceMsg(err, "read conn failed"),
errLogLevel: logrus.DebugLevel,
}
return
}

// Found a match, replay buffered bytes in new conn
// and downstream.
go func() {
bConn := newBufferedConn(conn, acc)
select {
case mux.conns[i] <- bConn:
case <-mux.ctx.Done():
bConn.Close()
}
}()
acc.Write(b[:n])

for i, classifier := range mux.classifiers {
if acc.Len() >= classifier.minBytesToMatch && classifier.match(acc.Bytes()) {
resultChannel <- &classifiedConnResult{
index: i,
acc: acc,
}
return
}
}

if maxBytesToMatch != 0 && acc.Len() > maxBytesToMatch {
// No match. Sample does not match any classifier and is
// longer than required by each.
resultChannel <- &classifiedConnResult{
err: std_errors.New("no classifier match for conn"),
errLogLevel: logrus.WarnLevel,
}
return
}
}

if maxBytesToMatch != 0 && acc.Len() > maxBytesToMatch {
resultChannel <- &classifiedConnResult{
err: mux.ctx.Err(),
errLogLevel: logrus.DebugLevel,
}
}()

// No match. Sample does not match any detector and is
// longer than required by each.
log.WithTrace().Warning("no detector match for conn")
result := <-resultChannel

break // conn will be closed
if connClassifiedAfterFunc != nil {
connClassifiedAfterFunc.Stop()
}

if result.err != nil {
log.WithTraceFields(LogFields{"error": result.err}).Log(result.errLogLevel, "conn classification failed")

err := conn.Close()
if err != nil {
log.WithTraceFields(LogFields{"error": err}).Debug("close failed")
}
return
}

// cleanup conn
err := conn.Close()
if err != nil {
log.WithTraceFields(LogFields{"error": err}).Debug("close conn failed")
// Found a match, replay buffered bytes in new conn and send
// downstream.
// TODO: subtract the time it took to classify the conn from the
// subsequent SSH handshake timeout (sshHandshakeTimeout).
bConn := newBufferedConn(conn, result.acc)
select {
case mux.conns[result.index] <- bConn:
case <-mux.ctx.Done():
bConn.Close()
}
}()

// Wait for one of the downstream listeners to request another conn.
select {
case <-mux.accept:
case <-mux.ctx.Done():
return mux.ctx.Err()
}
}

return mux.ctx.Err()
Expand All @@ -199,11 +253,6 @@ func (mux *protocolDemux) acceptForIndex(index int) (net.Conn, error) {
for mux.ctx.Err() == nil {
select {
case conn := <-mux.conns[index]:
// trigger another accept
select {
case mux.accept <- struct{}{}:
default: // don't block when a signal is already buffered
}
return conn, nil
case <-mux.ctx.Done():
return nil, errors.Trace(mux.ctx.Err())
Expand Down
2 changes: 1 addition & 1 deletion psiphon/server/demux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func runProtocolDemuxTest(tt *protocolDemuxTest) error {
}
}()

mux, protoListeners := newProtocolDemux(context.Background(), l, tt.classifiers)
mux, protoListeners := newProtocolDemux(context.Background(), l, tt.classifiers, 0)

errs := make([]chan error, len(protoListeners))
for i := range errs {
Expand Down
7 changes: 6 additions & 1 deletion psiphon/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1652,7 +1652,12 @@ func checkExpectedServerTunnelLogFields(
return fmt.Errorf("unexpected host_id '%s'", fields["host_id"])
}

if fields["relay_protocol"].(string) != runConfig.tunnelProtocol {
expectedRelayProtocol := runConfig.tunnelProtocol
if runConfig.clientTunnelProtocol != "" {
expectedRelayProtocol = runConfig.clientTunnelProtocol
}

if fields["relay_protocol"].(string) != expectedRelayProtocol {
return fmt.Errorf("unexpected relay_protocol '%s'", fields["relay_protocol"])
}

Expand Down
3 changes: 3 additions & 0 deletions psiphon/server/tlsTunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ type TLSTunnelServer struct {
obfuscatorSeedHistory *obfuscator.SeedHistory
}

// ListenTLSTunnel returns the listener of a new TLSTunnelServer.
// Note: the first Read or Write call on a connection returned by the listener
// will trigger the underlying TLS handshake.
func ListenTLSTunnel(
support *SupportServices,
listener net.Listener,
Expand Down
33 changes: 22 additions & 11 deletions psiphon/server/tunnelServer.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,12 +530,7 @@ func (sshServer *sshServer) runListener(sshListener *sshListener, listenerError

if protocol.TunnelProtocolUsesMeekHTTP(sshListener.tunnelProtocol) || protocol.TunnelProtocolUsesMeekHTTPS(sshListener.tunnelProtocol) {

_, passthroughEnabled := sshServer.support.Config.TunnelProtocolPassthroughAddresses[sshListener.tunnelProtocol]

// Only use meek/TLS-OSSH demux if unfronted meek HTTPS with non-legacy passthrough.
useTLSDemux := protocol.TunnelProtocolUsesMeekHTTPS(sshListener.tunnelProtocol) && !protocol.TunnelProtocolUsesFrontedMeek(sshListener.tunnelProtocol) && passthroughEnabled && !sshServer.support.Config.LegacyPassthrough

if useTLSDemux {
if sshServer.tunnelProtocolUsesTLSDemux(sshListener.tunnelProtocol) {

sshServer.runMeekTLSOSSHDemuxListener(sshListener, listenerError, handleClient)

Expand Down Expand Up @@ -567,7 +562,7 @@ func (sshServer *sshServer) runListener(sshListener *sshListener, listenerError

} else {

runListener(sshListener.Listener, sshServer.shutdownBroadcast, listenerError, handleClient)
runListener(sshListener.Listener, sshServer.shutdownBroadcast, listenerError, "", handleClient)
}
}

Expand Down Expand Up @@ -606,7 +601,7 @@ func (sshServer *sshServer) runMeekTLSOSSHDemuxListener(sshListener *sshListener
return
}

mux, listeners := newProtocolDemux(context.Background(), listener, []protocolClassifier{meekClassifier, tlsClassifier})
mux, listeners := newProtocolDemux(context.Background(), listener, []protocolClassifier{meekClassifier, tlsClassifier}, sshServer.support.Config.sshHandshakeTimeout)

var wg sync.WaitGroup

Expand Down Expand Up @@ -651,7 +646,8 @@ func (sshServer *sshServer) runMeekTLSOSSHDemuxListener(sshListener *sshListener

defer wg.Done()

runListener(listeners[1], sshServer.shutdownBroadcast, listenerError, handleClient)
// Override the listener tunnel protocol to report TLS-OSSH instead.
runListener(listeners[1], sshServer.shutdownBroadcast, listenerError, protocol.TUNNEL_PROTOCOL_TLS_OBFUSCATED_SSH, handleClient)
}()

wg.Add(1)
Expand Down Expand Up @@ -691,7 +687,7 @@ func (sshServer *sshServer) runMeekTLSOSSHDemuxListener(sshListener *sshListener
wg.Wait()
}

func runListener(listener net.Listener, shutdownBroadcast <-chan struct{}, listenerError chan<- error, handleClient func(clientTunnelProtocol string, clientConn net.Conn)) {
func runListener(listener net.Listener, shutdownBroadcast <-chan struct{}, listenerError chan<- error, overrideTunnelProtocol string, handleClient func(clientTunnelProtocol string, clientConn net.Conn)) {
for {
conn, err := listener.Accept()

Expand All @@ -718,7 +714,7 @@ func runListener(listener net.Listener, shutdownBroadcast <-chan struct{}, liste
return
}

handleClient("", conn)
handleClient(overrideTunnelProtocol, conn)
}
}

Expand Down Expand Up @@ -921,6 +917,10 @@ func (sshServer *sshServer) getLoadStats() (
stats["ALL"] = zeroClientStats()
for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
stats[tunnelProtocol] = zeroClientStats()

if sshServer.tunnelProtocolUsesTLSDemux(tunnelProtocol) {
stats[protocol.TUNNEL_PROTOCOL_TLS_OBFUSCATED_SSH] = zeroClientStats()
}
}
return stats
}
Expand Down Expand Up @@ -1542,6 +1542,17 @@ func (sshServer *sshServer) monitorPortForwardDialError(err error) {
}
}

// tunnelProtocolUsesTLSDemux returns true if the server demultiplexes the given
// protocol and TLS-OSSH over the same port.
func (sshServer *sshServer) tunnelProtocolUsesTLSDemux(tunnelProtocol string) bool {
// Only use meek/TLS-OSSH demux if unfronted meek HTTPS with non-legacy passthrough.
if protocol.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) && !protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
_, passthroughEnabled := sshServer.support.Config.TunnelProtocolPassthroughAddresses[tunnelProtocol]
return passthroughEnabled && !sshServer.support.Config.LegacyPassthrough
}
return false
}

type sshClient struct {
sync.Mutex
sshServer *sshServer
Expand Down
8 changes: 8 additions & 0 deletions psiphon/tlsTunnelConn.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,11 @@ func (conn *TLSTunnelConn) GetMetrics() common.LogFields {
}
return logFields
}

func (conn *TLSTunnelConn) IsClosed() bool {
closer, ok := conn.Conn.(common.Closer)
if !ok {
return false
}
return closer.IsClosed()
}
Loading

0 comments on commit 3f91b1b

Please sign in to comment.