From 044176e1b9456027c8e3c3242b587740be33fbee Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Thu, 27 Jan 2022 12:45:06 -0500 Subject: [PATCH] Dynamically resolve reverse tunnel address The reverse tunnel address is currently a static string that is retrieved from config and passed around for the duration of a services lifetime. When the `tunnel_public_address` is changed on the proxy and the proxy is then restarted, all established reverse tunnels over the old address will fail indefinintely. As a means to get around this, #8102 introduced a mechanism that would cause nodes to restart if their connection to the auth server was down for a period of time. While this did allow the nodes to pickup the new address after the nodes restarted it was meant to be a stop gap until a more robust solution could be applid. Instead of using a static address, the reverse tunnel address is now resolved via a `reversetunnel.Resolver`. Anywhere that previoulsy relied on the static proxy address now will fetch the actual reverse tunnel address via the webclient by using the Resolver. In addition this builds on the refactoring done in #4290 to further simplify the reversetunnel package. Since we no longer track multiple proxies, all the left over bits that did so have been removed to accomodate using a dynamic reverse tunnel address. --- integration/restart_test.go | 111 ----------------------- lib/reversetunnel/agent.go | 25 +++-- lib/reversetunnel/agentpool.go | 80 ++++++++-------- lib/reversetunnel/rc_manager.go | 42 +++++---- lib/reversetunnel/rc_manager_test.go | 66 +++++++++----- lib/reversetunnel/resolver.go | 67 ++++++++++++++ lib/reversetunnel/resolver_test.go | 116 ++++++++++++++++++++++++ lib/reversetunnel/track/tracker.go | 90 +++++++++--------- lib/reversetunnel/track/tracker_test.go | 114 +++++++++++------------ lib/reversetunnel/transport.go | 30 +++--- lib/service/connect.go | 51 +---------- lib/service/db.go | 24 +++-- lib/service/desktop.go | 15 +-- lib/service/kubernetes.go | 2 +- lib/service/service.go | 59 +++++++----- lib/utils/addr.go | 2 +- lib/utils/timed_counter.go | 2 +- lib/utils/workpool/workpool.go | 49 +++++----- lib/utils/workpool/workpool_test.go | 91 +++++-------------- tool/tctl/common/tctl.go | 26 +----- 20 files changed, 528 insertions(+), 534 deletions(-) delete mode 100644 integration/restart_test.go create mode 100644 lib/reversetunnel/resolver.go create mode 100644 lib/reversetunnel/resolver_test.go diff --git a/integration/restart_test.go b/integration/restart_test.go deleted file mode 100644 index bb28c4ad5b660..0000000000000 --- a/integration/restart_test.go +++ /dev/null @@ -1,111 +0,0 @@ -/* -Copyright 2021 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package integration - -import ( - "context" - "testing" - "time" - - "github.com/gravitational/teleport/lib" - "github.com/gravitational/teleport/lib/auth/testauthority" - "github.com/gravitational/teleport/lib/service" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" -) - -// TestLostConnectionToAuthCausesReload tests that a lost connection to the auth server -// will eventually restart a node -func TestLostConnectionToAuthCausesReload(t *testing.T) { - // Because testing that the node does a full restart is a bit tricky when - // running a cluster from inside a test runner (i.e. we don't want to - // SIGTERM the test runner), we will watch for the node emitting a - // `TeleportReload` even instead. In a proper Teleport instance, this - // event would be picked up at the Supervisor level and would eventually - // cause the instance to gracefully restart. - - require := require.New(t) - log := log.StandardLogger() - - log.Info(">>> Entering Test") - - // InsecureDevMode needed for SSH connections - // TODO(tcsc): surface this as per-server config (see also issue #8913) - lib.SetInsecureDevMode(true) - defer lib.SetInsecureDevMode(false) - - // GIVEN a cluster with a running auth+proxy instance.... - log.Info(">>> Creating cluster") - keygen := testauthority.New() - privateKey, publicKey, err := keygen.GenerateKeyPair("") - require.NoError(err) - auth := NewInstance(InstanceConfig{ - ClusterName: "test-tunnel-collapse", - HostID: "auth", - Priv: privateKey, - Pub: publicKey, - Ports: standardPortSetup(), - log: log, - }) - - log.Info(">>> Creating auth-proxy...") - authConf := service.MakeDefaultConfig() - authConf.Hostname = Host - authConf.Auth.Enabled = true - authConf.Proxy.Enabled = true - authConf.SSH.Enabled = false - authConf.Proxy.DisableWebInterface = true - authConf.Proxy.DisableDatabaseProxy = true - require.NoError(auth.CreateEx(t, nil, authConf)) - t.Cleanup(func() { require.NoError(auth.StopAll()) }) - - log.Info(">>> Start auth-proxy...") - require.NoError(auth.Start()) - - // ... and an SSH node connected via a reverse tunnel configured to - // reload after only a few failed connection attempts per minute - log.Info(">>> Creating and starting node...") - nodeCfg := service.MakeDefaultConfig() - nodeCfg.Hostname = Host - nodeCfg.SSH.Enabled = true - nodeCfg.RotationConnectionInterval = 1 * time.Second - nodeCfg.RestartThreshold = service.Rate{Amount: 3, Time: 1 * time.Minute} - node, err := auth.StartReverseTunnelNode(nodeCfg) - require.NoError(err) - - // WHEN I stop the auth node (and, by implication, disrupt the ssh node's - // connection to it) - log.Info(">>> Stopping auth node") - auth.StopAuth(false) - - // EXPECT THAT the ssh node will eventually issue a reload request - log.Info(">>> Waiting for node restart request.") - waitCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) - defer cancel() - - eventCh := make(chan service.Event) - node.WaitForEvent(waitCtx, service.TeleportReloadEvent, eventCh) - select { - case e := <-eventCh: - log.Infof(">>> Received Reload event: %v. Test passed.", e) - - case <-waitCtx.Done(): - require.FailNow("Timed out", "Timed out waiting for reload event") - } - - log.Info(">>> TEST COMPLETE") -} diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 6faf5840e592f..96eea026c9580 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -384,12 +384,19 @@ func (a *Agent) run() { a.log.Warningf("Failed to create remote tunnel: %v, conn: %v.", err, conn) return } - defer conn.Close() + + local := conn.LocalAddr().String() + remote := conn.RemoteAddr().String() + defer func() { + if err := conn.Close(); err != nil { + a.log.Warnf("Failed to close remote tunnel: %v, local addr: %s remote addr: %s", err, local, remote) + } + }() // Successfully connected to remote cluster. a.log.WithFields(log.Fields{ - "addr": conn.LocalAddr().String(), - "remote-addr": conn.RemoteAddr().String(), + "addr": local, + "remote-addr": remote, }).Info("Connected.") // wrap up remaining business logic in closure for easy @@ -414,14 +421,14 @@ func (a *Agent) run() { // or permanent loss of a proxy. err = a.processRequests(conn) if err != nil { - a.log.Warnf("Unable to continue processesing requests: %v.", err) + a.log.Warnf("Unable to continue processioning requests: %v.", err) return } } // if Tracker was provided, then the agent shouldn't continue unless // no other agents hold a claim. if a.Tracker != nil { - if !a.Tracker.WithProxy(doWork, a.Lease, a.getPrincipalsList()...) { + if !a.Tracker.WithProxy(doWork, a.getPrincipalsList()...) { a.log.Debugf("Proxy already held by other agent: %v, releasing.", a.getPrincipalsList()) } } else { @@ -518,7 +525,7 @@ func (a *Agent) processRequests(conn *ssh.Client) error { } } -// handleDisovery receives discovery requests from the reverse tunnel +// handleDiscovery receives discovery requests from the reverse tunnel // server, that informs agent about proxies registered in the remote // cluster and the reverse tunnels already established // @@ -526,7 +533,11 @@ func (a *Agent) processRequests(conn *ssh.Client) error { // reqC : request payload func (a *Agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) { a.log.Debugf("handleDiscovery requests channel.") - defer ch.Close() + defer func() { + if err := ch.Close(); err != nil { + a.log.Warnf("Failed to closed connection: %v", err) + } + }() for { var req *ssh.Request diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index cc6ca546cfacb..476271553aaf1 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -59,7 +59,7 @@ type AgentPool struct { spawnLimiter utils.Retry mu sync.Mutex - agents map[utils.NetAddr][]*Agent + agents []*Agent } // AgentPoolConfig holds configuration parameters for the agent pool @@ -89,8 +89,8 @@ type AgentPoolConfig struct { Component string // ReverseTunnelServer holds all reverse tunnel connections. ReverseTunnelServer Server - // ProxyAddr points to the address of the ssh proxy - ProxyAddr string + // Resolver retrieves the reverse tunnel address + Resolver Resolver // Cluster is a cluster name of the proxy. Cluster string // FIPS indicates if Teleport was started in FIPS mode. @@ -135,11 +135,6 @@ func NewAgentPool(ctx context.Context, cfg AgentPoolConfig) (*AgentPool, error) return nil, trace.Wrap(err) } - proxyAddr, err := utils.ParseAddr(cfg.ProxyAddr) - if err != nil { - return nil, trace.Wrap(err) - } - ctx, cancel := context.WithCancel(ctx) tr, err := track.New(ctx, track.Config{ClusterName: cfg.Cluster}) if err != nil { @@ -148,7 +143,7 @@ func NewAgentPool(ctx context.Context, cfg AgentPoolConfig) (*AgentPool, error) } pool := &AgentPool{ - agents: make(map[utils.NetAddr][]*Agent), + agents: nil, proxyTracker: tr, cfg: cfg, ctx: ctx, @@ -161,7 +156,7 @@ func NewAgentPool(ctx context.Context, cfg AgentPoolConfig) (*AgentPool, error) }, }), } - pool.proxyTracker.Start(*proxyAddr) + pool.proxyTracker.Start() return pool, nil } @@ -204,7 +199,6 @@ func (m *AgentPool) processSeekEvents() { // The proxy tracker has given us permission to act on a given // tunnel address case lease := <-m.proxyTracker.Acquire(): - m.log.Debugf("Seeking: %+v.", lease.Key()) m.withLock(func() { // Note that ownership of the lease is transferred to agent // pool for the lifetime of the connection @@ -232,11 +226,11 @@ func (m *AgentPool) withLock(f func()) { type matchAgentFn func(a *Agent) bool func (m *AgentPool) closeAgents() { - for key, agents := range m.agents { - m.agents[key] = filterAndClose(agents, func(*Agent) bool { return true }) - if len(m.agents[key]) == 0 { - delete(m.agents, key) - } + agents := filterAndClose(m.agents, func(*Agent) bool { return true }) + if len(agents) <= 0 { + m.agents = nil + } else { + m.agents = agents } } @@ -246,7 +240,9 @@ func filterAndClose(agents []*Agent, matchAgent matchAgentFn) []*Agent { agent := agents[i] if matchAgent(agent) { agent.log.Debugf("Pool is closing agent.") - agent.Close() + if err := agent.Close(); err != nil { + agent.log.WithError(err).Warnf("Failed to close agent") + } } else { filtered = append(filtered, agent) } @@ -271,21 +267,24 @@ func (m *AgentPool) pollAndSyncAgents() { // getReverseTunnelDetails gets the cached ReverseTunnelDetails obtained during the oldest cached agent.connect call. // This function should be called under a lock. -func (m *AgentPool) getReverseTunnelDetails(addr utils.NetAddr) *reverseTunnelDetails { - agents, ok := m.agents[addr] - if !ok || len(agents) == 0 { +func (m *AgentPool) getReverseTunnelDetails() *reverseTunnelDetails { + if len(m.agents) <= 0 { return nil } - return agents[0].reverseTunnelDetails + return m.agents[0].reverseTunnelDetails } // addAgent adds a new agent to the pool. Note that ownership of the lease // transfers into the AgentPool, and will be released when the AgentPool // is done with it. func (m *AgentPool) addAgent(lease track.Lease) error { - addr := lease.Key().(utils.NetAddr) + addr, err := m.cfg.Resolver() + if err != nil { + return trace.Wrap(err) + } + agent, err := NewAgent(AgentConfig{ - Addr: addr, + Addr: *addr, ClusterName: m.cfg.Cluster, Username: m.cfg.HostUUID, Signer: m.cfg.HostSigner, @@ -300,7 +299,7 @@ func (m *AgentPool) addAgent(lease track.Lease) error { Tracker: m.proxyTracker, Lease: lease, FIPS: m.cfg.FIPS, - reverseTunnelDetails: m.getReverseTunnelDetails(addr), + reverseTunnelDetails: m.getReverseTunnelDetails(), }) if err != nil { // ensure that lease has been released; OK to call multiple times. @@ -311,21 +310,19 @@ func (m *AgentPool) addAgent(lease track.Lease) error { // start the agent in a goroutine. no need to handle Start() errors: Start() will be // retrying itself until the agent is closed go agent.Start() - m.agents[addr] = append(m.agents[addr], agent) + m.agents = append(m.agents, agent) return nil } -// Counts returns a count of the number of proxies a outbound tunnel is +// Count returns a count of the number of proxies an outbound tunnel is // connected to. Used in tests to determine if a proxy has been found and/or // removed. func (m *AgentPool) Count() int { var out int m.withLock(func() { - for _, agents := range m.agents { - for _, agent := range agents { - if agent.getState() == agentStateConnected { - out++ - } + for _, agent := range m.agents { + if agent.getState() == agentStateConnected { + out++ } } }) @@ -336,18 +333,15 @@ func (m *AgentPool) Count() int { // removeDisconnected removes disconnected agents from the list of agents. // This function should be called under a lock. func (m *AgentPool) removeDisconnected() { - for agentKey, agentSlice := range m.agents { - // Filter and close all disconnected agents. - validAgents := filterAndClose(agentSlice, func(agent *Agent) bool { - return agent.getState() == agentStateDisconnected - }) - - // Update (or delete) agent key with filter applied. - if len(validAgents) > 0 { - m.agents[agentKey] = validAgents - } else { - delete(m.agents, agentKey) - } + // Filter and close all disconnected agents. + agents := filterAndClose(m.agents, func(agent *Agent) bool { + return agent.getState() == agentStateDisconnected + }) + + if len(agents) <= 0 { + m.agents = nil + } else { + m.agents = agents } } diff --git a/lib/reversetunnel/rc_manager.go b/lib/reversetunnel/rc_manager.go index ce3ceae8d1cec..d5572bd687825 100644 --- a/lib/reversetunnel/rc_manager.go +++ b/lib/reversetunnel/rc_manager.go @@ -45,7 +45,7 @@ type RemoteClusterTunnelManager struct { pools map[remoteClusterKey]*AgentPool stopRun func() - newAgentPool func(ctx context.Context, cluster, addr string) (*AgentPool, error) + newAgentPool func(ctx context.Context, cfg RemoteClusterTunnelManagerConfig, cluster, addr string) (*AgentPool, error) } type remoteClusterKey struct { @@ -106,17 +106,17 @@ func (c *RemoteClusterTunnelManagerConfig) CheckAndSetDefaults() error { return nil } -// NewRemoteClusterTunnelManager creates a new unstarted tunnel manager with +// NewRemoteClusterTunnelManager creates a new stopped tunnel manager with // the provided config. Call Run() to start the manager. func NewRemoteClusterTunnelManager(cfg RemoteClusterTunnelManagerConfig) (*RemoteClusterTunnelManager, error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } w := &RemoteClusterTunnelManager{ - cfg: cfg, - pools: make(map[remoteClusterKey]*AgentPool), + cfg: cfg, + pools: make(map[remoteClusterKey]*AgentPool), + newAgentPool: realNewAgentPool, } - w.newAgentPool = w.realNewAgentPool return w, nil } @@ -198,7 +198,7 @@ func (w *RemoteClusterTunnelManager) Sync(ctx context.Context) error { continue } - pool, err := w.newAgentPool(ctx, k.cluster, k.addr) + pool, err := w.newAgentPool(ctx, w.cfg, k.cluster, k.addr) if err != nil { errs = append(errs, trace.Wrap(err)) continue @@ -208,29 +208,33 @@ func (w *RemoteClusterTunnelManager) Sync(ctx context.Context) error { return trace.NewAggregate(errs...) } -func (w *RemoteClusterTunnelManager) realNewAgentPool(ctx context.Context, cluster, addr string) (*AgentPool, error) { +func realNewAgentPool(ctx context.Context, cfg RemoteClusterTunnelManagerConfig, cluster, addr string) (*AgentPool, error) { pool, err := NewAgentPool(ctx, AgentPoolConfig{ // Configs for our cluster. - Client: w.cfg.AuthClient, - AccessPoint: w.cfg.AccessPoint, - HostSigner: w.cfg.HostSigner, - HostUUID: w.cfg.HostUUID, - LocalCluster: w.cfg.LocalCluster, - Clock: w.cfg.Clock, - KubeDialAddr: w.cfg.KubeDialAddr, - ReverseTunnelServer: w.cfg.ReverseTunnelServer, - FIPS: w.cfg.FIPS, + Client: cfg.AuthClient, + AccessPoint: cfg.AccessPoint, + HostSigner: cfg.HostSigner, + HostUUID: cfg.HostUUID, + LocalCluster: cfg.LocalCluster, + Clock: cfg.Clock, + KubeDialAddr: cfg.KubeDialAddr, + ReverseTunnelServer: cfg.ReverseTunnelServer, + FIPS: cfg.FIPS, // RemoteClusterManager only runs on proxies. Component: teleport.ComponentProxy, // Configs for remote cluster. - Cluster: cluster, - ProxyAddr: addr, + Cluster: cluster, + Resolver: StaticResolver(addr), }) if err != nil { return nil, trace.Wrap(err, "failed creating reverse tunnel pool for remote cluster %q at address %q: %v", cluster, addr, err) } - go pool.Start() + go func() { + if err := pool.Start(); err != nil { + cfg.Log.WithError(err).Error("Failed to start agent pool") + } + }() return pool, nil } diff --git a/lib/reversetunnel/rc_manager_test.go b/lib/reversetunnel/rc_manager_test.go index 231728d0eda4d..950d9d6df194a 100644 --- a/lib/reversetunnel/rc_manager_test.go +++ b/lib/reversetunnel/rc_manager_test.go @@ -19,9 +19,11 @@ import ( "errors" "testing" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" @@ -30,17 +32,27 @@ import ( func TestRemoteClusterTunnelManagerSync(t *testing.T) { t.Parallel() + resolverFn := func(addr string) Resolver { + return func() (*utils.NetAddr, error) { + return &utils.NetAddr{ + Addr: addr, + AddrNetwork: "tcp", + Path: "", + }, nil + } + } + var newAgentPoolErr error w := &RemoteClusterTunnelManager{ pools: make(map[remoteClusterKey]*AgentPool), - newAgentPool: func(ctx context.Context, cluster, addr string) (*AgentPool, error) { + newAgentPool: func(ctx context.Context, cfg RemoteClusterTunnelManagerConfig, cluster, addr string) (*AgentPool, error) { return &AgentPool{ - cfg: AgentPoolConfig{Cluster: cluster, ProxyAddr: addr}, + cfg: AgentPoolConfig{Cluster: cluster, Resolver: resolverFn(addr)}, cancel: func() {}, }, newAgentPoolErr }, } - defer w.Close() + t.Cleanup(func() { require.NoError(t, w.Close()) }) tests := []struct { desc string @@ -61,7 +73,7 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { mustNewReverseTunnel(t, "cluster-a", []string{"addr-a"}), }, wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-a"}}, + {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-a")}}, }, assertErr: require.NoError, }, @@ -71,9 +83,9 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { mustNewReverseTunnel(t, "cluster-a", []string{"addr-a", "addr-b", "addr-c"}), }, wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-a"}}, - {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-b"}}, - {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-c"}}, + {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-a")}}, + {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-b")}}, + {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-c")}}, }, assertErr: require.NoError, }, @@ -83,7 +95,7 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { mustNewReverseTunnel(t, "cluster-b", []string{"addr-b"}), }, wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", ProxyAddr: "addr-b"}}, + {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", Resolver: resolverFn("addr-b")}}, }, assertErr: require.NoError, }, @@ -94,10 +106,10 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { mustNewReverseTunnel(t, "cluster-b", []string{"addr-b"}), }, wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-a"}}, - {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-b"}}, - {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-c"}}, - {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", ProxyAddr: "addr-b"}}, + {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-a")}}, + {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-b")}}, + {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-c")}}, + {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", Resolver: resolverFn("addr-b")}}, }, assertErr: require.NoError, }, @@ -105,10 +117,10 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { desc: "GetReverseTunnels error, keep existing pools", reverseTunnelsErr: errors.New("nah"), wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-a"}}, - {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-b"}}, - {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-c"}}, - {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", ProxyAddr: "addr-b"}}, + {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-a")}}, + {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-b")}}, + {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-c")}}, + {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", Resolver: resolverFn("addr-b")}}, }, assertErr: require.Error, }, @@ -121,16 +133,16 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { }, newAgentPoolErr: errors.New("nah"), wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-a"}}, - {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-b"}}, - {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-c"}}, - {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", ProxyAddr: "addr-b"}}, + {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-a")}}, + {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-b")}}, + {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-c")}}, + {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", Resolver: resolverFn("addr-b")}}, }, assertErr: require.Error, }, } - ctx := context.TODO() + ctx := context.Background() for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { w.cfg.AuthClient = mockAuthClient{ @@ -148,8 +160,18 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { // Tweaks to get comparison working with our complex types. cmp.AllowUnexported(remoteClusterKey{}), cmp.Comparer(func(a, b *AgentPool) bool { + aAddr, aErr := a.cfg.Resolver() + bAddr, bErr := b.cfg.Resolver() + + if aAddr != bAddr && aErr != bErr { + return false + } + // Only check the supplied configs of AgentPools. - return cmp.Equal(a.cfg, b.cfg) + return cmp.Equal( + a.cfg, + b.cfg, + cmpopts.IgnoreFields(AgentPoolConfig{}, "Resolver")) }), )) }) diff --git a/lib/reversetunnel/resolver.go b/lib/reversetunnel/resolver.go new file mode 100644 index 0000000000000..91b7608883123 --- /dev/null +++ b/lib/reversetunnel/resolver.go @@ -0,0 +1,67 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reversetunnel + +import ( + "context" + + "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" +) + +// Resolver looks up reverse tunnel addresses +type Resolver func() (*utils.NetAddr, error) + +// ResolveViaWebClient returns a Resolver which uses the web proxy to +// discover where the SSH reverse tunnel server is running. +func ResolveViaWebClient(ctx context.Context, addrs []utils.NetAddr, insecureTLS bool) Resolver { + return func() (*utils.NetAddr, error) { + var errs []error + for _, addr := range addrs { + // In insecure mode, any certificate is accepted. In secure mode the hosts + // CAs are used to validate the certificate on the proxy. + tunnelAddr, err := webclient.GetTunnelAddr(ctx, addr.String(), insecureTLS, nil) + if err != nil { + errs = append(errs, err) + continue + } + + addr, err := utils.ParseAddr(tunnelAddr) + if err != nil { + errs = append(errs, err) + continue + } + + addr.Addr = utils.ReplaceUnspecifiedHost(addr, defaults.HTTPListenPort) + return addr, nil + } + return nil, trace.NewAggregate(errs...) + } +} + +// StaticResolver returns a Resolver which will always resolve to +// the provided address +func StaticResolver(address string) Resolver { + addr, err := utils.ParseAddr(address) + if err == nil { + addr.Addr = utils.ReplaceUnspecifiedHost(addr, defaults.HTTPListenPort) + } + + return func() (*utils.NetAddr, error) { + return addr, err + } +} diff --git a/lib/reversetunnel/resolver_test.go b/lib/reversetunnel/resolver_test.go new file mode 100644 index 0000000000000..43b809c32a6c0 --- /dev/null +++ b/lib/reversetunnel/resolver_test.go @@ -0,0 +1,116 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reversetunnel + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/lib/utils" + "github.com/stretchr/testify/require" +) + +func TestStaticResolver(t *testing.T) { + cases := []struct { + name string + address string + errorAssertionFn require.ErrorAssertionFunc + expected *utils.NetAddr + }{ + { + name: "invalid address yields error", + address: "", + errorAssertionFn: require.Error, + }, + { + name: "valid address yields NetAddr", + address: "localhost:80", + errorAssertionFn: require.NoError, + expected: &utils.NetAddr{ + Addr: "localhost:80", + AddrNetwork: "tcp", + Path: "", + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr, err := StaticResolver(tt.address)() + tt.errorAssertionFn(t, err) + if err != nil { + return + } + + require.Empty(t, cmp.Diff(tt.expected, addr)) + }) + } +} + +func TestResolveViaWebClient(t *testing.T) { + + fakeAddr := utils.NetAddr{} + + cases := []struct { + name string + addrs []utils.NetAddr + address string + errorAssertionFn require.ErrorAssertionFunc + expected *utils.NetAddr + }{ + { + name: "no addrs yields no results", + errorAssertionFn: require.NoError, + }, + { + name: "unreachable proxy yields errors", + addrs: []utils.NetAddr{fakeAddr}, + address: "", + errorAssertionFn: require.Error, + }, + { + name: "invalid address yields errors", + addrs: []utils.NetAddr{fakeAddr}, + address: "fake://test", + errorAssertionFn: require.Error, + }, + { + name: "valid address yields NetAddr", + addrs: []utils.NetAddr{fakeAddr}, + address: "localhost:80", + errorAssertionFn: require.NoError, + expected: &utils.NetAddr{ + Addr: "localhost:80", + AddrNetwork: "tcp", + Path: "", + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Setenv(defaults.TunnelPublicAddrEnvar, tt.address) + addr, err := ResolveViaWebClient(context.Background(), tt.addrs, true)() + tt.errorAssertionFn(t, err) + if err != nil { + return + } + + require.Empty(t, cmp.Diff(tt.expected, addr)) + }) + } +} diff --git a/lib/reversetunnel/track/tracker.go b/lib/reversetunnel/track/tracker.go index ea4fb94c9d30b..863bfe94f7431 100644 --- a/lib/reversetunnel/track/tracker.go +++ b/lib/reversetunnel/track/tracker.go @@ -22,7 +22,6 @@ import ( "sync" "time" - "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/workpool" "github.com/gravitational/trace" ) @@ -31,7 +30,7 @@ type Lease = workpool.Lease // Config configures basic Tracker parameters. type Config struct { - // ProxyExpiry is the duration an entry will be held sice the last + // ProxyExpiry is the duration an entry will be held since the last // successful connection to, or message about, a given proxy. ProxyExpiry time.Duration // TickRate is the rate at which expired entries are cleared from @@ -41,7 +40,7 @@ type Config struct { ClusterName string } -// SetDefaults set default values for Config. +// CheckAndSetDefaults set default values for Config. func (c *Config) CheckAndSetDefaults() error { if c.ProxyExpiry < 1 { c.ProxyExpiry = 3 * time.Minute @@ -66,7 +65,7 @@ type Tracker struct { Config mu sync.Mutex wp *workpool.Pool - sets map[utils.NetAddr]*proxySet + sets *proxySet cancel context.CancelFunc } @@ -79,7 +78,6 @@ func New(ctx context.Context, c Config) (*Tracker, error) { t := &Tracker{ Config: c, wp: workpool.NewPool(ctx), - sets: make(map[utils.NetAddr]*proxySet), cancel: cancel, } go t.run(ctx) @@ -110,38 +108,38 @@ func (t *Tracker) Acquire() <-chan Lease { func (t *Tracker) TrackExpected(lease Lease, proxies ...string) { t.mu.Lock() defer t.mu.Unlock() - addr := lease.Key().(utils.NetAddr) - set, ok := t.sets[addr] - if !ok { + + if t.sets == nil { return } now := time.Now() for _, name := range proxies { - set.markSeen(now, name) + t.sets.markSeen(now, name) } - count := len(set.proxies) + count := len(t.sets.proxies) if count < 1 { count = 1 } - t.wp.Set(addr, uint64(count)) + t.wp.Set(uint64(count)) } // Start starts tracking for specified proxy address. -func (t *Tracker) Start(addr utils.NetAddr) { +func (t *Tracker) Start() { t.mu.Lock() defer t.mu.Unlock() - t.getOrCreate(addr) + t.getOrCreate() } // Stop stops tracking for specified proxy address. -func (t *Tracker) Stop(addr utils.NetAddr) { +func (t *Tracker) Stop() { t.mu.Lock() defer t.mu.Unlock() - if _, ok := t.sets[addr]; !ok { + if t.sets == nil { return } - delete(t.sets, addr) - t.wp.Set(addr, 0) + + t.sets = nil + t.wp.Set(0) } // StopAll permanently deactivates this tracker and cleans @@ -153,59 +151,59 @@ func (t *Tracker) StopAll() { func (t *Tracker) tick() { t.mu.Lock() defer t.mu.Unlock() + if t.sets == nil { + return + } + cutoff := time.Now().Add(-1 * t.ProxyExpiry) - for addr, set := range t.sets { - if set.expire(cutoff) > 0 { - count := len(set.proxies) - if count < 1 { - count = 1 - } - t.wp.Set(addr, uint64(count)) + if t.sets.expire(cutoff) > 0 { + count := len(t.sets.proxies) + if count < 1 { + count = 1 } + t.wp.Set(uint64(count)) } + } -func (t *Tracker) getOrCreate(addr utils.NetAddr) *proxySet { - if s, ok := t.sets[addr]; ok { - return s +func (t *Tracker) getOrCreate() *proxySet { + if t.sets == nil { + t.sets = newProxySet(t.ClusterName) + t.wp.Set(1) } - set := newProxySet(addr, t.ClusterName) - t.sets[addr] = set - t.wp.Set(addr, 1) - return set + + return t.sets } // WithProxy runs the supplied closure if and only if // no other work is currently being done with the proxy // identified by principals. -func (t *Tracker) WithProxy(work func(), lease Lease, principals ...string) (didWork bool) { - addr := lease.Key().(utils.NetAddr) - if ok := t.claim(addr, principals...); !ok { +func (t *Tracker) WithProxy(work func(), principals ...string) (didWork bool) { + if ok := t.claim(principals...); !ok { return false } - defer t.unclaim(addr, principals...) + defer t.release(principals...) work() return true } -func (t *Tracker) claim(addr utils.NetAddr, principals ...string) (ok bool) { +func (t *Tracker) claim(principals ...string) (ok bool) { t.mu.Lock() defer t.mu.Unlock() - set, ok := t.sets[addr] - if !ok { + if t.sets == nil { return false } - return set.claim(principals...) + return t.sets.claim(principals...) } -func (t *Tracker) unclaim(addr utils.NetAddr, principals ...string) { +func (t *Tracker) release(principals ...string) { t.mu.Lock() defer t.mu.Unlock() - set, ok := t.sets[addr] - if !ok { + if t.sets == nil { return } - set.unclaim(principals...) + + t.sets.release(principals...) } type entry struct { @@ -213,16 +211,14 @@ type entry struct { claimed bool } -func newProxySet(addr utils.NetAddr, clusterName string) *proxySet { +func newProxySet(clusterName string) *proxySet { return &proxySet{ - addr: addr, clusterName: clusterName, proxies: make(map[string]entry), } } type proxySet struct { - addr utils.NetAddr clusterName string proxies map[string]entry } @@ -244,7 +240,7 @@ func (p *proxySet) claim(principals ...string) (ok bool) { return true } -func (p *proxySet) unclaim(principals ...string) { +func (p *proxySet) release(principals ...string) { proxy := p.resolveName(principals) p.proxies[proxy] = entry{ lastSeen: time.Now(), diff --git a/lib/reversetunnel/track/tracker_test.go b/lib/reversetunnel/track/tracker_test.go index fd4dae8f039a2..7736dd27218c9 100644 --- a/lib/reversetunnel/track/tracker_test.go +++ b/lib/reversetunnel/track/tracker_test.go @@ -20,12 +20,12 @@ import ( "context" "fmt" pr "math/rand" + "os" "sync" "testing" "time" - "github.com/gravitational/teleport/lib/utils" - "gopkg.in/check.v1" + "github.com/stretchr/testify/require" ) type simpleTestProxies struct { @@ -102,7 +102,7 @@ func (s *simpleTestProxies) ProxyLoop(tracker *Tracker, lease Lease, proxy testP break Loop } } - }, lease, proxy.principals...) + }, proxy.principals...) return } @@ -133,29 +133,26 @@ func jitter(t time.Duration) time.Duration { return t + j } -func Test(t *testing.T) { +func TestMain(m *testing.M) { pr.Seed(time.Now().UnixNano()) - check.TestingT(t) + os.Exit(m.Run()) } -type StateSuite struct{} +func TestBasic(t *testing.T) { + const ( + timeout = time.Second * 16 + proxyCount = 16 + ) -var _ = check.Suite(&StateSuite{}) + ctx := context.Background() -func (s *StateSuite) TestBasic(c *check.C) { - s.runBasicProxyTest(c, time.Second*16) -} - -func (s *StateSuite) runBasicProxyTest(c *check.C, timeout time.Duration) { - const proxyCount = 16 timeoutC := time.After(timeout) ticker := time.NewTicker(time.Millisecond * 100) - defer ticker.Stop() - tracker, err := New(context.TODO(), Config{ClusterName: "test-cluster"}) - c.Assert(err, check.IsNil) - defer tracker.StopAll() - addr := utils.NetAddr{Addr: "test-cluster"} - tracker.Start(addr) + t.Cleanup(ticker.Stop) + tracker, err := New(ctx, Config{ClusterName: "test-cluster"}) + require.NoError(t, err) + t.Cleanup(tracker.StopAll) + tracker.Start() min, max := time.Duration(0), timeout var proxies simpleTestProxies proxies.AddRandProxies(proxyCount, min, max) @@ -165,18 +162,18 @@ Discover: case lease := <-tracker.Acquire(): go proxies.Discover(tracker, lease) case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts: %+v", counts) if counts.Active == proxyCount { break Discover } case <-timeoutC: - c.Fatal("timeout") + t.Fatal("timeout") } } } -func (s *StateSuite) TestFullRotation(c *check.C) { +func TestFullRotation(t *testing.T) { const ( proxyCount = 8 minConnA = time.Second * 2 @@ -185,37 +182,37 @@ func (s *StateSuite) TestFullRotation(c *check.C) { maxConnB = time.Second * 25 timeout = time.Second * 30 ) + + ctx := context.Background() ticker := time.NewTicker(time.Millisecond * 100) - defer ticker.Stop() + t.Cleanup(ticker.Stop) var proxies simpleTestProxies proxies.AddRandProxies(proxyCount, minConnA, maxConnA) - tracker, err := New(context.TODO(), Config{ClusterName: "test-cluster"}) - c.Assert(err, check.IsNil) - defer tracker.StopAll() - addr := utils.NetAddr{Addr: "test-cluster"} - tracker.Start(addr) + tracker, err := New(ctx, Config{ClusterName: "test-cluster"}) + require.NoError(t, err) + t.Cleanup(tracker.StopAll) + tracker.Start() timeoutC := time.After(timeout) Loop0: for { select { case lease := <-tracker.Acquire(): - c.Assert(lease.Key(), check.DeepEquals, addr) // get our "discovered" proxy in the foreground // to prevent race with the call to RemoveRandProxies // that comes after this loop. proxy, ok := proxies.GetRandProxy() if !ok { - c.Fatal("failed to get test proxy") + t.Fatal("failed to get test proxy") } go proxies.ProxyLoop(tracker, lease, proxy) case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts0: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts0: %+v", counts) if counts.Active == proxyCount { break Loop0 } case <-timeoutC: - c.Fatal("timeout") + t.Fatal("timeout") } } proxies.RemoveRandProxies(proxyCount) @@ -223,13 +220,13 @@ Loop1: for { select { case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts1: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts1: %+v", counts) if counts.Active < 1 { break Loop1 } case <-timeoutC: - c.Fatal("timeout") + t.Fatal("timeout") } } proxies.AddRandProxies(proxyCount, minConnB, maxConnB) @@ -239,13 +236,13 @@ Loop2: case lease := <-tracker.Acquire(): go proxies.Discover(tracker, lease) case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts2: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts2: %+v", counts) if counts.Active >= proxyCount { break Loop2 } case <-timeoutC: - c.Fatal("timeout") + t.Fatal("timeout") } } } @@ -253,41 +250,40 @@ Loop2: // TestUUIDHandling verifies that host UUIDs are correctly extracted // from the expected teleport principal format, and that gossip messages // consisting only of uuid don't create duplicate entries. -func (s *StateSuite) TestUUIDHandling(c *check.C) { - ctx, cancel := context.WithTimeout(context.TODO(), time.Second*6) - defer cancel() +func TestUUIDHandling(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*6) + t.Cleanup(cancel) ticker := time.NewTicker(time.Millisecond * 10) - defer ticker.Stop() - tracker, err := New(context.TODO(), Config{ClusterName: "test-cluster"}) - c.Assert(err, check.IsNil) - defer tracker.StopAll() - addr := utils.NetAddr{Addr: "test-cluster"} - tracker.Start(addr) + t.Cleanup(ticker.Stop) + tracker, err := New(context.Background(), Config{ClusterName: "test-cluster"}) + require.NoError(t, err) + t.Cleanup(tracker.StopAll) + tracker.Start() lease := <-tracker.Acquire() // claim a proxy using principal of the form . go tracker.WithProxy(func() { - c.Logf("Successfully claimed proxy") + t.Logf("Successfully claimed proxy") <-ctx.Done() - }, lease, "my-proxy.test-cluster") + }, "my-proxy.test-cluster") // Wait for proxy to be claimed Wait: for { select { case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts: %+v", counts) if counts.Active == counts.Target { break Wait } case <-ctx.Done(): - c.Errorf("pool never reached expected state") + t.Errorf("pool never reached expected state") } } // Send a gossip message containing host UUID only tracker.TrackExpected(lease, "my-proxy") - c.Logf("Sent uuid-only gossip message; watching status...") + t.Logf("Sent uuid-only gossip message; watching status...") // Let pool go through a few ticks, monitoring status to ensure that // we don't incorrectly enter seek mode (entering seek mode here would @@ -296,13 +292,13 @@ Wait: for i := 0; i < 3; i++ { select { case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts: %+v", counts) if counts.Active != counts.Target { - c.Errorf("incorrectly entered seek mode") + t.Errorf("incorrectly entered seek mode") } case <-ctx.Done(): - c.Errorf("timeout") + t.Errorf("timeout") } } } diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index bf6f42bdbfbd0..a2d5cbc0a1117 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -34,7 +34,6 @@ import ( "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/proxy" @@ -56,8 +55,8 @@ func NewTunnelAuthDialer(config TunnelAuthDialerConfig) (*TunnelAuthDialer, erro // TunnelAuthDialerConfig specifies TunnelAuthDialer configuration. type TunnelAuthDialerConfig struct { - // ProxyAddr is the address of the proxy - ProxyAddr string + // Resolver retrieves the address of the proxy + Resolver Resolver // ClientConfig is SSH tunnel client config ClientConfig *ssh.ClientConfig // Log is used for logging. @@ -65,14 +64,9 @@ type TunnelAuthDialerConfig struct { } func (c *TunnelAuthDialerConfig) CheckAndSetDefaults() error { - if c.ProxyAddr == "" { - return trace.BadParameter("missing proxy address") + if c.Resolver == nil { + return trace.BadParameter("missing tunnel address resolver") } - parsedAddr, err := utils.ParseAddr(c.ProxyAddr) - if err != nil { - return trace.Wrap(err) - } - c.ProxyAddr = utils.ReplaceUnspecifiedHost(parsedAddr, defaults.HTTPListenPort) return nil } @@ -83,22 +77,28 @@ type TunnelAuthDialer struct { } // DialContext dials auth server via SSH tunnel -func (t *TunnelAuthDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { +func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { // Connect to the reverse tunnel server. var opts []proxy.DialerOptionFunc + addr, err := t.Resolver() + if err != nil { + t.Log.Errorf("Failed to resolve tunnel address %v", err) + return nil, trace.Wrap(err) + } + // Check if t.ProxyAddr is ProxyWebPort and remote Proxy supports TLS ALPNSNIListener. - resp, err := webclient.Find(ctx, t.ProxyAddr, lib.IsInsecureDevMode(), nil) + resp, err := webclient.Find(ctx, addr.Addr, lib.IsInsecureDevMode(), nil) if err != nil { // If TLS Routing is disabled the address is the proxy reverse tunnel // address thus the ping call will always fail. - t.Log.Debugf("Failed to ping web proxy %q addr: %v", t.ProxyAddr, err) + t.Log.Debugf("Failed to ping web proxy %q addr: %v", addr.Addr, err) } else if resp.Proxy.TLSRoutingEnabled { opts = append(opts, proxy.WithALPNDialer()) } - dialer := proxy.DialerFromEnvironment(t.ProxyAddr, opts...) - sconn, err := dialer.Dial("tcp", t.ProxyAddr, t.ClientConfig) + dialer := proxy.DialerFromEnvironment(addr.Addr, opts...) + sconn, err := dialer.Dial(addr.AddrNetwork, addr.Addr, t.ClientConfig) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/service/connect.go b/lib/service/connect.go index f23fe44c070f0..b691db1859e45 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -21,15 +21,13 @@ import ( "math" "path/filepath" - "golang.org/x/crypto/ssh" - "github.com/coreos/go-semver/semver" "github.com/gravitational/roundtrip" "github.com/gravitational/teleport" + "golang.org/x/crypto/ssh" apiclient "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth" @@ -485,28 +483,13 @@ func (process *TeleportProcess) periodicSyncRotationState() error { }) defer periodic.Stop() - errors := utils.NewTimedCounter(process.Clock, process.Config.RestartThreshold.Time) - for { err := process.syncRotationStateCycle() if err == nil { return nil } - process.log.WithError(err).Warning("Sync rotation state cycle failed") - - // If we have had a *lot* of failures very recently, then it's likely that our - // route to the auth server is gone. If we're using a tunnel then it's possible - // that the proxy has been reconfigured and the tunnel address has moved. - count := errors.Increment() - process.log.Warnf("%d connection errors in last %v.", count, process.Config.RestartThreshold.Time) - if count > process.Config.RestartThreshold.Amount { - // signal quit - process.log.Error("Connection error threshold exceeded. Asking for a graceful restart.") - process.BroadcastEvent(Event{Name: TeleportReloadEvent}) - return nil - } - process.log.Warningf("Retrying in ~%v", process.Config.RotationConnectionInterval) + process.log.Warningf("Sync rotation state cycle failed. Retrying in ~%v", process.Config.RotationConnectionInterval) select { case <-periodic.Next(): @@ -881,20 +864,12 @@ func (process *TeleportProcess) newClient(authServers []utils.NetAddr, identity logger.Debug("Attempting to discover reverse tunnel address.") - proxyAddr, err := process.findReverseTunnel(authServers) - if err != nil { - directErrLogger.Debug("Failed to connect to Auth Server directly.") - logger.WithError(err).Debug("Failed to discover reverse tunnel address.") - return nil, trace.Errorf("Failed to connect to Auth Server directly or over tunnel, no methods remaining.") - } - - logger = process.log.WithField("proxy-addr", proxyAddr) logger.Debug("Attempting to connect to Auth Server through tunnel.") sshClientConfig, err := identity.SSHClientConfig(process.Config.FIPS) if err != nil { return nil, trace.Wrap(err) } - tunnelClient, err := process.newClientThroughTunnel(proxyAddr, tlsConfig, sshClientConfig) + tunnelClient, err := process.newClientThroughTunnel(authServers, tlsConfig, sshClientConfig) if err != nil { directErrLogger.Debug("Failed to connect to Auth Server directly.") logger.WithError(err).Debug("Failed to connect to Auth Server through tunnel.") @@ -905,25 +880,9 @@ func (process *TeleportProcess) newClient(authServers []utils.NetAddr, identity return tunnelClient, nil } -// findReverseTunnel uses the web proxy to discover where the SSH reverse tunnel -// server is running. -func (process *TeleportProcess) findReverseTunnel(addrs []utils.NetAddr) (string, error) { - var errs []error - for _, addr := range addrs { - // In insecure mode, any certificate is accepted. In secure mode the hosts - // CAs are used to validate the certificate on the proxy. - tunnelAddr, err := webclient.GetTunnelAddr(process.ExitContext(), addr.String(), lib.IsInsecureDevMode(), nil) - if err == nil { - return tunnelAddr, nil - } - errs = append(errs, err) - } - return "", trace.NewAggregate(errs...) -} - -func (process *TeleportProcess) newClientThroughTunnel(proxyAddr string, tlsConfig *tls.Config, sshConfig *ssh.ClientConfig) (*auth.Client, error) { +func (process *TeleportProcess) newClientThroughTunnel(authServers []utils.NetAddr, tlsConfig *tls.Config, sshConfig *ssh.ClientConfig) (*auth.Client, error) { dialer, err := reversetunnel.NewTunnelAuthDialer(reversetunnel.TunnelAuthDialerConfig{ - ProxyAddr: proxyAddr, + Resolver: reversetunnel.ResolveViaWebClient(process.ExitContext(), authServers, lib.IsInsecureDevMode()), ClientConfig: sshConfig, Log: process.log, }) diff --git a/lib/service/db.go b/lib/service/db.go index 7054fb7622dca..5a15fa5d54627 100644 --- a/lib/service/db.go +++ b/lib/service/db.go @@ -69,14 +69,17 @@ func (process *TeleportProcess) initDatabaseService() (retErr error) { return trace.Wrap(err) } - var tunnelAddr string - if conn.TunnelProxy() != "" { - tunnelAddr = conn.TunnelProxy() - } else { - if tunnelAddr, ok = process.singleProcessMode(resp.GetProxyListenerMode()); !ok { - return trace.BadParameter("failed to find reverse tunnel address, " + - "if running in a single-process mode, make sure auth_service, " + - "proxy_service, and db_service are all enabled") + tunnelAddrResolver := conn.TunnelProxyResolver() + if tunnelAddrResolver == nil { + tunnelAddrResolver = func() (*utils.NetAddr, error) { + addr, ok := process.singleProcessMode(resp.GetProxyListenerMode()) + if !ok { + return nil, trace.BadParameter("failed to find reverse tunnel address, " + + "if running in a single-process mode, make sure auth_service, " + + "proxy_service, and db_service are all enabled") + } + + return addr, nil } } @@ -208,11 +211,12 @@ func (process *TeleportProcess) initDatabaseService() (retErr error) { }() // Create and start the agent pool. - agentPool, err := reversetunnel.NewAgentPool(process.ExitContext(), + agentPool, err := reversetunnel.NewAgentPool( + process.ExitContext(), reversetunnel.AgentPoolConfig{ Component: teleport.ComponentDatabase, HostUUID: conn.ServerIdentity.ID.HostUUID, - ProxyAddr: tunnelAddr, + Resolver: tunnelAddrResolver, Client: conn.Client, Server: dbService, AccessPoint: conn.Client, diff --git a/lib/service/desktop.go b/lib/service/desktop.go index 7cde2b1f3fd05..5b3be1ac3528a 100644 --- a/lib/service/desktop.go +++ b/lib/service/desktop.go @@ -83,6 +83,7 @@ func (process *TeleportProcess) initWindowsDesktopServiceRegistered(log *logrus. return trace.Wrap(err) } + useTunnel := conn.UseTunnel() // This service can run in 2 modes: // 1. Reachable (by the proxy) - registers with auth server directly and // creates a local listener to accept proxy conns. @@ -95,13 +96,13 @@ func (process *TeleportProcess) initWindowsDesktopServiceRegistered(log *logrus. switch { // Filter out cases where both listen_addr and tunnel are set or both are // not set. - case conn.UseTunnel() && !cfg.WindowsDesktop.ListenAddr.IsEmpty(): + case useTunnel && !cfg.WindowsDesktop.ListenAddr.IsEmpty(): return trace.BadParameter("either set windows_desktop_service.listen_addr if this process can be reached from a teleport proxy or point teleport.auth_servers to a proxy to dial out, but don't set both") - case !conn.UseTunnel() && cfg.WindowsDesktop.ListenAddr.IsEmpty(): + case !useTunnel && cfg.WindowsDesktop.ListenAddr.IsEmpty(): return trace.BadParameter("set windows_desktop_service.listen_addr if this process can be reached from a teleport proxy or point teleport.auth_servers to a proxy to dial out") // Start a local listener and let proxies dial in. - case !conn.UseTunnel() && !cfg.WindowsDesktop.ListenAddr.IsEmpty(): + case !useTunnel && !cfg.WindowsDesktop.ListenAddr.IsEmpty(): log.Info("Using local listener and registering directly with auth server") listener, err = process.importOrCreateListener(listenerWindowsDesktop, cfg.WindowsDesktop.ListenAddr.Addr) if err != nil { @@ -114,7 +115,7 @@ func (process *TeleportProcess) initWindowsDesktopServiceRegistered(log *logrus. }() // Dialed out to a proxy, start servicing the reverse tunnel as a listener. - case conn.UseTunnel() && cfg.WindowsDesktop.ListenAddr.IsEmpty(): + case useTunnel && cfg.WindowsDesktop.ListenAddr.IsEmpty(): // create an adapter, from reversetunnel.ServerHandler to net.Listener. shtl := reversetunnel.NewServerHandlerToListener(reversetunnel.LocalWindowsDesktop) listener = shtl @@ -123,7 +124,7 @@ func (process *TeleportProcess) initWindowsDesktopServiceRegistered(log *logrus. reversetunnel.AgentPoolConfig{ Component: teleport.ComponentWindowsDesktop, HostUUID: conn.ServerIdentity.ID.HostUUID, - ProxyAddr: conn.TunnelProxy(), + Resolver: conn.TunnelProxyResolver(), Client: conn.Client, AccessPoint: accessPoint, HostSigner: conn.ServerIdentity.KeySigner, @@ -194,7 +195,7 @@ func (process *TeleportProcess) initWindowsDesktopServiceRegistered(log *logrus. } var publicAddr string switch { - case conn.UseTunnel(): + case useTunnel: publicAddr = listener.Addr().String() case len(cfg.WindowsDesktop.PublicAddrs) > 0: publicAddr = cfg.WindowsDesktop.PublicAddrs[0].String() @@ -234,7 +235,7 @@ func (process *TeleportProcess) initWindowsDesktopServiceRegistered(log *logrus. } }() process.RegisterCriticalFunc("windows_desktop.serve", func() error { - if conn.UseTunnel() { + if useTunnel { log.Info("Starting Windows desktop service via proxy reverse tunnel.") utils.Consolef(cfg.Console, log, teleport.ComponentWindowsDesktop, "Windows desktop service %s:%s is starting via proxy reverse tunnel.", diff --git a/lib/service/kubernetes.go b/lib/service/kubernetes.go index 265e26d48b6ad..70d4cca899d50 100644 --- a/lib/service/kubernetes.go +++ b/lib/service/kubernetes.go @@ -137,7 +137,7 @@ func (process *TeleportProcess) initKubernetesService(log *logrus.Entry, conn *C reversetunnel.AgentPoolConfig{ Component: teleport.ComponentKube, HostUUID: conn.ServerIdentity.ID.HostUUID, - ProxyAddr: conn.TunnelProxy(), + Resolver: conn.TunnelProxyResolver(), Client: conn.Client, AccessPoint: accessPoint, HostSigner: conn.ServerIdentity.KeySigner, diff --git a/lib/service/service.go b/lib/service/service.go index c07d6aac43c3c..0fb969d7354d8 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -230,23 +230,25 @@ type Connector struct { Client *auth.Client } -// TunnelProxy if non-empty, indicates that the client is connected to the Auth Server +// TunnelProxyResolver if non-nil, indicates that the client is connected to the Auth Server // through the reverse SSH tunnel proxy -func (c *Connector) TunnelProxy() string { +func (c *Connector) TunnelProxyResolver() reversetunnel.Resolver { if c.Client == nil || c.Client.Dialer() == nil { - return "" + return nil } - tun, ok := c.Client.Dialer().(*reversetunnel.TunnelAuthDialer) - if !ok { - return "" + + switch dialer := c.Client.Dialer().(type) { + case *reversetunnel.TunnelAuthDialer: + return dialer.Resolver + default: + return nil } - return tun.ProxyAddr } // UseTunnel indicates if the client is connected directly to the Auth Server // (false) or through the proxy (true). func (c *Connector) UseTunnel() bool { - return c.TunnelProxy() != "" + return c.TunnelProxyResolver() != nil } // Close closes resources associated with connector @@ -1999,7 +2001,7 @@ func (process *TeleportProcess) initSSH() error { reversetunnel.AgentPoolConfig{ Component: teleport.ComponentNode, HostUUID: conn.ServerIdentity.ID.HostUUID, - ProxyAddr: conn.TunnelProxy(), + Resolver: conn.TunnelProxyResolver(), Client: conn.Client, AccessPoint: conn.Client, HostSigner: conn.ServerIdentity.KeySigner, @@ -3523,12 +3525,14 @@ func (process *TeleportProcess) initApps() { // If it was not, it is running in single process mode which is used for // development and demos. In that case, wait until all dependencies (like // auth and reverse tunnel server) are ready before starting. - var tunnelAddr string - if conn.TunnelProxy() != "" { - tunnelAddr = conn.TunnelProxy() - } else { - if tunnelAddr, ok = process.singleProcessMode(resp.GetProxyListenerMode()); !ok { - return trace.BadParameter(`failed to find reverse tunnel address, if running in single process mode, make sure "auth_service", "proxy_service", and "app_service" are all enabled`) + tunnelAddrResolver := conn.TunnelProxyResolver() + if tunnelAddrResolver == nil { + tunnelAddrResolver = func() (*utils.NetAddr, error) { + addr, ok := process.singleProcessMode(resp.GetProxyListenerMode()) + if !ok { + return nil, trace.BadParameter(`failed to find reverse tunnel address, if running in single process mode, make sure "auth_service", "proxy_service", and "app_service" are all enabled`) + } + return addr, nil } // Block and wait for all dependencies to start before starting. @@ -3651,11 +3655,12 @@ func (process *TeleportProcess) initApps() { } // Create and start an agent pool. - agentPool, err = reversetunnel.NewAgentPool(process.ExitContext(), + agentPool, err = reversetunnel.NewAgentPool( + process.ExitContext(), reversetunnel.AgentPoolConfig{ Component: teleport.ComponentApp, HostUUID: conn.ServerIdentity.ID.HostUUID, - ProxyAddr: tunnelAddr, + Resolver: tunnelAddrResolver, Client: conn.Client, Server: appServer, AccessPoint: accessPoint, @@ -3912,29 +3917,35 @@ func (process *TeleportProcess) initDebugApp() { // singleProcessMode returns true when running all components needed within // the same process. It's used for development and demo purposes. -func (process *TeleportProcess) singleProcessMode(mode types.ProxyListenerMode) (string, bool) { +func (process *TeleportProcess) singleProcessMode(mode types.ProxyListenerMode) (*utils.NetAddr, bool) { if !process.Config.Proxy.Enabled || !process.Config.Auth.Enabled { - return "", false + return nil, false } if process.Config.Proxy.DisableReverseTunnel { - return "", false + return nil, false } if !process.Config.Proxy.DisableTLS && !process.Config.Proxy.DisableALPNSNIListener && mode == types.ProxyListenerMode_Multiplex { if len(process.Config.Proxy.PublicAddrs) != 0 { - return process.Config.Proxy.PublicAddrs[0].String(), true + return &process.Config.Proxy.PublicAddrs[0], true } // If WebAddress is unspecified "0.0.0.0" replace 0.0.0.0 with localhost since 0.0.0.0 is never a valid // principal (auth server explicitly removes it when issuing host certs) and when WebPort is used // in the single process mode to establish SSH reverse tunnel connection the host is validated against // the valid principal list. - return utils.ReplaceUnspecifiedHost(&process.Config.Proxy.WebAddr, defaults.HTTPListenPort), true + addr := process.Config.Proxy.WebAddr + addr.Addr = utils.ReplaceUnspecifiedHost(&addr, defaults.HTTPListenPort) + return &addr, true } if len(process.Config.Proxy.TunnelPublicAddrs) == 0 { - return net.JoinHostPort(string(teleport.PrincipalLocalhost), strconv.Itoa(defaults.SSHProxyTunnelListenPort)), true + addr, err := utils.ParseHostPortAddr(string(teleport.PrincipalLocalhost), defaults.SSHProxyTunnelListenPort) + if err != nil { + return nil, false + } + return addr, true } - return process.Config.Proxy.TunnelPublicAddrs[0].String(), true + return &process.Config.Proxy.TunnelPublicAddrs[0], true } // dumperHandler is an Application Access debugging application that will diff --git a/lib/utils/addr.go b/lib/utils/addr.go index 1580667818ee4..6cdaa16b8f0a9 100644 --- a/lib/utils/addr.go +++ b/lib/utils/addr.go @@ -344,7 +344,7 @@ func guessHostIP(addrs []net.Addr) (ip net.IP) { return ip } -// ReplaceUnspecifiedHost replaces unspecified "0.0.0.0" host localhost since 0.0.0.0 is never a valid +// ReplaceUnspecifiedHost replaces unspecified "0.0.0.0" with localhost since "0.0.0.0" is never a valid // principal (auth server explicitly removes it when issuing host certs) and when a reverse tunnel client used // establishes SSH reverse tunnel connection the host is validated against // the valid principal list. diff --git a/lib/utils/timed_counter.go b/lib/utils/timed_counter.go index 2798b7cd1f62f..c83ff15339a92 100644 --- a/lib/utils/timed_counter.go +++ b/lib/utils/timed_counter.go @@ -32,7 +32,7 @@ type TimedCounter struct { events []time.Time } -// TimedCounted creates a new timed counter with the specified timeout +// NewTimedCounter creates a new timed counter with the specified timeout func NewTimedCounter(clock clockwork.Clock, timeout time.Duration) *TimedCounter { return &TimedCounter{ clock: clock, diff --git a/lib/utils/workpool/workpool.go b/lib/utils/workpool/workpool.go index 8236ac91e1726..12801fab6d4b2 100644 --- a/lib/utils/workpool/workpool.go +++ b/lib/utils/workpool/workpool.go @@ -30,7 +30,7 @@ import ( type Pool struct { mu sync.Mutex leaseIDs *atomic.Uint64 - groups map[interface{}]*group + groups *group // grantC is an unbuffered channel that funnels available leases from the // workgroups to the outside world grantC chan Lease @@ -42,7 +42,6 @@ func NewPool(ctx context.Context) *Pool { ctx, cancel := context.WithCancel(ctx) return &Pool{ leaseIDs: atomic.NewUint64(0), - groups: make(map[interface{}]*group), grantC: make(chan Lease), ctx: ctx, cancel: cancel, @@ -65,34 +64,37 @@ func (p *Pool) Done() <-chan struct{} { } // Get gets the current counts for the specified key. -func (p *Pool) Get(key interface{}) Counts { +func (p *Pool) Get() Counts { p.mu.Lock() defer p.mu.Unlock() - if g, ok := p.groups[key]; ok { - return g.loadCounts() + + if p.groups == nil { + return Counts{} } - return Counts{} + + return p.groups.loadCounts() } // Set sets the target for the specified key. -func (p *Pool) Set(key interface{}, target uint64) { +func (p *Pool) Set(target uint64) { p.mu.Lock() defer p.mu.Unlock() if target < 1 { - p.del(key) + p.del() return } - g, ok := p.groups[key] - if !ok { - p.start(key, target) + + if p.groups == nil { + p.start(target) return } - g.setTarget(target) + + p.groups.setTarget(target) } // Start starts a new work group with the specified initial target. // If Start returns false, the group already exists. -func (p *Pool) start(key interface{}, target uint64) { +func (p *Pool) start(target uint64) { ctx, cancel := context.WithCancel(p.ctx) notifyC := make(chan struct{}, 1) g := &group{ @@ -101,13 +103,12 @@ func (p *Pool) start(key interface{}, target uint64) { Target: target, }, leaseIDs: p.leaseIDs, - key: key, grantC: p.grantC, notifyC: notifyC, ctx: ctx, cancel: cancel, } - p.groups[key] = g + p.groups = g // Start a routine to monitor the group's lease acquisition // and handle notifications when a lease is returned to the @@ -115,13 +116,13 @@ func (p *Pool) start(key interface{}, target uint64) { go g.run() } -func (p *Pool) del(key interface{}) (ok bool) { - group, ok := p.groups[key] - if !ok { +func (p *Pool) del() (ok bool) { + if p.groups == nil { return false } - group.cancel() - delete(p.groups, key) + + p.groups.cancel() + p.groups = nil return true } @@ -146,7 +147,6 @@ type group struct { cmu sync.Mutex counts Counts leaseIDs *atomic.Uint64 - key interface{} grantC chan Lease notifyC chan struct{} ctx context.Context @@ -231,7 +231,7 @@ func (g *group) run() { // is called (usually by a lease being returned), or the context is // canceled. // - // Otherwise we post the lease to the outside world and block on + // Otherwise, we post the lease to the outside world and block on // someone picking it up (or cancellation) select { case grant <- nextLease: @@ -267,11 +267,6 @@ func (l Lease) ID() uint64 { return l.id } -// Key returns the key that this lease is associated with. -func (l Lease) Key() interface{} { - return l.key -} - // IsZero checks if this is the zero value of Lease. func (l Lease) IsZero() bool { return l == Lease{} diff --git a/lib/utils/workpool/workpool_test.go b/lib/utils/workpool/workpool_test.go index dac46a41fc914..d50510350cef3 100644 --- a/lib/utils/workpool/workpool_test.go +++ b/lib/utils/workpool/workpool_test.go @@ -18,65 +18,21 @@ package workpool import ( "context" - "fmt" "sync" "testing" "time" - - "gopkg.in/check.v1" ) -func Example() { - pool := NewPool(context.TODO()) - defer pool.Stop() - // create two keys with different target counts - pool.Set("spam", 2) - pool.Set("eggs", 1) - // track how many workers are spawned for each key - counts := make(map[string]int) - var mu sync.Mutex - var wg sync.WaitGroup - for i := 0; i < 12; i++ { - wg.Add(1) - go func() { - lease := <-pool.Acquire() - defer lease.Release() - mu.Lock() - counts[lease.Key().(string)]++ - mu.Unlock() - // in order to demonstrate the differing spawn rates we need - // work to take some time, otherwise pool will end up granting - // leases in a "round robin" fashion. - time.Sleep(time.Millisecond * 10) - wg.Done() - }() - } - wg.Wait() - // exact counts will vary, but leases with key `spam` - // will end up being generated approximately twice as - // often as leases with key `eggs`. - fmt.Println(counts["spam"] > counts["eggs"]) // Output: true -} - -func Test(t *testing.T) { - check.TestingT(t) -} - -type WorkSuite struct{} - -var _ = check.Suite(&WorkSuite{}) - // TestFull runs a pool though a round of normal usage, // and verifies expected state along the way: -// - A group of workers acquire leases, do some work, and release them. -// - A second group of workers receieve leases as the first group finishes. -// - The expected amout of leases are in play after this churn. +// - A group of workers acquires leases, do some work, and release them. +// - A second group of workers receives leases as the first group finishes. +// - The expected amount of leases is in play after this churn. // - Updating the target lease count has the expected effect. -func (s *WorkSuite) TestFull(c *check.C) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() +func TestFull(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) p := NewPool(ctx) - key := "some-key" var wg sync.WaitGroup // signal channel to cause the first group of workers to // release their leases. @@ -88,7 +44,7 @@ func (s *WorkSuite) TestFull(c *check.C) { time.Sleep(time.Millisecond * 500) close(g1timeout) }() - p.Set(key, 200) + p.Set(200) // spawn first group of workers. for i := 0; i < 200; i++ { wg.Add(1) @@ -98,7 +54,7 @@ func (s *WorkSuite) TestFull(c *check.C) { <-g1done l.Release() case <-g1timeout: - c.Errorf("Timeout waiting for lease") + t.Errorf("Timeout waiting for lease") } wg.Done() }() @@ -107,7 +63,7 @@ func (s *WorkSuite) TestFull(c *check.C) { // no additional leases should exist select { case l := <-p.Acquire(): - c.Errorf("unexpected lease: %+v", l) + t.Errorf("unexpected lease: %+v", l) default: } // spawn a second group of workers that won't be able to @@ -119,7 +75,7 @@ func (s *WorkSuite) TestFull(c *check.C) { case <-p.Acquire(): // leak deliberately case <-time.After(time.Millisecond * 512): - c.Errorf("Timeout waiting for lease") + t.Errorf("Timeout waiting for lease") } wg.Done() }() @@ -132,46 +88,43 @@ func (s *WorkSuite) TestFull(c *check.C) { select { case l := <-p.Acquire(): counts := l.loadCounts() - c.Errorf("unexpected lease grant: %+v, counts=%+v", l, counts) + t.Errorf("unexpected lease grant: %+v, counts=%+v", l, counts) case <-time.After(time.Millisecond * 128): } // make one additional lease available - p.Set(key, 201) + p.Set(201) select { case l := <-p.Acquire(): - c.Assert(l.Key().(string), check.Equals, key) l.Release() case <-time.After(time.Millisecond * 128): - c.Errorf("timeout waiting for lease grant") + t.Errorf("timeout waiting for lease grant") } } -// TestZeroed varifies that a zeroed pool stops granting +// TestZeroed verifies that a zeroed pool stops granting // leases as expected. -func (s *WorkSuite) TestZeroed(c *check.C) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() +func TestZeroed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) p := NewPool(ctx) - key := "some-key" - p.Set(key, 1) + p.Set(1) var l Lease select { case l = <-p.Acquire(): - c.Assert(l.Key().(string), check.Equals, key) l.Release() case <-time.After(time.Millisecond * 128): - c.Errorf("timeout waiting for lease grant") + t.Errorf("timeout waiting for lease grant") } - p.Set(key, 0) + p.Set(0) // modifications to counts are *ordered*, but asynchronous, - // so we could actually receieve a lease here if we don't sleep + // so we could actually receive a lease here if we don't sleep // briefly. if we opted for condvars instead of channels, this // issue could be avoided at the cost of more cumbersome // composition/cancellation. time.Sleep(time.Millisecond * 10) select { case l := <-p.Acquire(): - c.Errorf("unexpected lease grant: %+v", l) + t.Errorf("unexpected lease grant: %+v", l) case <-time.After(time.Millisecond * 128): } } diff --git a/tool/tctl/common/tctl.go b/tool/tctl/common/tctl.go index 05e046a013249..374f346ff8f3c 100644 --- a/tool/tctl/common/tctl.go +++ b/tool/tctl/common/tctl.go @@ -27,7 +27,6 @@ import ( "github.com/gravitational/teleport" apiclient "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth" @@ -224,17 +223,10 @@ func connectToAuthService(ctx context.Context, cfg *service.Config, clientConfig errs := []error{err} - // Figure out the reverse tunnel address on the proxy first. - tunAddr, err := findReverseTunnel(ctx, cfg.AuthServers, clientConfig.TLS.InsecureSkipVerify) - if err != nil { - errs = append(errs, trace.Wrap(err, "failed lookup of proxy reverse tunnel address: %v", err)) - return nil, trace.NewAggregate(errs...) - } - log.Debugf("Attempting to connect using reverse tunnel address %v.", tunAddr) // reversetunnel.TunnelAuthDialer will take care of creating a net.Conn // within an SSH tunnel. dialer, err := reversetunnel.NewTunnelAuthDialer(reversetunnel.TunnelAuthDialerConfig{ - ProxyAddr: tunAddr, + Resolver: reversetunnel.ResolveViaWebClient(ctx, cfg.AuthServers, clientConfig.TLS.InsecureSkipVerify), ClientConfig: clientConfig.SSH, Log: cfg.Log, }) @@ -260,22 +252,6 @@ func connectToAuthService(ctx context.Context, cfg *service.Config, clientConfig return client, nil } -// findReverseTunnel uses the web proxy to discover where the SSH reverse tunnel -// server is running. -func findReverseTunnel(ctx context.Context, addrs []utils.NetAddr, insecureTLS bool) (string, error) { - var errs []error - for _, addr := range addrs { - // In insecure mode, any certificate is accepted. In secure mode the hosts - // CAs are used to validate the certificate on the proxy. - tunnelAddr, err := webclient.GetTunnelAddr(ctx, addr.String(), insecureTLS, nil) - if err == nil { - return tunnelAddr, nil - } - errs = append(errs, err) - } - return "", trace.NewAggregate(errs...) -} - // applyConfig takes configuration values from the config file and applies // them to 'service.Config' object. //