diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go new file mode 100644 index 00000000..5cc4c34c --- /dev/null +++ b/cmd/outline-ss-server/config.go @@ -0,0 +1,116 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "errors" + "fmt" + "net" + + "gopkg.in/yaml.v3" +) + +type ServiceConfig struct { + Listeners []ListenerConfig + Keys []KeyConfig +} + +type ListenerType string + +const ( + listenerTypeTCP ListenerType = "tcp" + listenerTypeUDP ListenerType = "udp" + listenerTypeProxy ListenerType = "proxy_protocol" +) + +type ListenerConfig struct { + Type ListenerType + Address string + Listeners []ListenerConfig +} + +// Validate checks that the config is valid. +func (lc *ListenerConfig) Validate() error { + if lc.Type != listenerTypeTCP && lc.Type != listenerTypeUDP && lc.Type != listenerTypeProxy { + return fmt.Errorf("unsupported listener type: %s", lc.Type) + } + if lc.Address != "" && len(lc.Listeners) > 0 { + return errors.New("cannot specify both `listeners` and `address` on a listener type") + } + if len(lc.Listeners) > 0 { + for _, childLnConfig := range lc.Listeners { + if err := childLnConfig.Validate(); err != nil { + return err + } + } + return nil + } + + host, _, err := net.SplitHostPort(lc.Address) + if err != nil { + return fmt.Errorf("invalid listener address `%s`: %v", lc.Address, err) + } + if ip := net.ParseIP(host); ip == nil { + return fmt.Errorf("address must be IP, found: %s", host) + } + return nil +} + +type KeyConfig struct { + ID string + Cipher string + Secret string +} + +type LegacyKeyServiceConfig struct { + KeyConfig `yaml:",inline"` + Port int +} + +type Config struct { + Services []ServiceConfig + + // Deprecated: `keys` exists for backward compatibility. Prefer to configure + // using the newer `services` format. + Keys []LegacyKeyServiceConfig +} + +// Validate checks that the config is valid. +func (c *Config) Validate() error { + existingListeners := make(map[string]bool) + for _, serviceConfig := range c.Services { + for _, lnConfig := range serviceConfig.Listeners { + key := string(lnConfig.Type) + "/" + lnConfig.Address + if _, exists := existingListeners[key]; exists { + return fmt.Errorf("listener of type %s with address %s already exists.", lnConfig.Type, lnConfig.Address) + } + existingListeners[key] = true + + if err := lnConfig.Validate(); err != nil { + return err + } + } + } + return nil +} + +// readConfig attempts to read a config from a filename and parses it as a [Config]. +func readConfig(configData []byte) (*Config, error) { + config := Config{} + if err := yaml.Unmarshal(configData, &config); err != nil { + return nil, fmt.Errorf("failed to parse config: %w", err) + } + return &config, nil +} diff --git a/cmd/outline-ss-server/config_example.deprecated.yml b/cmd/outline-ss-server/config_example.deprecated.yml new file mode 100644 index 00000000..8895b86d --- /dev/null +++ b/cmd/outline-ss-server/config_example.deprecated.yml @@ -0,0 +1,15 @@ +keys: + - id: user-0 + port: 9000 + cipher: chacha20-ietf-poly1305 + secret: Secret0 + + - id: user-1 + port: 9000 + cipher: chacha20-ietf-poly1305 + secret: Secret1 + + - id: user-2 + port: 9001 + cipher: chacha20-ietf-poly1305 + secret: Secret2 diff --git a/cmd/outline-ss-server/config_example.yml b/cmd/outline-ss-server/config_example.yml index 8895b86d..b0bba3ed 100644 --- a/cmd/outline-ss-server/config_example.yml +++ b/cmd/outline-ss-server/config_example.yml @@ -1,15 +1,31 @@ -keys: - - id: user-0 - port: 9000 - cipher: chacha20-ietf-poly1305 - secret: Secret0 +services: + - listeners: + # TODO(sbruens): Allow a string-based listener config, as a convenient short-form + # to create a direct listener, e.g. `- tcp/[::]:9000`. + - type: tcp + address: "[::]:9000" + - type: udp + address: "[::]:9000" + - type: proxy_protocol + listeners: + - type: tcp + address: "[::]:9010" + - type: udp + address: "[::]:9010" + keys: + - id: user-0 + cipher: chacha20-ietf-poly1305 + secret: Secret0 + - id: user-1 + cipher: chacha20-ietf-poly1305 + secret: Secret1 - - id: user-1 - port: 9000 - cipher: chacha20-ietf-poly1305 - secret: Secret1 - - - id: user-2 - port: 9001 - cipher: chacha20-ietf-poly1305 - secret: Secret2 + - listeners: + - type: tcp + address: "[::]:9001" + - type: udp + address: "[::]:9001" + keys: + - id: user-2 + cipher: chacha20-ietf-poly1305 + secret: Secret2 diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go new file mode 100644 index 00000000..b2d60c2a --- /dev/null +++ b/cmd/outline-ss-server/config_test.go @@ -0,0 +1,174 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidateConfigFails(t *testing.T) { + tests := []struct { + name string + cfg *Config + }{ + { + name: "WithUnknownListenerType", + cfg: &Config{ + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: "foo", Address: "[::]:9000"}, + }, + }, + }, + }, + }, + { + name: "WithInvalidListenerAddress", + cfg: &Config{ + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeTCP, Address: "tcp/[::]:9000"}, + }, + }, + }, + }, + }, + { + name: "WithHostnameAddress", + cfg: &Config{ + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeTCP, Address: "example.com:9000"}, + }, + }, + }, + }, + }, + { + name: "WithDuplicateListeners", + cfg: &Config{ + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9000"}, + }, + }, + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9000"}, + }, + }, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.cfg.Validate() + require.Error(t, err) + }) + } +} + +func TestReadConfig(t *testing.T) { + config, err := readConfigFile("./config_example.yml") + + require.NoError(t, err) + expected := Config{ + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9000"}, + ListenerConfig{Type: listenerTypeUDP, Address: "[::]:9000"}, + ListenerConfig{ + Type: listenerTypeProxy, + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9010"}, + ListenerConfig{Type: listenerTypeUDP, Address: "[::]:9010"}, + }, + }, + }, + Keys: []KeyConfig{ + KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, + KeyConfig{"user-1", "chacha20-ietf-poly1305", "Secret1"}, + }, + }, + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9001"}, + ListenerConfig{Type: listenerTypeUDP, Address: "[::]:9001"}, + }, + Keys: []KeyConfig{ + KeyConfig{"user-2", "chacha20-ietf-poly1305", "Secret2"}, + }, + }, + }, + } + require.Equal(t, expected, *config) +} + +func TestReadConfigParsesDeprecatedFormat(t *testing.T) { + config, err := readConfigFile("./config_example.deprecated.yml") + + require.NoError(t, err) + expected := Config{ + Keys: []LegacyKeyServiceConfig{ + LegacyKeyServiceConfig{ + KeyConfig: KeyConfig{ID: "user-0", Cipher: "chacha20-ietf-poly1305", Secret: "Secret0"}, + Port: 9000, + }, + LegacyKeyServiceConfig{ + KeyConfig: KeyConfig{ID: "user-1", Cipher: "chacha20-ietf-poly1305", Secret: "Secret1"}, + Port: 9000, + }, + LegacyKeyServiceConfig{ + KeyConfig: KeyConfig{ID: "user-2", Cipher: "chacha20-ietf-poly1305", Secret: "Secret2"}, + Port: 9001, + }, + }, + } + require.Equal(t, expected, *config) +} + +func TestReadConfigFromEmptyFile(t *testing.T) { + file, _ := os.CreateTemp("", "empty.yaml") + + config, err := readConfigFile(file.Name()) + + require.NoError(t, err) + require.ElementsMatch(t, Config{}, config) +} + +func TestReadConfigFromIncorrectFormatFails(t *testing.T) { + file, _ := os.CreateTemp("", "empty.yaml") + file.WriteString("foo") + + config, err := readConfigFile(file.Name()) + + require.Error(t, err) + require.ElementsMatch(t, Config{}, config) +} + +func readConfigFile(filename string) (*Config, error) { + configData, _ := os.ReadFile(filename) + return readConfig(configData) +} diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index e73506a8..75e31db9 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -22,7 +22,9 @@ import ( "net/http" "os" "os/signal" + "strconv" "strings" + "sync" "syscall" "time" @@ -33,7 +35,6 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/term" - "gopkg.in/yaml.v2" ) var logger *logging.Logger @@ -58,129 +59,320 @@ func init() { logger = logging.MustGetLogger("") } -type ssPort struct { - tcpListener *net.TCPListener - packetConn net.PacketConn - cipherList service.CipherList -} - type SSServer struct { + stopConfig func() error + lnManager service.ListenerManager natTimeout time.Duration m *outlineMetrics replayCache service.ReplayCache - ports map[int]*ssPort } -func (s *SSServer) startPort(portNum int) error { - listener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: portNum}) +func (s *SSServer) loadConfig(filename string) error { + configData, err := os.ReadFile(filename) if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks TCP service failed to start on port %v: %w", portNum, err) + return fmt.Errorf("failed to read config file %s: %w", filename, err) } - logger.Infof("Shadowsocks TCP service listening on %v", listener.Addr().String()) - packetConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: portNum}) + config, err := readConfig(configData) if err != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks UDP service failed to start on port %v: %w", portNum, err) + return fmt.Errorf("failed to load config (%v): %w", filename, err) } - logger.Infof("Shadowsocks UDP service listening on %v", packetConn.LocalAddr().String()) - port := &ssPort{tcpListener: listener, packetConn: packetConn, cipherList: service.NewCipherList()} - authFunc := service.NewShadowsocksStreamAuthenticator(port.cipherList, &s.replayCache, s.m) - // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout) - packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m) - s.ports[portNum] = port - go service.StreamServe(service.WrapStreamListener(listener.AcceptTCP), tcpHandler.Handle) - go packetHandler.Handle(port.packetConn) - return nil -} - -func (s *SSServer) removePort(portNum int) error { - port, ok := s.ports[portNum] - if !ok { - return fmt.Errorf("port %v doesn't exist", portNum) + if err := config.Validate(); err != nil { + return fmt.Errorf("failed to validate config: %w", err) } - tcpErr := port.tcpListener.Close() - udpErr := port.packetConn.Close() - delete(s.ports, portNum) - if tcpErr != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks TCP service on port %v failed to stop: %w", portNum, tcpErr) + + // We hot swap the config by having the old and new listeners both live at + // the same time. This means we create listeners for the new config first, + // and then close the old ones after. + sopConfig, err := s.runConfig(*config) + if err != nil { + return err } - logger.Infof("Shadowsocks TCP service on port %v stopped", portNum) - if udpErr != nil { - //lint:ignore ST1005 Shadowsocks is capitalized. - return fmt.Errorf("Shadowsocks UDP service on port %v failed to stop: %w", portNum, udpErr) + if err := s.Stop(); err != nil { + return fmt.Errorf("unable to stop old config: %v", err) } - logger.Infof("Shadowsocks UDP service on port %v stopped", portNum) + s.stopConfig = sopConfig return nil } -func (s *SSServer) loadConfig(filename string) error { - config, err := readConfig(filename) - if err != nil { - return fmt.Errorf("failed to load config (%v): %w", filename, err) +func newCipherListFromConfig(config ServiceConfig) (service.CipherList, error) { + type cipherKey struct { + cipher string + secret string } - - portChanges := make(map[int]int) - portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. + cipherList := list.New() + existingCiphers := make(map[cipherKey]bool) for _, keyConfig := range config.Keys { - portChanges[keyConfig.Port] = 1 - cipherList, ok := portCiphers[keyConfig.Port] - if !ok { - cipherList = list.New() - portCiphers[keyConfig.Port] = cipherList + key := cipherKey{keyConfig.Cipher, keyConfig.Secret} + if _, exists := existingCiphers[key]; exists { + logger.Debugf("encryption key already exists for ID=`%v`. Skipping.", keyConfig.ID) + continue } cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) if err != nil { - return fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) + return nil, fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) } entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) cipherList.PushBack(&entry) + existingCiphers[key] = true + } + ciphers := service.NewCipherList() + ciphers.Update(cipherList) + + return ciphers, nil +} + +func (s *SSServer) NewShadowsocksStreamHandler(ciphers service.CipherList) service.StreamHandler { + authFunc := service.NewShadowsocksStreamAuthenticator(ciphers, &s.replayCache, s.m) + // TODO: Register initial data metrics at zero. + return service.NewStreamHandler(authFunc, s.m, tcpReadTimeout) +} + +func (s *SSServer) NewShadowsocksPacketHandler(ciphers service.CipherList) service.PacketHandler { + return service.NewPacketHandler(s.natTimeout, ciphers, s.m) +} + +func (s *SSServer) NewShadowsocksStreamHandlerFromConfig(config ServiceConfig) (service.StreamHandler, error) { + ciphers, err := newCipherListFromConfig(config) + if err != nil { + return nil, err + } + return s.NewShadowsocksStreamHandler(ciphers), nil +} + +func (s *SSServer) NewShadowsocksPacketHandlerFromConfig(config ServiceConfig) (service.PacketHandler, error) { + ciphers, err := newCipherListFromConfig(config) + if err != nil { + return nil, err + } + return s.NewShadowsocksPacketHandler(ciphers), nil +} + +type listenerSet struct { + manager service.ListenerManager + listenerCloseFuncs map[string]func() error + listenersMu sync.Mutex +} + +// ListenStream announces on a given network address. Trying to listen for stream connections +// on the same address twice will result in an error. +func (ls *listenerSet) ListenStream(addr string, proxy bool) (service.StreamListener, error) { + ls.listenersMu.Lock() + defer ls.listenersMu.Unlock() + + lnKey := "stream/" + addr + if _, exists := ls.listenerCloseFuncs[lnKey]; exists { + return nil, fmt.Errorf("stream listener for %s already exists", addr) + } + ln, err := ls.manager.ListenStream(addr, proxy) + if err != nil { + return nil, err + } + ls.listenerCloseFuncs[lnKey] = ln.Close + return ln, nil +} + +// ListenPacket announces on a given network address. Trying to listen for packet connections +// on the same address twice will result in an error. +func (ls *listenerSet) ListenPacket(addr string, proxy bool) (net.PacketConn, error) { + ls.listenersMu.Lock() + defer ls.listenersMu.Unlock() + + lnKey := "packet/" + addr + if _, exists := ls.listenerCloseFuncs[lnKey]; exists { + return nil, fmt.Errorf("packet listener for %s already exists", addr) } - for port := range s.ports { - portChanges[port] = portChanges[port] - 1 + ln, err := ls.manager.ListenPacket(addr, proxy) + if err != nil { + return nil, err } - for portNum, count := range portChanges { - if count == -1 { - if err := s.removePort(portNum); err != nil { - return fmt.Errorf("failed to remove port %v: %w", portNum, err) + ls.listenerCloseFuncs[lnKey] = ln.Close + return ln, nil +} + +// Close closes all the listeners in the set, after which the set can't be used again. +func (ls *listenerSet) Close() error { + ls.listenersMu.Lock() + defer ls.listenersMu.Unlock() + + for addr, listenerCloseFunc := range ls.listenerCloseFuncs { + if err := listenerCloseFunc(); err != nil { + return fmt.Errorf("listener on address %s failed to stop: %w", addr, err) + } + } + ls.listenerCloseFuncs = nil + return nil +} + +// Len returns the number of listeners in the set. +func (ls *listenerSet) Len() int { + return len(ls.listenerCloseFuncs) +} + +func (s *SSServer) runConfig(config Config) (func() error, error) { + startErrCh := make(chan error) + stopErrCh := make(chan error) + stopCh := make(chan struct{}) + + go func() { + lnSet := &listenerSet{ + manager: s.lnManager, + listenerCloseFuncs: make(map[string]func() error), + } + defer func() { + stopErrCh <- lnSet.Close() + }() + + startErrCh <- func() error { + totalCipherCount := len(config.Keys) + portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. + for _, keyConfig := range config.Keys { + cipherList, ok := portCiphers[keyConfig.Port] + if !ok { + cipherList = list.New() + portCiphers[keyConfig.Port] = cipherList + } + cryptoKey, err := shadowsocks.NewEncryptionKey(keyConfig.Cipher, keyConfig.Secret) + if err != nil { + return fmt.Errorf("failed to create encyption key for key %v: %w", keyConfig.ID, err) + } + entry := service.MakeCipherEntry(keyConfig.ID, cryptoKey, keyConfig.Secret) + cipherList.PushBack(&entry) } - } else if count == +1 { - if err := s.startPort(portNum); err != nil { - return err + for portNum, cipherList := range portCiphers { + addr := net.JoinHostPort("::", strconv.Itoa(portNum)) + + ciphers := service.NewCipherList() + ciphers.Update(cipherList) + + sh := s.NewShadowsocksStreamHandler(ciphers) + ln, err := lnSet.ListenStream(addr, false) + if err != nil { + return err + } + logger.Infof("Shadowsocks service listening on tcp/%s", ln.Addr().String()) + go service.StreamServe(ln.AcceptStream, sh.Handle) + + pc, err := lnSet.ListenPacket(addr, false) + if err != nil { + return err + } + logger.Infof("Shadowsocks service listening on udp/%s", pc.LocalAddr().String()) + ph := s.NewShadowsocksPacketHandler(ciphers) + go ph.Handle(pc) + } + + for _, serviceConfig := range config.Services { + var ( + sh service.StreamHandler + ph service.PacketHandler + ) + for _, lnConfig := range serviceConfig.Listeners { + err := s.startListenerFromConfig(lnSet, serviceConfig, lnConfig, false, sh, ph) + if err != nil { + return err + } + } + totalCipherCount += len(serviceConfig.Keys) } + + logger.Infof("Loaded %d access keys over %d listeners", totalCipherCount, lnSet.Len()) + s.m.SetNumAccessKeys(totalCipherCount, lnSet.Len()) + return nil + }() + + <-stopCh + }() + + err := <-startErrCh + if err != nil { + return nil, err + } + return func() error { + logger.Infof("Stopping running config.") + // TODO(sbruens): Actually wait for all handlers to be stopped, e.g. by + // using a https://pkg.go.dev/sync#WaitGroup. + stopCh <- struct{}{} + stopErr := <-stopErrCh + return stopErr + }, nil +} + +func (s *SSServer) startListenerFromConfig(lnSet *listenerSet, serviceConfig ServiceConfig, lnConfig ListenerConfig, proxy bool, sh service.StreamHandler, ph service.PacketHandler) error { + lnLogFunc := func(key string) { + var serviceToLog string + if proxy { + serviceToLog = "Proxy" + } else { + serviceToLog = "Shadowsocks" } + logger.Infof("%s service listening on %s", serviceToLog, key) } - for portNum, cipherList := range portCiphers { - s.ports[portNum].cipherList.Update(cipherList) + switch lnConfig.Type { + case listenerTypeTCP: + ln, err := lnSet.ListenStream(lnConfig.Address, proxy) + if err != nil { + return err + } + lnLogFunc("tcp/" + ln.Addr().String()) + if sh == nil { + sh, err = s.NewShadowsocksStreamHandlerFromConfig(serviceConfig) + if err != nil { + return err + } + } + go service.StreamServe(ln.AcceptStream, sh.Handle) + + case listenerTypeUDP: + pc, err := lnSet.ListenPacket(lnConfig.Address, proxy) + if err != nil { + return err + } + lnLogFunc("udp/" + pc.LocalAddr().String()) + if ph == nil { + ph, err = s.NewShadowsocksPacketHandlerFromConfig(serviceConfig) + if err != nil { + return err + } + } + go ph.Handle(pc) + + case listenerTypeProxy: + for _, proxyLnConfig := range lnConfig.Listeners { + err := s.startListenerFromConfig(lnSet, serviceConfig, proxyLnConfig, true, sh, ph) + if err != nil { + return err + } + } } - logger.Infof("Loaded %v access keys over %v ports", len(config.Keys), len(s.ports)) - s.m.SetNumAccessKeys(len(config.Keys), len(portCiphers)) + return nil } -// Stop serving on all ports. +// Stop stops serving the current config. func (s *SSServer) Stop() error { - for portNum := range s.ports { - if err := s.removePort(portNum); err != nil { - return err - } + stopFunc := s.stopConfig + if stopFunc == nil { + return nil + } + if err := stopFunc(); err != nil { + logger.Errorf("Error stopping config: %v", err) + return err } + logger.Info("Stopped all listeners for running config") return nil } // RunSSServer starts a shadowsocks server running, and returns the server or an error. func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, replayHistory int) (*SSServer, error) { server := &SSServer{ + lnManager: service.NewListenerManager(), natTimeout: natTimeout, m: sm, replayCache: service.NewReplayCache(replayHistory), - ports: make(map[int]*ssPort), } err := server.loadConfig(filename) if err != nil { - return nil, fmt.Errorf("failed configure server: %w", err) + return nil, fmt.Errorf("failed to configure server: %w", err) } sigHup := make(chan os.Signal, 1) signal.Notify(sigHup, syscall.SIGHUP) @@ -195,28 +387,6 @@ func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, return server, nil } -type Config struct { - Keys []struct { - ID string - Port int - Cipher string - Secret string - } -} - -func readConfig(filename string) (*Config, error) { - config := Config{} - configData, err := os.ReadFile(filename) - if err != nil { - return nil, fmt.Errorf("failed to read config: %w", err) - } - err = yaml.Unmarshal(configData, &config) - if err != nil { - return nil, fmt.Errorf("failed to parse config: %w", err) - } - return &config, nil -} - func main() { var flags struct { ConfigFile string diff --git a/go.mod b/go.mod index 04a9ddab..b245aace 100644 --- a/go.mod +++ b/go.mod @@ -7,12 +7,13 @@ require ( github.com/goreleaser/goreleaser v1.18.2 github.com/op/go-logging v0.0.0-20160315200505-970db520ece7 github.com/oschwald/geoip2-golang v1.8.0 + github.com/pires/go-proxyproto v0.7.0 github.com/prometheus/client_golang v1.15.0 github.com/shadowsocks/go-shadowsocks2 v0.1.5 github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.17.0 golang.org/x/term v0.16.0 - gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -263,7 +264,7 @@ require ( gopkg.in/src-d/go-billy.v4 v4.3.2 // indirect gopkg.in/src-d/go-git.v4 v4.13.1 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect k8s.io/klog/v2 v2.90.0 // indirect mvdan.cc/sh/v3 v3.7.0 // indirect sigs.k8s.io/kind v0.17.0 // indirect diff --git a/go.sum b/go.sum index 213e3f6d..0bf21fcf 100644 --- a/go.sum +++ b/go.sum @@ -1906,6 +1906,8 @@ github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha github.com/performancecopilot/speed/v4 v4.0.0/go.mod h1:qxrSyuDGrTOWfV+uKRFhfxw6h/4HXRGUiZiufxo49BM= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= +github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs= +github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4= github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4= github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI= github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4/go.mod h1:N6UoU20jOqggOuDwUaBQpluzLNDqif3kq9z2wpdYEfQ= diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index 43109b7a..c2ef215f 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -133,11 +133,11 @@ func TestTCPEcho(t *testing.T) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := service.NewStreamHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { - service.StreamServe(func() (transport.StreamConn, error) { return proxyListener.AcceptTCP() }, handler.Handle) + service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -202,10 +202,10 @@ func TestRestrictedAddresses(t *testing.T) { const testTimeout = 200 * time.Millisecond testMetrics := &statusMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := service.NewStreamHandler(authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamAcceptFunc(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -384,11 +384,11 @@ func BenchmarkTCPThroughput(b *testing.B) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := service.NewStreamHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamAcceptFunc(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -448,11 +448,11 @@ func BenchmarkTCPMultiplexing(b *testing.B) { const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPMetrics{} authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := service.NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := service.NewStreamHandler(authFunc, testMetrics, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) go func() { - service.StreamServe(service.WrapStreamListener(proxyListener.AcceptTCP), handler.Handle) + service.StreamServe(service.WrapStreamAcceptFunc(proxyListener.AcceptTCP), handler.Handle) done <- struct{}{} }() diff --git a/service/listeners.go b/service/listeners.go new file mode 100644 index 00000000..25974a94 --- /dev/null +++ b/service/listeners.go @@ -0,0 +1,442 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + + "github.com/Jigsaw-Code/outline-sdk/transport" +) + +// The implementations of listeners for different network types are not +// interchangeable. The type of listener depends on the network type. +type Listener = io.Closer + +// ClientStreamConn wraps a [transport.StreamConn] and sets the client source of the connection. +// This is useful for handling the PROXY protocol where the RemoteAddr() points to the +// server/load balancer address and we need the perceived source of the connection. +type ClientStreamConn interface { + transport.StreamConn + ClientAddr() net.Addr +} + +type clientStreamConn struct { + transport.StreamConn + clientAddr net.Addr +} + +func (c *clientStreamConn) ClientAddr() net.Addr { + return c.clientAddr +} + +// StreamListener is a network listener for stream-oriented protocols that +// accepts [transport.StreamConn] connections. +type StreamListener interface { + // Accept waits for and returns the next connection to the listener. + AcceptStream() (ClientStreamConn, error) + + // Close closes the listener. + // Any blocked Accept operations will be unblocked and return errors. This + // stops the current listener from accepting new connections without closing + // the underlying socket. Only when the last user of the underlying socket + // closes it, do we actually close it. + Close() error + + // Addr returns the listener's network address. + Addr() net.Addr +} + +type TCPListener struct { + ln *net.TCPListener +} + +var _ StreamListener = (*TCPListener)(nil) + +func (t *TCPListener) AcceptStream() (ClientStreamConn, error) { + conn, err := t.ln.AcceptTCP() + if err != nil { + return nil, err + } + return &clientStreamConn{StreamConn: conn, clientAddr: conn.RemoteAddr()}, err +} + +func (t *TCPListener) Close() error { + return t.ln.Close() +} + +func (t *TCPListener) Addr() net.Addr { + return t.ln.Addr() +} + +type OnCloseFunc func() error + +type acceptResponse struct { + conn ClientStreamConn + err error +} + +type virtualStreamListener struct { + mu sync.Mutex // Mutex to protect access to the channels + addr net.Addr + acceptCh <-chan acceptResponse + closeCh chan struct{} + closed bool + onCloseFunc OnCloseFunc +} + +var _ StreamListener = (*virtualStreamListener)(nil) + +func (sl *virtualStreamListener) AcceptStream() (ClientStreamConn, error) { + sl.mu.Lock() + acceptCh := sl.acceptCh + sl.mu.Unlock() + + select { + case acceptResponse, ok := <-acceptCh: + if !ok { + return nil, net.ErrClosed + } + return acceptResponse.conn, acceptResponse.err + case <-sl.closeCh: + return nil, net.ErrClosed + } +} + +func (sl *virtualStreamListener) Close() error { + sl.mu.Lock() + if sl.closed { + sl.mu.Unlock() + return nil + } + sl.closed = true + sl.acceptCh = nil + close(sl.closeCh) + sl.mu.Unlock() + + if sl.onCloseFunc != nil { + return sl.onCloseFunc() + } + return nil +} + +func (sl *virtualStreamListener) Addr() net.Addr { + return sl.addr +} + +type packetResponse struct { + n int + addr net.Addr + err error + data []byte +} + +type virtualPacketConn struct { + net.PacketConn + mu sync.Mutex // Mutex to protect access to the channels + readCh <-chan packetResponse + closeCh chan struct{} + closed bool + onCloseFunc OnCloseFunc +} + +func (pc *virtualPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + pc.mu.Lock() + readCh := pc.readCh + pc.mu.Unlock() + + select { + case packetResponse, ok := <-readCh: + if !ok { + return 0, nil, net.ErrClosed + } + copy(p, packetResponse.data) + return packetResponse.n, packetResponse.addr, packetResponse.err + case <-pc.closeCh: + return 0, nil, net.ErrClosed + } +} + +func (pc *virtualPacketConn) Close() error { + pc.mu.Lock() + if pc.closed { + pc.mu.Unlock() + return nil + } + pc.closed = true + pc.readCh = nil + close(pc.closeCh) + pc.mu.Unlock() + + if pc.onCloseFunc != nil { + return pc.onCloseFunc() + } + return nil +} + +// MultiListener manages shared listeners. +type MultiListener[T Listener] interface { + // Acquire creates a new listener from the shared listener. Listeners can overlap + // one another (e.g. during config changes the new config is started before the + // old config is destroyed), which is done by creating virtual listeners that wrap + // the shared listener. These virtual listeners do not actually close the + // underlying socket until all uses of the shared listener have been closed. + Acquire() (T, error) +} + +type multiStreamListener struct { + mu sync.Mutex + addr string + proxy bool + ln RefCount[StreamListener] + acceptCh chan acceptResponse + onCloseFunc OnCloseFunc +} + +// NewMultiStreamListener creates a new stream-based [MultiListener]. +func NewMultiStreamListener(addr string, proxy bool, onCloseFunc OnCloseFunc) MultiListener[StreamListener] { + return &multiStreamListener{ + addr: addr, + proxy: proxy, + onCloseFunc: onCloseFunc, + } +} + +func (m *multiStreamListener) Acquire() (StreamListener, error) { + refCount, err := func() (RefCount[StreamListener], error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.ln == nil { + tcpAddr, err := net.ResolveTCPAddr("tcp", m.addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + var sl StreamListener + sl = &TCPListener{ln} + if m.proxy { + sl = &ProxyStreamListener{StreamListener: sl} + } + m.ln = NewRefCount[StreamListener](sl, m.onCloseFunc) + m.acceptCh = make(chan acceptResponse) + go func() { + for { + conn, err := sl.AcceptStream() + if errors.Is(err, net.ErrClosed) { + close(m.acceptCh) + return + } + m.acceptCh <- acceptResponse{conn, err} + } + }() + } + return m.ln, nil + }() + if err != nil { + return nil, err + } + + sl := refCount.Acquire() + return &virtualStreamListener{ + addr: sl.Addr(), + acceptCh: m.acceptCh, + closeCh: make(chan struct{}), + onCloseFunc: refCount.Close, + }, nil +} + +type multiPacketListener struct { + mu sync.Mutex + addr string + pc RefCount[net.PacketConn] + readCh chan packetResponse + onCloseFunc OnCloseFunc +} + +// NewMultiPacketListener creates a new packet-based [MultiListener]. +func NewMultiPacketListener(addr string, onCloseFunc OnCloseFunc) MultiListener[net.PacketConn] { + return &multiPacketListener{ + addr: addr, + onCloseFunc: onCloseFunc, + } +} + +func (m *multiPacketListener) Acquire() (net.PacketConn, error) { + refCount, err := func() (RefCount[net.PacketConn], error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.pc == nil { + pc, err := net.ListenPacket("udp", m.addr) + if err != nil { + return nil, err + } + m.pc = NewRefCount(pc, m.onCloseFunc) + m.readCh = make(chan packetResponse) + go func() { + for { + buffer := make([]byte, serverUDPBufferSize) + n, addr, err := pc.ReadFrom(buffer) + if err != nil { + close(m.readCh) + return + } + m.readCh <- packetResponse{n: n, addr: addr, err: err, data: buffer[:n]} + } + }() + } + return m.pc, nil + }() + if err != nil { + return nil, err + } + + pc := refCount.Acquire() + return &virtualPacketConn{ + PacketConn: pc, + readCh: m.readCh, + closeCh: make(chan struct{}), + onCloseFunc: refCount.Close, + }, nil +} + +// ListenerManager holds the state of shared listeners. +type ListenerManager interface { + // ListenStream creates a new stream listener for a given address. + ListenStream(addr string, proxy bool) (StreamListener, error) + + // ListenPacket creates a new packet listener for a given address. + ListenPacket(addr string, proxy bool) (net.PacketConn, error) +} + +type listenerManager struct { + streamListeners map[string]MultiListener[StreamListener] + packetListeners map[string]MultiListener[net.PacketConn] + mu sync.Mutex +} + +// NewListenerManager creates a new [ListenerManger]. +func NewListenerManager() ListenerManager { + return &listenerManager{ + streamListeners: make(map[string]MultiListener[StreamListener]), + packetListeners: make(map[string]MultiListener[net.PacketConn]), + } +} + +func (m *listenerManager) ListenStream(addr string, proxy bool) (StreamListener, error) { + m.mu.Lock() + defer m.mu.Unlock() + + streamLn, exists := m.streamListeners[addr] + if !exists { + streamLn = NewMultiStreamListener( + addr, + proxy, + func() error { + m.mu.Lock() + delete(m.streamListeners, addr) + m.mu.Unlock() + return nil + }, + ) + m.streamListeners[addr] = streamLn + } + ln, err := streamLn.Acquire() + if err != nil { + return nil, fmt.Errorf("unable to create stream listener for %s: %v", addr, err) + } + return ln, nil +} + +func (m *listenerManager) ListenPacket(addr string, proxy bool) (net.PacketConn, error) { + m.mu.Lock() + defer m.mu.Unlock() + + packetLn, exists := m.packetListeners[addr] + if !exists { + packetLn = NewMultiPacketListener( + addr, + func() error { + m.mu.Lock() + delete(m.packetListeners, addr) + m.mu.Unlock() + return nil + }, + ) + m.packetListeners[addr] = packetLn + } + + ln, err := packetLn.Acquire() + if err != nil { + return nil, fmt.Errorf("unable to create packet listener for %s: %v", addr, err) + } + return ln, nil +} + +// RefCount is an atomic reference counter that can be used to track a shared +// [io.Closer] resource. +type RefCount[T io.Closer] interface { + io.Closer + + // Acquire increases the ref count and returns the wrapped object. + Acquire() T +} + +type refCount[T io.Closer] struct { + mu sync.Mutex + count *atomic.Int32 + value T + onCloseFunc OnCloseFunc +} + +func NewRefCount[T io.Closer](value T, onCloseFunc OnCloseFunc) RefCount[T] { + r := &refCount[T]{ + count: &atomic.Int32{}, + value: value, + onCloseFunc: onCloseFunc, + } + return r +} + +func (r refCount[T]) Acquire() T { + r.count.Add(1) + return r.value +} + +func (r refCount[T]) Close() error { + // Lock to prevent someone from acquiring while we close the value. + r.mu.Lock() + defer r.mu.Unlock() + + if count := r.count.Add(-1); count == 0 { + err := r.value.Close() + if err != nil { + return err + } + if r.onCloseFunc != nil { + return r.onCloseFunc() + } + return nil + } + return nil +} diff --git a/service/listeners_test.go b/service/listeners_test.go new file mode 100644 index 00000000..0bdb3df8 --- /dev/null +++ b/service/listeners_test.go @@ -0,0 +1,255 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "fmt" + "net" + "testing" + + "github.com/pires/go-proxyproto" + "github.com/stretchr/testify/require" +) + +func TestDirectListenerSetsRemoteAddrAsClientAddr(t *testing.T) { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + go func() { + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err) + conn.Write(makeTestPayload(50)) + conn.Close() + }() + + ln := &TCPListener{listener} + conn, err := ln.AcceptStream() + require.NoError(t, err) + require.Equal(t, conn.RemoteAddr(), conn.ClientAddr()) +} + +func TestProxyProtocolListenerParsesSourceAddressAsClientAddr(t *testing.T) { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + sourceAddr := &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + } + go func() { + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err) + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: sourceAddr, + DestinationAddr: conn.RemoteAddr(), + } + header.WriteTo(conn) + conn.Write(makeTestPayload(50)) + conn.Close() + }() + + ln := &ProxyStreamListener{StreamListener: &TCPListener{listener}} + conn, err := ln.AcceptStream() + require.NoError(t, err) + require.True(t, sourceAddr.IP.Equal(conn.ClientAddr().(*net.TCPAddr).IP)) + require.Equal(t, sourceAddr.Port, conn.ClientAddr().(*net.TCPAddr).Port) +} + +func TestProxyProtocolListenerUsesRemoteAddrAsClientAddrIfLocalHeader(t *testing.T) { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + go func() { + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err) + + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.LOCAL, + TransportProtocol: proxyproto.UNSPEC, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + }, + DestinationAddr: conn.RemoteAddr(), + } + header.WriteTo(conn) + conn.Write(makeTestPayload(50)) + conn.Close() + }() + + ln := &ProxyStreamListener{StreamListener: &TCPListener{listener}} + conn, err := ln.AcceptStream() + require.NoError(t, err) + require.Equal(t, conn.RemoteAddr(), conn.ClientAddr()) +} + +func TestListenerManagerStreamListenerEarlyClose(t *testing.T) { + m := NewListenerManager() + ln, err := m.ListenStream("127.0.0.1:0", false) + require.NoError(t, err) + + ln.Close() + _, err = ln.AcceptStream() + + require.ErrorIs(t, err, net.ErrClosed) +} + +func writeTestPayload(ln StreamListener) error { + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + return fmt.Errorf("Failed to dial %v: %v", ln.Addr().String(), err) + } + if _, err = conn.Write(makeTestPayload(50)); err != nil { + return fmt.Errorf("Failed to write to connection: %v", err) + } + conn.Close() + return nil +} + +func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { + m := NewListenerManager() + ln, err := m.ListenStream("127.0.0.1:0", false) + require.NoError(t, err) + ln2, err := m.ListenStream("127.0.0.1:0", false) + require.NoError(t, err) + // Close only the first listener. + ln.Close() + + done := make(chan struct{}) + go func() { + ln2.AcceptStream() + done <- struct{}{} + }() + err = writeTestPayload(ln2) + + require.NoError(t, err) + <-done +} + +func TestListenerManagerStreamListenerCreatesListenerOnDemand(t *testing.T) { + m := NewListenerManager() + // Create a listener and immediately close it. + ln, err := m.ListenStream("127.0.0.1:0", false) + require.NoError(t, err) + ln.Close() + // Now create another listener on the same address. + ln2, err := m.ListenStream("127.0.0.1:0", false) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + ln2.AcceptStream() + done <- struct{}{} + }() + err = writeTestPayload(ln2) + + require.NoError(t, err) + <-done +} + +func TestListenerManagerPacketListenerEarlyClose(t *testing.T) { + m := NewListenerManager() + pc, err := m.ListenPacket("127.0.0.1:0", false) + require.NoError(t, err) + + pc.Close() + _, _, readErr := pc.ReadFrom(nil) + _, writeErr := pc.WriteTo(nil, &net.UDPAddr{}) + + require.ErrorIs(t, readErr, net.ErrClosed) + require.ErrorIs(t, writeErr, net.ErrClosed) +} + +func TestListenerManagerPacketListenerNotClosedIfStillInUse(t *testing.T) { + m := NewListenerManager() + pc, err := m.ListenPacket("127.0.0.1:0", false) + require.NoError(t, err) + pc2, err := m.ListenPacket("127.0.0.1:0", false) + require.NoError(t, err) + // Close only the first listener. + pc.Close() + + done := make(chan struct{}) + go func() { + _, _, readErr := pc2.ReadFrom(nil) + require.NoError(t, readErr) + done <- struct{}{} + }() + _, err = pc.WriteTo(nil, pc2.LocalAddr()) + + require.NoError(t, err) + <-done +} + +func TestListenerManagerPacketListenerCreatesListenerOnDemand(t *testing.T) { + m := NewListenerManager() + // Create a listener and immediately close it. + pc, err := m.ListenPacket("127.0.0.1:0", false) + require.NoError(t, err) + pc.Close() + // Now create another listener on the same address. + pc2, err := m.ListenPacket("127.0.0.1:0", false) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + _, _, readErr := pc2.ReadFrom(nil) + require.NoError(t, readErr) + done <- struct{}{} + }() + _, err = pc2.WriteTo(nil, pc2.LocalAddr()) + + require.NoError(t, err) + <-done +} + +type testRefCount struct { + onCloseFunc func() +} + +func (t *testRefCount) Close() error { + t.onCloseFunc() + return nil +} + +func TestRefCount(t *testing.T) { + var objectCloseDone bool + var onCloseFuncDone bool + rc := NewRefCount[*testRefCount]( + &testRefCount{ + onCloseFunc: func() { + objectCloseDone = true + }, + }, + func() error { + onCloseFuncDone = true + return nil + }, + ) + rc.Acquire() + rc.Acquire() + + require.NoError(t, rc.Close()) + require.False(t, objectCloseDone) + require.False(t, onCloseFuncDone) + + require.NoError(t, rc.Close()) + require.True(t, objectCloseDone) + require.True(t, onCloseFuncDone) +} diff --git a/service/proxyproto.go b/service/proxyproto.go new file mode 100644 index 00000000..3978daf9 --- /dev/null +++ b/service/proxyproto.go @@ -0,0 +1,51 @@ +// Copyright 2024 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "bufio" + "errors" + "fmt" + + "github.com/pires/go-proxyproto" +) + +// ProxyStreamListener wraps a [StreamListener] and fetches the source of the connection from the PROXY +// protocol header string. See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt. +type ProxyStreamListener struct { + StreamListener +} + +// AcceptStream waits for the next incoming connection, parses the client IP from the PROXY protocol +// header, and adds it to the connection. +func (l *ProxyStreamListener) AcceptStream() (ClientStreamConn, error) { + conn, err := l.StreamListener.AcceptStream() + if err != nil { + return nil, err + } + r := bufio.NewReader(conn) + header, err := proxyproto.Read(r) + if errors.Is(err, proxyproto.ErrNoProxyProtocol) { + logger.Warningf("Received connection from %v without proxy header.", conn.RemoteAddr()) + return conn, nil + } + if header == nil || err != nil { + return nil, fmt.Errorf("error parsing proxy header: %v", err) + } + if header.Command.IsLocal() { + return conn, nil + } + return &clientStreamConn{StreamConn: conn, clientAddr: header.SourceAddr}, nil +} diff --git a/service/tcp.go b/service/tcp.go index ab74ce6a..02da5f3c 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -161,17 +161,16 @@ func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCa } } -type tcpHandler struct { - listenerId string +type streamHandler struct { m TCPMetrics readTimeout time.Duration authenticate StreamAuthenticateFunc dialer transport.StreamDialer } -// NewTCPService creates a TCPService -func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) TCPHandler { - return &tcpHandler{ +// NewStreamHandler creates a StreamHandler +func NewStreamHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration) StreamHandler { + return &streamHandler{ m: m, readTimeout: timeout, authenticate: authenticate, @@ -188,14 +187,14 @@ func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) tra }}} } -// TCPService is a Shadowsocks TCP service that can be started and stopped. -type TCPHandler interface { - Handle(ctx context.Context, conn transport.StreamConn) +// StreamHandler is a handler that handles stream connections. +type StreamHandler interface { + Handle(ctx context.Context, conn ClientStreamConn) // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses. SetTargetDialer(dialer transport.StreamDialer) } -func (s *tcpHandler) SetTargetDialer(dialer transport.StreamDialer) { +func (s *streamHandler) SetTargetDialer(dialer transport.StreamDialer) { s.dialer = dialer } @@ -211,20 +210,24 @@ func ensureConnectionError(err error, fallbackStatus string, fallbackMsg string) } } -type StreamListener func() (transport.StreamConn, error) +type StreamAcceptFunc func() (ClientStreamConn, error) -func WrapStreamListener[T transport.StreamConn](f func() (T, error)) StreamListener { - return func() (transport.StreamConn, error) { - return f() +func WrapStreamAcceptFunc[T transport.StreamConn](f func() (T, error)) StreamAcceptFunc { + return func() (ClientStreamConn, error) { + c, err := f() + if err != nil { + return nil, err + } + return &clientStreamConn{StreamConn: c, clientAddr: c.RemoteAddr()}, err } } -type StreamHandler func(ctx context.Context, conn transport.StreamConn) +type StreamHandleFunc func(ctx context.Context, conn ClientStreamConn) // StreamServe repeatedly calls `accept` to obtain connections and `handle` to handle them until // accept() returns [ErrClosed]. When that happens, all connection handlers will be notified // via their [context.Context]. StreamServe will return after all pending handlers return. -func StreamServe(accept StreamListener, handle StreamHandler) { +func StreamServe(accept StreamAcceptFunc, handle StreamHandleFunc) { var running sync.WaitGroup defer running.Wait() ctx, contextCancel := context.WithCancel(context.Background()) @@ -235,7 +238,7 @@ func StreamServe(accept StreamListener, handle StreamHandler) { if errors.Is(err, net.ErrClosed) { break } - logger.Warningf("AcceptTCP failed: %v. Continuing to listen.", err) + logger.Warningf("Accept failed: %v. Continuing to listen.", err) continue } @@ -253,12 +256,12 @@ func StreamServe(accept StreamListener, handle StreamHandler) { } } -func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn) { - clientInfo, err := ipinfo.GetIPInfoFromAddr(h.m, clientConn.RemoteAddr()) +func (h *streamHandler) Handle(ctx context.Context, clientConn ClientStreamConn) { + clientInfo, err := ipinfo.GetIPInfoFromAddr(h.m, clientConn.ClientAddr()) if err != nil { logger.Warningf("Failed client info lookup: %v", err) } - logger.Debugf("Got info \"%#v\" for IP %v", clientInfo, clientConn.RemoteAddr().String()) + logger.Debugf("Got info \"%#v\" for IP %v", clientInfo, clientConn.ClientAddr().String()) h.m.AddOpenTCPConnection(clientInfo) var proxyMetrics metrics.ProxyMetrics measuredClientConn := metrics.MeasureConn(clientConn, &proxyMetrics.ProxyClient, &proxyMetrics.ClientProxy) @@ -272,7 +275,7 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn status = connError.Status logger.Debugf("TCP Error: %v: %v", connError.Message, connError.Cause) } - h.m.AddClosedTCPConnection(clientInfo, clientConn.RemoteAddr(), id, status, proxyMetrics, connDuration) + h.m.AddClosedTCPConnection(clientInfo, clientConn.ClientAddr(), id, status, proxyMetrics, connDuration) measuredClientConn.Close() // Closing after the metrics are added aids integration testing. logger.Debugf("Done with status %v, duration %v", status, connDuration) } @@ -327,7 +330,7 @@ func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr return nil } -func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) { +func (h *streamHandler) handleConnection(ctx context.Context, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) { // Set a deadline to receive the address to the target. readDeadline := time.Now().Add(h.readTimeout) if deadline, ok := ctx.Deadline(); ok { @@ -369,7 +372,7 @@ func (h *tcpHandler) handleConnection(ctx context.Context, outerConn transport.S // Keep the connection open until we hit the authentication deadline to protect against probing attacks // `proxyMetrics` is a pointer because its value is being mutated by `clientConn`. -func (h *tcpHandler) absorbProbe(clientConn io.ReadCloser, addr, status string, proxyMetrics *metrics.ProxyMetrics) { +func (h *streamHandler) absorbProbe(clientConn io.ReadCloser, addr, status string, proxyMetrics *metrics.ProxyMetrics) { // This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe. _, drainErr := io.Copy(io.Discard, clientConn) // drain socket drainResult := drainErrToString(drainErr) diff --git a/service/tcp_test.go b/service/tcp_test.go index fbe80f7c..428f70e0 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -281,10 +281,10 @@ func TestProbeRandom(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) + handler := NewStreamHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -358,11 +358,11 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) + handler := NewStreamHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -393,11 +393,11 @@ func TestProbeClientBytesBasicModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) + handler := NewStreamHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -429,11 +429,11 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) + handler := NewStreamHandler(authFunc, testMetrics, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -472,10 +472,10 @@ func TestProbeServerBytesModified(t *testing.T) { cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, 200*time.Millisecond) + handler := NewStreamHandler(authFunc, testMetrics, 200*time.Millisecond) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -503,7 +503,7 @@ func TestReplayDefense(t *testing.T) { testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := NewStreamHandler(authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -528,7 +528,7 @@ func TestReplayDefense(t *testing.T) { done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -582,7 +582,7 @@ func TestReverseReplayDefense(t *testing.T) { testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := NewStreamHandler(authFunc, testMetrics, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) cipher := cipherEntry.CryptoKey @@ -598,7 +598,7 @@ func TestReverseReplayDefense(t *testing.T) { done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -653,11 +653,11 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, testMetrics) - handler := NewTCPHandler(authFunc, testMetrics, testTimeout) + handler := NewStreamHandler(authFunc, testMetrics, testTimeout) done := make(chan struct{}) go func() { - StreamServe(WrapStreamListener(listener.AcceptTCP), handler.Handle) + StreamServe(WrapStreamAcceptFunc(listener.AcceptTCP), handler.Handle) done <- struct{}{} }() @@ -717,14 +717,14 @@ func TestStreamServeEarlyClose(t *testing.T) { err = tcpListener.Close() require.NoError(t, err) // This should return quickly, without timing out or calling the handler. - StreamServe(WrapStreamListener(tcpListener.AcceptTCP), nil) + StreamServe(WrapStreamAcceptFunc(tcpListener.AcceptTCP), nil) } // Makes sure the TCP listener returns [io.ErrClosed] on Close(). func TestClosedTCPListenerError(t *testing.T) { tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) require.NoError(t, err) - accept := WrapStreamListener(tcpListener.AcceptTCP) + accept := WrapStreamAcceptFunc(tcpListener.AcceptTCP) err = tcpListener.Close() require.NoError(t, err) _, err = accept()