diff --git a/integration/restart_test.go b/integration/restart_test.go deleted file mode 100644 index 67c359c84f143..0000000000000 --- a/integration/restart_test.go +++ /dev/null @@ -1,111 +0,0 @@ -/* -Copyright 2021 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package integration - -import ( - "context" - "testing" - "time" - - "github.com/gravitational/teleport/lib" - "github.com/gravitational/teleport/lib/auth/testauthority" - "github.com/gravitational/teleport/lib/service" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" -) - -// TestLostConnectionToAuthCausesReload tests that a lost connection to the auth server -// will eventually restart a node -func TestLostConnectionToAuthCausesReload(t *testing.T) { - // Because testing that the node does a full restart is a bit tricky when - // running a cluster from inside a test runner (i.e. we don't want to - // SIGTERM the test runner), we will watch for the node emitting a - // `TeleportReload` even instead. In a proper Teleport instance, this - // event would be picked up at the Supervisor level and would eventually - // cause the instance to gracefully restart. - - require := require.New(t) - log := log.StandardLogger() - - log.Info(">>> Entering Test") - - // InsecureDevMode needed for SSH connections - // TODO(tcsc): surface this as per-server config (see also issue #8913) - lib.SetInsecureDevMode(true) - defer lib.SetInsecureDevMode(false) - - // GIVEN a cluster with a running auth+proxy instance.... - log.Info(">>> Creating cluster") - keygen := testauthority.New() - privateKey, publicKey, err := keygen.GenerateKeyPair("") - require.NoError(err) - auth := NewInstance(InstanceConfig{ - ClusterName: "test-tunnel-collapse", - HostID: "auth", - Priv: privateKey, - Pub: publicKey, - Ports: ports.PopIntSlice(6), - log: log, - }) - - log.Info(">>> Creating auth-proxy...") - authConf := service.MakeDefaultConfig() - authConf.Hostname = Host - authConf.Auth.Enabled = true - authConf.Proxy.Enabled = true - authConf.SSH.Enabled = false - authConf.Proxy.DisableWebInterface = true - authConf.Proxy.DisableDatabaseProxy = true - require.NoError(auth.CreateEx(t, nil, authConf)) - t.Cleanup(func() { require.NoError(auth.StopAll()) }) - - log.Info(">>> Start auth-proxy...") - require.NoError(auth.Start()) - - // ... and an SSH node connected via a reverse tunnel configured to - // reload after only a few failed connection attempts per minute - log.Info(">>> Creating and starting node...") - nodeCfg := service.MakeDefaultConfig() - nodeCfg.Hostname = Host - nodeCfg.SSH.Enabled = true - nodeCfg.RotationConnectionInterval = 1 * time.Second - nodeCfg.RestartThreshold = service.Rate{Amount: 3, Time: 1 * time.Minute} - node, err := auth.StartReverseTunnelNode(nodeCfg) - require.NoError(err) - - // WHEN I stop the auth node (and, by implication, disrupt the ssh node's - // connection to it) - log.Info(">>> Stopping auth node") - auth.StopAuth(false) - - // EXPECT THAT the ssh node will eventually issue a reload request - log.Info(">>> Waiting for node restart request.") - waitCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) - defer cancel() - - eventCh := make(chan service.Event) - node.WaitForEvent(waitCtx, service.TeleportReloadEvent, eventCh) - select { - case e := <-eventCh: - log.Infof(">>> Received Reload event: %v. Test passed.", e) - - case <-waitCtx.Done(): - require.FailNow("Timed out", "Timed out waiting for reload event") - } - - log.Info(">>> TEST COMPLETE") -} diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 428b215f1ccb8..5ce5b5c021e15 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -303,7 +303,7 @@ type Cache struct { // fnCache is used to perform short ttl-based caching of the results of // regularly called methods. - fnCache *fnCache + fnCache *utils.FnCache trustCache services.Trust clusterConfigCache services.ClusterConfiguration @@ -568,6 +568,14 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } + fnCache, err := utils.NewFnCache(utils.FnCacheConfig{ + TTL: time.Second, + Clock: config.Clock, + }) + if err != nil { + return nil, trace.Wrap(err) + } + ctx, cancel := context.WithCancel(config.Context) cs := &Cache{ wrapper: wrapper, @@ -576,7 +584,7 @@ func New(config Config) (*Cache, error) { Config: config, generation: atomic.NewUint64(0), initC: make(chan struct{}), - fnCache: newFnCache(time.Second), + fnCache: fnCache, trustCache: local.NewCAService(wrapper), clusterConfigCache: clusterConfigCache, provisionerCache: local.NewProvisioningService(wrapper), diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 216f9b725aac4..0d2cba35adec1 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -345,12 +345,19 @@ func (a *Agent) run() { a.log.Warningf("Failed to create remote tunnel: %v, conn: %v.", err, conn) return } - defer conn.Close() + + local := conn.LocalAddr().String() + remote := conn.RemoteAddr().String() + defer func() { + if err := conn.Close(); err != nil { + a.log.Warnf("Failed to close remote tunnel: %v, local addr: %s remote addr: %s", err, local, remote) + } + }() // Successfully connected to remote cluster. a.log.WithFields(log.Fields{ - "addr": conn.LocalAddr().String(), - "remote-addr": conn.RemoteAddr().String(), + "addr": local, + "remote-addr": remote, }).Info("Connected.") // wrap up remaining business logic in closure for easy @@ -375,14 +382,14 @@ func (a *Agent) run() { // or permanent loss of a proxy. err = a.processRequests(conn) if err != nil { - a.log.Warnf("Unable to continue processesing requests: %v.", err) + a.log.Warnf("Unable to continue processioning requests: %v.", err) return } } // if Tracker was provided, then the agent shouldn't continue unless // no other agents hold a claim. if a.Tracker != nil { - if !a.Tracker.WithProxy(doWork, a.Lease, a.getPrincipalsList()...) { + if !a.Tracker.WithProxy(doWork, a.getPrincipalsList()...) { a.log.Debugf("Proxy already held by other agent: %v, releasing.", a.getPrincipalsList()) } } else { @@ -478,7 +485,7 @@ func (a *Agent) processRequests(conn *ssh.Client) error { } } -// handleDisovery receives discovery requests from the reverse tunnel +// handleDiscovery receives discovery requests from the reverse tunnel // server, that informs agent about proxies registered in the remote // cluster and the reverse tunnels already established // @@ -486,7 +493,11 @@ func (a *Agent) processRequests(conn *ssh.Client) error { // reqC : request payload func (a *Agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) { a.log.Debugf("handleDiscovery requests channel.") - defer ch.Close() + defer func() { + if err := ch.Close(); err != nil { + a.log.Warnf("Failed to close discovery channel:: %v", err) + } + }() for { var req *ssh.Request @@ -507,7 +518,7 @@ func (a *Agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) { if a.Tracker != nil { // Notify tracker of all known proxies. for _, p := range r.Proxies { - a.Tracker.TrackExpected(a.Lease, p.GetName()) + a.Tracker.TrackExpected(p.GetName()) } } } diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 9ea5cb1fe27a9..58b2389cd316c 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -59,7 +59,7 @@ type AgentPool struct { spawnLimiter utils.Retry mu sync.Mutex - agents map[utils.NetAddr][]*Agent + agents []*Agent } // AgentPoolConfig holds configuration parameters for the agent pool @@ -89,8 +89,8 @@ type AgentPoolConfig struct { Component string // ReverseTunnelServer holds all reverse tunnel connections. ReverseTunnelServer Server - // ProxyAddr points to the address of the ssh proxy - ProxyAddr string + // Resolver retrieves the reverse tunnel address + Resolver Resolver // Cluster is a cluster name of the proxy. Cluster string // FIPS indicates if Teleport was started in FIPS mode. @@ -135,11 +135,6 @@ func NewAgentPool(ctx context.Context, cfg AgentPoolConfig) (*AgentPool, error) return nil, trace.Wrap(err) } - proxyAddr, err := utils.ParseAddr(cfg.ProxyAddr) - if err != nil { - return nil, trace.Wrap(err) - } - ctx, cancel := context.WithCancel(ctx) tr, err := track.New(ctx, track.Config{ClusterName: cfg.Cluster}) if err != nil { @@ -148,7 +143,6 @@ func NewAgentPool(ctx context.Context, cfg AgentPoolConfig) (*AgentPool, error) } pool := &AgentPool{ - agents: make(map[utils.NetAddr][]*Agent), proxyTracker: tr, cfg: cfg, ctx: ctx, @@ -161,7 +155,7 @@ func NewAgentPool(ctx context.Context, cfg AgentPoolConfig) (*AgentPool, error) }, }), } - pool.proxyTracker.Start(*proxyAddr) + pool.proxyTracker.Start() return pool, nil } @@ -204,11 +198,11 @@ func (m *AgentPool) processSeekEvents() { // The proxy tracker has given us permission to act on a given // tunnel address case lease := <-m.proxyTracker.Acquire(): - m.log.Debugf("Seeking: %+v.", lease.Key()) m.withLock(func() { // Note that ownership of the lease is transferred to agent // pool for the lifetime of the connection if err := m.addAgent(lease); err != nil { + lease.Release() m.log.WithError(err).Errorf("Failed to add agent.") } }) @@ -232,12 +226,7 @@ func (m *AgentPool) withLock(f func()) { type matchAgentFn func(a *Agent) bool func (m *AgentPool) closeAgents() { - for key, agents := range m.agents { - m.agents[key] = filterAndClose(agents, func(*Agent) bool { return true }) - if len(m.agents[key]) == 0 { - delete(m.agents, key) - } - } + m.agents = filterAndClose(m.agents, func(*Agent) bool { return true }) } func filterAndClose(agents []*Agent, matchAgent matchAgentFn) []*Agent { @@ -246,11 +235,18 @@ func filterAndClose(agents []*Agent, matchAgent matchAgentFn) []*Agent { agent := agents[i] if matchAgent(agent) { agent.log.Debugf("Pool is closing agent.") - agent.Close() + if err := agent.Close(); err != nil { + agent.log.WithError(err).Warnf("Failed to close agent") + } } else { filtered = append(filtered, agent) } } + + if len(filtered) <= 0 { + return nil + } + return filtered } @@ -273,9 +269,13 @@ func (m *AgentPool) pollAndSyncAgents() { // transfers into the AgentPool, and will be released when the AgentPool // is done with it. func (m *AgentPool) addAgent(lease track.Lease) error { - addr := lease.Key().(utils.NetAddr) + addr, err := m.cfg.Resolver() + if err != nil { + return trace.Wrap(err) + } + agent, err := NewAgent(AgentConfig{ - Addr: addr, + Addr: *addr, ClusterName: m.cfg.Cluster, Username: m.cfg.HostUUID, Signer: m.cfg.HostSigner, @@ -293,28 +293,25 @@ func (m *AgentPool) addAgent(lease track.Lease) error { }) if err != nil { // ensure that lease has been released; OK to call multiple times. - lease.Release() return trace.Wrap(err) } m.log.Debugf("Adding %v.", agent) // start the agent in a goroutine. no need to handle Start() errors: Start() will be // retrying itself until the agent is closed go agent.Start() - m.agents[addr] = append(m.agents[addr], agent) + m.agents = append(m.agents, agent) return nil } -// Counts returns a count of the number of proxies a outbound tunnel is +// Count returns a count of the number of proxies an outbound tunnel is // connected to. Used in tests to determine if a proxy has been found and/or // removed. func (m *AgentPool) Count() int { var out int m.withLock(func() { - for _, agents := range m.agents { - for _, agent := range agents { - if agent.getState() == agentStateConnected { - out++ - } + for _, agent := range m.agents { + if agent.getState() == agentStateConnected { + out++ } } }) @@ -325,19 +322,10 @@ func (m *AgentPool) Count() int { // removeDisconnected removes disconnected agents from the list of agents. // This function should be called under a lock. func (m *AgentPool) removeDisconnected() { - for agentKey, agentSlice := range m.agents { - // Filter and close all disconnected agents. - validAgents := filterAndClose(agentSlice, func(agent *Agent) bool { - return agent.getState() == agentStateDisconnected - }) - - // Update (or delete) agent key with filter applied. - if len(validAgents) > 0 { - m.agents[agentKey] = validAgents - } else { - delete(m.agents, agentKey) - } - } + // Filter and close all disconnected agents. + m.agents = filterAndClose(m.agents, func(agent *Agent) bool { + return agent.getState() == agentStateDisconnected + }) } // Make sure ServerHandlerToListener implements both interfaces. diff --git a/lib/reversetunnel/rc_manager.go b/lib/reversetunnel/rc_manager.go index 1559a0f6efa06..deebef8525af6 100644 --- a/lib/reversetunnel/rc_manager.go +++ b/lib/reversetunnel/rc_manager.go @@ -44,7 +44,7 @@ type RemoteClusterTunnelManager struct { pools map[remoteClusterKey]*AgentPool stopRun func() - newAgentPool func(ctx context.Context, cluster, addr string) (*AgentPool, error) + newAgentPool func(ctx context.Context, cfg RemoteClusterTunnelManagerConfig, cluster, addr string) (*AgentPool, error) } type remoteClusterKey struct { @@ -105,17 +105,17 @@ func (c *RemoteClusterTunnelManagerConfig) CheckAndSetDefaults() error { return nil } -// NewRemoteClusterTunnelManager creates a new unstarted tunnel manager with +// NewRemoteClusterTunnelManager creates a new stopped tunnel manager with // the provided config. Call Run() to start the manager. func NewRemoteClusterTunnelManager(cfg RemoteClusterTunnelManagerConfig) (*RemoteClusterTunnelManager, error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } w := &RemoteClusterTunnelManager{ - cfg: cfg, - pools: make(map[remoteClusterKey]*AgentPool), + cfg: cfg, + pools: make(map[remoteClusterKey]*AgentPool), + newAgentPool: realNewAgentPool, } - w.newAgentPool = w.realNewAgentPool return w, nil } @@ -197,7 +197,7 @@ func (w *RemoteClusterTunnelManager) Sync(ctx context.Context) error { continue } - pool, err := w.newAgentPool(ctx, k.cluster, k.addr) + pool, err := w.newAgentPool(ctx, w.cfg, k.cluster, k.addr) if err != nil { errs = append(errs, trace.Wrap(err)) continue @@ -207,29 +207,32 @@ func (w *RemoteClusterTunnelManager) Sync(ctx context.Context) error { return trace.NewAggregate(errs...) } -func (w *RemoteClusterTunnelManager) realNewAgentPool(ctx context.Context, cluster, addr string) (*AgentPool, error) { +func realNewAgentPool(ctx context.Context, cfg RemoteClusterTunnelManagerConfig, cluster, addr string) (*AgentPool, error) { pool, err := NewAgentPool(ctx, AgentPoolConfig{ // Configs for our cluster. - Client: w.cfg.AuthClient, - AccessPoint: w.cfg.AccessPoint, - HostSigner: w.cfg.HostSigner, - HostUUID: w.cfg.HostUUID, - LocalCluster: w.cfg.LocalCluster, - Clock: w.cfg.Clock, - KubeDialAddr: w.cfg.KubeDialAddr, - ReverseTunnelServer: w.cfg.ReverseTunnelServer, - FIPS: w.cfg.FIPS, + Client: cfg.AuthClient, + AccessPoint: cfg.AccessPoint, + HostSigner: cfg.HostSigner, + HostUUID: cfg.HostUUID, + LocalCluster: cfg.LocalCluster, + Clock: cfg.Clock, + KubeDialAddr: cfg.KubeDialAddr, + ReverseTunnelServer: cfg.ReverseTunnelServer, + FIPS: cfg.FIPS, // RemoteClusterManager only runs on proxies. Component: teleport.ComponentProxy, // Configs for remote cluster. - Cluster: cluster, - ProxyAddr: addr, + Cluster: cluster, + Resolver: StaticResolver(addr), }) if err != nil { return nil, trace.Wrap(err, "failed creating reverse tunnel pool for remote cluster %q at address %q: %v", cluster, addr, err) } - go pool.Start() + + if err := pool.Start(); err != nil { + cfg.Log.WithError(err).Error("Failed to start agent pool") + } return pool, nil } diff --git a/lib/reversetunnel/rc_manager_test.go b/lib/reversetunnel/rc_manager_test.go index db9766f651846..f492cf4e82085 100644 --- a/lib/reversetunnel/rc_manager_test.go +++ b/lib/reversetunnel/rc_manager_test.go @@ -8,25 +8,37 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/require" ) func TestRemoteClusterTunnelManagerSync(t *testing.T) { t.Parallel() + resolverFn := func(addr string) Resolver { + return func() (*utils.NetAddr, error) { + return &utils.NetAddr{ + Addr: addr, + AddrNetwork: "tcp", + Path: "", + }, nil + } + } + var newAgentPoolErr error w := &RemoteClusterTunnelManager{ pools: make(map[remoteClusterKey]*AgentPool), - newAgentPool: func(ctx context.Context, cluster, addr string) (*AgentPool, error) { + newAgentPool: func(ctx context.Context, cfg RemoteClusterTunnelManagerConfig, cluster, addr string) (*AgentPool, error) { return &AgentPool{ - cfg: AgentPoolConfig{Cluster: cluster, ProxyAddr: addr}, + cfg: AgentPoolConfig{Cluster: cluster, Resolver: resolverFn(addr)}, cancel: func() {}, }, newAgentPoolErr }, } - defer w.Close() + t.Cleanup(func() { require.NoError(t, w.Close()) }) tests := []struct { desc string @@ -47,7 +59,7 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { mustNewReverseTunnel(t, "cluster-a", []string{"addr-a"}), }, wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-a"}}, + {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-a")}}, }, assertErr: require.NoError, }, @@ -57,9 +69,9 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { mustNewReverseTunnel(t, "cluster-a", []string{"addr-a", "addr-b", "addr-c"}), }, wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-a"}}, - {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-b"}}, - {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-c"}}, + {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-a")}}, + {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-b")}}, + {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-c")}}, }, assertErr: require.NoError, }, @@ -69,7 +81,7 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { mustNewReverseTunnel(t, "cluster-b", []string{"addr-b"}), }, wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", ProxyAddr: "addr-b"}}, + {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", Resolver: resolverFn("addr-b")}}, }, assertErr: require.NoError, }, @@ -80,10 +92,10 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { mustNewReverseTunnel(t, "cluster-b", []string{"addr-b"}), }, wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-a"}}, - {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-b"}}, - {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-c"}}, - {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", ProxyAddr: "addr-b"}}, + {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-a")}}, + {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-b")}}, + {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-c")}}, + {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", Resolver: resolverFn("addr-b")}}, }, assertErr: require.NoError, }, @@ -91,10 +103,10 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { desc: "GetReverseTunnels error, keep existing pools", reverseTunnelsErr: errors.New("nah"), wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-a"}}, - {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-b"}}, - {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-c"}}, - {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", ProxyAddr: "addr-b"}}, + {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-a")}}, + {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-b")}}, + {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-c")}}, + {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", Resolver: resolverFn("addr-b")}}, }, assertErr: require.Error, }, @@ -107,16 +119,16 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { }, newAgentPoolErr: errors.New("nah"), wantPools: map[remoteClusterKey]*AgentPool{ - {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-a"}}, - {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-b"}}, - {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", ProxyAddr: "addr-c"}}, - {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", ProxyAddr: "addr-b"}}, + {cluster: "cluster-a", addr: "addr-a"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-a")}}, + {cluster: "cluster-a", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-b")}}, + {cluster: "cluster-a", addr: "addr-c"}: {cfg: AgentPoolConfig{Cluster: "cluster-a", Resolver: resolverFn("addr-c")}}, + {cluster: "cluster-b", addr: "addr-b"}: {cfg: AgentPoolConfig{Cluster: "cluster-b", Resolver: resolverFn("addr-b")}}, }, assertErr: require.Error, }, } - ctx := context.TODO() + ctx := context.Background() for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { w.cfg.AuthClient = mockAuthClient{ @@ -134,8 +146,18 @@ func TestRemoteClusterTunnelManagerSync(t *testing.T) { // Tweaks to get comparison working with our complex types. cmp.AllowUnexported(remoteClusterKey{}), cmp.Comparer(func(a, b *AgentPool) bool { + aAddr, aErr := a.cfg.Resolver() + bAddr, bErr := b.cfg.Resolver() + + if aAddr != bAddr && aErr != bErr { + return false + } + // Only check the supplied configs of AgentPools. - return cmp.Equal(a.cfg, b.cfg) + return cmp.Equal( + a.cfg, + b.cfg, + cmpopts.IgnoreFields(AgentPoolConfig{}, "Resolver")) }), )) }) diff --git a/lib/reversetunnel/resolver.go b/lib/reversetunnel/resolver.go new file mode 100644 index 0000000000000..a467f5214577b --- /dev/null +++ b/lib/reversetunnel/resolver.go @@ -0,0 +1,91 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reversetunnel + +import ( + "context" + "time" + + "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" +) + +// Resolver looks up reverse tunnel addresses +type Resolver func() (*utils.NetAddr, error) + +// CachingResolver wraps the provided Resolver with one that will cache the previous result +// for 3 seconds to reduce the number of resolutions in an effort to mitigate potentially +// overwhelming the Resolver source. +func CachingResolver(resolver Resolver, clock clockwork.Clock) (Resolver, error) { + cache, err := utils.NewFnCache(utils.FnCacheConfig{ + TTL: 3 * time.Second, + Clock: clock, + }) + if err != nil { + return nil, err + } + return func() (*utils.NetAddr, error) { + a, err := cache.Get(context.TODO(), "resolver", func() (interface{}, error) { + return resolver() + }) + if err != nil { + return nil, err + } + return a.(*utils.NetAddr), nil + }, nil +} + +// WebClientResolver returns a Resolver which uses the web proxy to +// discover where the SSH reverse tunnel server is running. +func WebClientResolver(ctx context.Context, addrs []utils.NetAddr, insecureTLS bool) Resolver { + return func() (*utils.NetAddr, error) { + var errs []error + for _, addr := range addrs { + // In insecure mode, any certificate is accepted. In secure mode the hosts + // CAs are used to validate the certificate on the proxy. + tunnelAddr, err := webclient.GetTunnelAddr(ctx, addr.String(), insecureTLS, nil) + if err != nil { + errs = append(errs, err) + continue + } + + addr, err := utils.ParseAddr(tunnelAddr) + if err != nil { + errs = append(errs, err) + continue + } + + addr.Addr = utils.ReplaceUnspecifiedHost(addr, defaults.HTTPListenPort) + return addr, nil + } + return nil, trace.NewAggregate(errs...) + } +} + +// StaticResolver returns a Resolver which will always resolve to +// the provided address +func StaticResolver(address string) Resolver { + addr, err := utils.ParseAddr(address) + if err == nil { + addr.Addr = utils.ReplaceUnspecifiedHost(addr, defaults.HTTPListenPort) + } + + return func() (*utils.NetAddr, error) { + return addr, err + } +} diff --git a/lib/reversetunnel/resolver_test.go b/lib/reversetunnel/resolver_test.go new file mode 100644 index 0000000000000..e9dee2209e43e --- /dev/null +++ b/lib/reversetunnel/resolver_test.go @@ -0,0 +1,156 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reversetunnel + +import ( + "context" + "os" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/lib/utils" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +func TestStaticResolver(t *testing.T) { + cases := []struct { + name string + address string + errorAssertionFn require.ErrorAssertionFunc + expected *utils.NetAddr + }{ + { + name: "invalid address yields error", + address: "", + errorAssertionFn: require.Error, + }, + { + name: "valid address yields NetAddr", + address: "localhost:80", + errorAssertionFn: require.NoError, + expected: &utils.NetAddr{ + Addr: "localhost:80", + AddrNetwork: "tcp", + Path: "", + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr, err := StaticResolver(tt.address)() + tt.errorAssertionFn(t, err) + if err != nil { + return + } + + require.Empty(t, cmp.Diff(tt.expected, addr)) + }) + } +} + +func TestResolveViaWebClient(t *testing.T) { + + fakeAddr := utils.NetAddr{} + + cases := []struct { + name string + addrs []utils.NetAddr + address string + errorAssertionFn require.ErrorAssertionFunc + expected *utils.NetAddr + }{ + { + name: "no addrs yields no results", + errorAssertionFn: require.NoError, + }, + { + name: "unreachable proxy yields errors", + addrs: []utils.NetAddr{fakeAddr}, + address: "", + errorAssertionFn: require.Error, + }, + { + name: "invalid address yields errors", + addrs: []utils.NetAddr{fakeAddr}, + address: "fake://test", + errorAssertionFn: require.Error, + }, + { + name: "valid address yields NetAddr", + addrs: []utils.NetAddr{fakeAddr}, + address: "localhost:80", + errorAssertionFn: require.NoError, + expected: &utils.NetAddr{ + Addr: "localhost:80", + AddrNetwork: "tcp", + Path: "", + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + os.Setenv(defaults.TunnelPublicAddrEnvar, tt.address) + t.Cleanup(func() { os.Unsetenv(defaults.TunnelPublicAddrEnvar) }) + + addr, err := WebClientResolver(context.Background(), tt.addrs, true)() + tt.errorAssertionFn(t, err) + if err != nil { + return + } + + require.Empty(t, cmp.Diff(tt.expected, addr)) + }) + } +} + +func TestCachingResolver(t *testing.T) { + randomResolver := func() (*utils.NetAddr, error) { + return &utils.NetAddr{ + Addr: uuid.New().String(), + AddrNetwork: uuid.New().String(), + Path: uuid.New().String(), + }, nil + } + + clock := clockwork.NewFakeClock() + resolver, err := CachingResolver(randomResolver, clock) + require.NoError(t, err) + + addr, err := resolver() + require.NoError(t, err) + + addr2, err := resolver() + require.NoError(t, err) + + require.Equal(t, addr, addr2) + + clock.Advance(time.Hour) + + addr3, err := resolver() + require.NoError(t, err) + + require.NotEqual(t, addr2, addr3) + + addr4, err := resolver() + require.NoError(t, err) + + require.Equal(t, addr3, addr4) +} diff --git a/lib/reversetunnel/track/tracker.go b/lib/reversetunnel/track/tracker.go index ea4fb94c9d30b..ff0a71b8b9acb 100644 --- a/lib/reversetunnel/track/tracker.go +++ b/lib/reversetunnel/track/tracker.go @@ -22,7 +22,6 @@ import ( "sync" "time" - "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/workpool" "github.com/gravitational/trace" ) @@ -31,7 +30,7 @@ type Lease = workpool.Lease // Config configures basic Tracker parameters. type Config struct { - // ProxyExpiry is the duration an entry will be held sice the last + // ProxyExpiry is the duration an entry will be held since the last // successful connection to, or message about, a given proxy. ProxyExpiry time.Duration // TickRate is the rate at which expired entries are cleared from @@ -41,7 +40,7 @@ type Config struct { ClusterName string } -// SetDefaults set default values for Config. +// CheckAndSetDefaults set default values for Config. func (c *Config) CheckAndSetDefaults() error { if c.ProxyExpiry < 1 { c.ProxyExpiry = 3 * time.Minute @@ -66,7 +65,7 @@ type Tracker struct { Config mu sync.Mutex wp *workpool.Pool - sets map[utils.NetAddr]*proxySet + sets *proxySet cancel context.CancelFunc } @@ -79,7 +78,6 @@ func New(ctx context.Context, c Config) (*Tracker, error) { t := &Tracker{ Config: c, wp: workpool.NewPool(ctx), - sets: make(map[utils.NetAddr]*proxySet), cancel: cancel, } go t.run(ctx) @@ -107,41 +105,41 @@ func (t *Tracker) Acquire() <-chan Lease { // TrackExpected starts/refreshes tracking for expected proxies. Called by // agents when gossip messages are received. -func (t *Tracker) TrackExpected(lease Lease, proxies ...string) { +func (t *Tracker) TrackExpected(proxies ...string) { t.mu.Lock() defer t.mu.Unlock() - addr := lease.Key().(utils.NetAddr) - set, ok := t.sets[addr] - if !ok { + + if t.sets == nil { return } now := time.Now() for _, name := range proxies { - set.markSeen(now, name) + t.sets.markSeen(now, name) } - count := len(set.proxies) + count := len(t.sets.proxies) if count < 1 { count = 1 } - t.wp.Set(addr, uint64(count)) + t.wp.Set(uint64(count)) } // Start starts tracking for specified proxy address. -func (t *Tracker) Start(addr utils.NetAddr) { +func (t *Tracker) Start() { t.mu.Lock() defer t.mu.Unlock() - t.getOrCreate(addr) + t.getOrCreate() } // Stop stops tracking for specified proxy address. -func (t *Tracker) Stop(addr utils.NetAddr) { +func (t *Tracker) Stop() { t.mu.Lock() defer t.mu.Unlock() - if _, ok := t.sets[addr]; !ok { + if t.sets == nil { return } - delete(t.sets, addr) - t.wp.Set(addr, 0) + + t.sets = nil + t.wp.Set(0) } // StopAll permanently deactivates this tracker and cleans @@ -153,59 +151,59 @@ func (t *Tracker) StopAll() { func (t *Tracker) tick() { t.mu.Lock() defer t.mu.Unlock() + if t.sets == nil { + return + } + cutoff := time.Now().Add(-1 * t.ProxyExpiry) - for addr, set := range t.sets { - if set.expire(cutoff) > 0 { - count := len(set.proxies) - if count < 1 { - count = 1 - } - t.wp.Set(addr, uint64(count)) + if t.sets.expire(cutoff) > 0 { + count := len(t.sets.proxies) + if count < 1 { + count = 1 } + t.wp.Set(uint64(count)) } + } -func (t *Tracker) getOrCreate(addr utils.NetAddr) *proxySet { - if s, ok := t.sets[addr]; ok { - return s +func (t *Tracker) getOrCreate() *proxySet { + if t.sets == nil { + t.sets = newProxySet(t.ClusterName) + t.wp.Set(1) } - set := newProxySet(addr, t.ClusterName) - t.sets[addr] = set - t.wp.Set(addr, 1) - return set + + return t.sets } // WithProxy runs the supplied closure if and only if // no other work is currently being done with the proxy // identified by principals. -func (t *Tracker) WithProxy(work func(), lease Lease, principals ...string) (didWork bool) { - addr := lease.Key().(utils.NetAddr) - if ok := t.claim(addr, principals...); !ok { +func (t *Tracker) WithProxy(work func(), principals ...string) (didWork bool) { + if ok := t.claim(principals...); !ok { return false } - defer t.unclaim(addr, principals...) + defer t.release(principals...) work() return true } -func (t *Tracker) claim(addr utils.NetAddr, principals ...string) (ok bool) { +func (t *Tracker) claim(principals ...string) (ok bool) { t.mu.Lock() defer t.mu.Unlock() - set, ok := t.sets[addr] - if !ok { + if t.sets == nil { return false } - return set.claim(principals...) + return t.sets.claim(principals...) } -func (t *Tracker) unclaim(addr utils.NetAddr, principals ...string) { +func (t *Tracker) release(principals ...string) { t.mu.Lock() defer t.mu.Unlock() - set, ok := t.sets[addr] - if !ok { + if t.sets == nil { return } - set.unclaim(principals...) + + t.sets.release(principals...) } type entry struct { @@ -213,16 +211,14 @@ type entry struct { claimed bool } -func newProxySet(addr utils.NetAddr, clusterName string) *proxySet { +func newProxySet(clusterName string) *proxySet { return &proxySet{ - addr: addr, clusterName: clusterName, proxies: make(map[string]entry), } } type proxySet struct { - addr utils.NetAddr clusterName string proxies map[string]entry } @@ -244,7 +240,7 @@ func (p *proxySet) claim(principals ...string) (ok bool) { return true } -func (p *proxySet) unclaim(principals ...string) { +func (p *proxySet) release(principals ...string) { proxy := p.resolveName(principals) p.proxies[proxy] = entry{ lastSeen: time.Now(), diff --git a/lib/reversetunnel/track/tracker_test.go b/lib/reversetunnel/track/tracker_test.go index fd4dae8f039a2..2a788bde63b44 100644 --- a/lib/reversetunnel/track/tracker_test.go +++ b/lib/reversetunnel/track/tracker_test.go @@ -20,12 +20,12 @@ import ( "context" "fmt" pr "math/rand" + "os" "sync" "testing" "time" - "github.com/gravitational/teleport/lib/utils" - "gopkg.in/check.v1" + "github.com/stretchr/testify/require" ) type simpleTestProxies struct { @@ -96,13 +96,13 @@ func (s *simpleTestProxies) ProxyLoop(tracker *Tracker, lease Lease, proxy testP select { case <-ticker.C: if p, ok := s.GetRandProxy(); ok { - tracker.TrackExpected(lease, p.principals[0]) + tracker.TrackExpected(p.principals[0]) } case <-timeout: break Loop } } - }, lease, proxy.principals...) + }, proxy.principals...) return } @@ -133,29 +133,26 @@ func jitter(t time.Duration) time.Duration { return t + j } -func Test(t *testing.T) { +func TestMain(m *testing.M) { pr.Seed(time.Now().UnixNano()) - check.TestingT(t) + os.Exit(m.Run()) } -type StateSuite struct{} +func TestBasic(t *testing.T) { + const ( + timeout = time.Second * 16 + proxyCount = 16 + ) -var _ = check.Suite(&StateSuite{}) + ctx := context.Background() -func (s *StateSuite) TestBasic(c *check.C) { - s.runBasicProxyTest(c, time.Second*16) -} - -func (s *StateSuite) runBasicProxyTest(c *check.C, timeout time.Duration) { - const proxyCount = 16 timeoutC := time.After(timeout) ticker := time.NewTicker(time.Millisecond * 100) - defer ticker.Stop() - tracker, err := New(context.TODO(), Config{ClusterName: "test-cluster"}) - c.Assert(err, check.IsNil) - defer tracker.StopAll() - addr := utils.NetAddr{Addr: "test-cluster"} - tracker.Start(addr) + t.Cleanup(ticker.Stop) + tracker, err := New(ctx, Config{ClusterName: "test-cluster"}) + require.NoError(t, err) + t.Cleanup(tracker.StopAll) + tracker.Start() min, max := time.Duration(0), timeout var proxies simpleTestProxies proxies.AddRandProxies(proxyCount, min, max) @@ -165,18 +162,18 @@ Discover: case lease := <-tracker.Acquire(): go proxies.Discover(tracker, lease) case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts: %+v", counts) if counts.Active == proxyCount { break Discover } case <-timeoutC: - c.Fatal("timeout") + t.Fatal("timeout") } } } -func (s *StateSuite) TestFullRotation(c *check.C) { +func TestFullRotation(t *testing.T) { const ( proxyCount = 8 minConnA = time.Second * 2 @@ -185,37 +182,37 @@ func (s *StateSuite) TestFullRotation(c *check.C) { maxConnB = time.Second * 25 timeout = time.Second * 30 ) + + ctx := context.Background() ticker := time.NewTicker(time.Millisecond * 100) - defer ticker.Stop() + t.Cleanup(ticker.Stop) var proxies simpleTestProxies proxies.AddRandProxies(proxyCount, minConnA, maxConnA) - tracker, err := New(context.TODO(), Config{ClusterName: "test-cluster"}) - c.Assert(err, check.IsNil) - defer tracker.StopAll() - addr := utils.NetAddr{Addr: "test-cluster"} - tracker.Start(addr) + tracker, err := New(ctx, Config{ClusterName: "test-cluster"}) + require.NoError(t, err) + t.Cleanup(tracker.StopAll) + tracker.Start() timeoutC := time.After(timeout) Loop0: for { select { case lease := <-tracker.Acquire(): - c.Assert(lease.Key(), check.DeepEquals, addr) // get our "discovered" proxy in the foreground // to prevent race with the call to RemoveRandProxies // that comes after this loop. proxy, ok := proxies.GetRandProxy() if !ok { - c.Fatal("failed to get test proxy") + t.Fatal("failed to get test proxy") } go proxies.ProxyLoop(tracker, lease, proxy) case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts0: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts0: %+v", counts) if counts.Active == proxyCount { break Loop0 } case <-timeoutC: - c.Fatal("timeout") + t.Fatal("timeout") } } proxies.RemoveRandProxies(proxyCount) @@ -223,13 +220,13 @@ Loop1: for { select { case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts1: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts1: %+v", counts) if counts.Active < 1 { break Loop1 } case <-timeoutC: - c.Fatal("timeout") + t.Fatal("timeout") } } proxies.AddRandProxies(proxyCount, minConnB, maxConnB) @@ -239,13 +236,13 @@ Loop2: case lease := <-tracker.Acquire(): go proxies.Discover(tracker, lease) case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts2: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts2: %+v", counts) if counts.Active >= proxyCount { break Loop2 } case <-timeoutC: - c.Fatal("timeout") + t.Fatal("timeout") } } } @@ -253,41 +250,40 @@ Loop2: // TestUUIDHandling verifies that host UUIDs are correctly extracted // from the expected teleport principal format, and that gossip messages // consisting only of uuid don't create duplicate entries. -func (s *StateSuite) TestUUIDHandling(c *check.C) { - ctx, cancel := context.WithTimeout(context.TODO(), time.Second*6) - defer cancel() +func TestUUIDHandling(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*6) + t.Cleanup(cancel) ticker := time.NewTicker(time.Millisecond * 10) - defer ticker.Stop() - tracker, err := New(context.TODO(), Config{ClusterName: "test-cluster"}) - c.Assert(err, check.IsNil) - defer tracker.StopAll() - addr := utils.NetAddr{Addr: "test-cluster"} - tracker.Start(addr) - lease := <-tracker.Acquire() + t.Cleanup(ticker.Stop) + tracker, err := New(context.Background(), Config{ClusterName: "test-cluster"}) + require.NoError(t, err) + t.Cleanup(tracker.StopAll) + tracker.Start() + <-tracker.Acquire() // claim a proxy using principal of the form . go tracker.WithProxy(func() { - c.Logf("Successfully claimed proxy") + t.Logf("Successfully claimed proxy") <-ctx.Done() - }, lease, "my-proxy.test-cluster") + }, "my-proxy.test-cluster") // Wait for proxy to be claimed Wait: for { select { case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts: %+v", counts) if counts.Active == counts.Target { break Wait } case <-ctx.Done(): - c.Errorf("pool never reached expected state") + t.Errorf("pool never reached expected state") } } // Send a gossip message containing host UUID only - tracker.TrackExpected(lease, "my-proxy") - c.Logf("Sent uuid-only gossip message; watching status...") + tracker.TrackExpected("my-proxy") + t.Logf("Sent uuid-only gossip message; watching status...") // Let pool go through a few ticks, monitoring status to ensure that // we don't incorrectly enter seek mode (entering seek mode here would @@ -296,13 +292,13 @@ Wait: for i := 0; i < 3; i++ { select { case <-ticker.C: - counts := tracker.wp.Get(addr) - c.Logf("Counts: %+v", counts) + counts := tracker.wp.Get() + t.Logf("Counts: %+v", counts) if counts.Active != counts.Target { - c.Errorf("incorrectly entered seek mode") + t.Errorf("incorrectly entered seek mode") } case <-ctx.Done(): - c.Errorf("timeout") + t.Errorf("timeout") } } } diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index 91a9ec422b4d1..34fa2f084851b 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -39,19 +39,50 @@ import ( "github.com/sirupsen/logrus" ) -// TunnelAuthDialer connects to the Auth Server through the reverse tunnel. -type TunnelAuthDialer struct { - // ProxyAddr is the address of the proxy - ProxyAddr string +// NewTunnelAuthDialer creates a new instance of TunnelAuthDialer +func NewTunnelAuthDialer(config TunnelAuthDialerConfig) (*TunnelAuthDialer, error) { + if err := config.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return &TunnelAuthDialer{ + TunnelAuthDialerConfig: config, + }, nil +} + +// TunnelAuthDialerConfig specifies TunnelAuthDialer configuration. +type TunnelAuthDialerConfig struct { + // Resolver retrieves the address of the proxy + Resolver Resolver // ClientConfig is SSH tunnel client config ClientConfig *ssh.ClientConfig + // Log is used for logging. + Log logrus.FieldLogger +} + +func (c *TunnelAuthDialerConfig) CheckAndSetDefaults() error { + if c.Resolver == nil { + return trace.BadParameter("missing tunnel address resolver") + } + return nil +} + +// TunnelAuthDialer connects to the Auth Server through the reverse tunnel. +type TunnelAuthDialer struct { + // TunnelAuthDialerConfig is the TunnelAuthDialer configuration. + TunnelAuthDialerConfig } // DialContext dials auth server via SSH tunnel -func (t *TunnelAuthDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { +func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { + addr, err := t.Resolver() + if err != nil { + t.Log.Errorf("Failed to resolve tunnel address %v", err) + return nil, trace.Wrap(err) + } + // Connect to the reverse tunnel server. - dialer := proxy.DialerFromEnvironment(t.ProxyAddr) - sconn, err := dialer.Dial("tcp", t.ProxyAddr, t.ClientConfig) + dialer := proxy.DialerFromEnvironment(addr.Addr) + sconn, err := dialer.Dial("tcp", addr.Addr, t.ClientConfig) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/service/connect.go b/lib/service/connect.go index 25fc860d2097b..2d2247fbf348d 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -20,12 +20,12 @@ import ( "crypto/tls" "path/filepath" + "github.com/gravitational/roundtrip" + "github.com/gravitational/teleport/api/constants" + "golang.org/x/crypto/ssh" - "github.com/gravitational/roundtrip" apiclient "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/api/client/webclient" - "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth" @@ -472,28 +472,13 @@ func (process *TeleportProcess) periodicSyncRotationState() error { }) defer periodic.Stop() - errors := utils.NewTimedCounter(process.Clock, process.Config.RestartThreshold.Time) - for { err := process.syncRotationStateCycle() if err == nil { return nil } - process.log.WithError(err).Warning("Sync rotation state cycle failed") - - // If we have had a *lot* of failures very recently, then it's likely that our - // route to the auth server is gone. If we're using a tunnel then it's possible - // that the proxy has been reconfigured and the tunnel address has moved. - count := errors.Increment() - process.log.Warnf("%d connection errors in last %v.", count, process.Config.RestartThreshold.Time) - if count > process.Config.RestartThreshold.Amount { - // signal quit - process.log.Error("Connection error threshold exceeded. Asking for a graceful restart.") - process.BroadcastEvent(Event{Name: TeleportReloadEvent}) - return nil - } - process.log.Warningf("Retrying in ~%v", process.Config.RotationConnectionInterval) + process.log.Warningf("Sync rotation state cycle failed. Retrying in ~%v", process.Config.RotationConnectionInterval) select { case <-periodic.Next(): @@ -867,20 +852,13 @@ func (process *TeleportProcess) newClient(authServers []utils.NetAddr, identity } logger.Debug("Attempting to discover reverse tunnel address.") - proxyAddr, err := process.findReverseTunnel(authServers) - if err != nil { - directErrLogger.Debug("Failed to connect to Auth Server directly.") - logger.WithError(err).Debug("Failed to discover reverse tunnel address.") - return nil, trace.Errorf("Failed to connect to Auth Server directly or over tunnel, no methods remaining.") - } - logger = process.log.WithField("proxy-addr", proxyAddr) logger.Debug("Attempting to connect to Auth Server through tunnel.") sshClientConfig, err := identity.SSHClientConfig(process.Config.FIPS) if err != nil { return nil, trace.Wrap(err) } - tunnelClient, err := process.newClientThroughTunnel(proxyAddr, tlsConfig, sshClientConfig) + tunnelClient, err := process.newClientThroughTunnel(authServers, tlsConfig, sshClientConfig) if err != nil { directErrLogger.Debug("Failed to connect to Auth Server directly.") logger.WithError(err).Debug("Failed to connect to Auth Server through tunnel.") @@ -891,28 +869,24 @@ func (process *TeleportProcess) newClient(authServers []utils.NetAddr, identity return tunnelClient, nil } -// findReverseTunnel uses the web proxy to discover where the SSH reverse tunnel -// server is running. -func (process *TeleportProcess) findReverseTunnel(addrs []utils.NetAddr) (string, error) { - var errs []error - for _, addr := range addrs { - // In insecure mode, any certificate is accepted. In secure mode the hosts - // CAs are used to validate the certificate on the proxy. - tunnelAddr, err := webclient.GetTunnelAddr(process.ExitContext(), addr.String(), lib.IsInsecureDevMode(), nil) - if err == nil { - return tunnelAddr, nil - } - errs = append(errs, err) +func (process *TeleportProcess) newClientThroughTunnel(authServers []utils.NetAddr, tlsConfig *tls.Config, sshConfig *ssh.ClientConfig) (*auth.Client, error) { + resolver := reversetunnel.WebClientResolver(process.ExitContext(), authServers, lib.IsInsecureDevMode()) + + resolver, err := reversetunnel.CachingResolver(resolver, process.Clock) + if err != nil { + return nil, trace.Wrap(err) } - return "", trace.NewAggregate(errs...) -} -func (process *TeleportProcess) newClientThroughTunnel(proxyAddr string, tlsConfig *tls.Config, sshConfig *ssh.ClientConfig) (*auth.Client, error) { + dialer, err := reversetunnel.NewTunnelAuthDialer(reversetunnel.TunnelAuthDialerConfig{ + Resolver: resolver, + ClientConfig: sshConfig, + Log: process.log, + }) + if err != nil { + return nil, trace.Wrap(err) + } clt, err := auth.NewClient(apiclient.Config{ - Dialer: &reversetunnel.TunnelAuthDialer{ - ProxyAddr: proxyAddr, - ClientConfig: sshConfig, - }, + Dialer: dialer, Credentials: []apiclient.Credentials{ apiclient.LoadTLS(tlsConfig), }, diff --git a/lib/service/db.go b/lib/service/db.go index 6557604f5ae39..ac87c086a00c9 100644 --- a/lib/service/db.go +++ b/lib/service/db.go @@ -59,15 +59,9 @@ func (process *TeleportProcess) initDatabaseService() (retErr error) { return trace.BadParameter("unsupported event payload type %q", event.Payload) } - var tunnelAddr string - if conn.TunnelProxy() != "" { - tunnelAddr = conn.TunnelProxy() - } else { - if tunnelAddr, ok = process.singleProcessMode(); !ok { - return trace.BadParameter("failed to find reverse tunnel address, " + - "if running in a single-process mode, make sure auth_service, " + - "proxy_service, and db_service are all enabled") - } + tunnelAddrResolver := conn.TunnelProxyResolver() + if tunnelAddrResolver == nil { + tunnelAddrResolver = process.singleProcessModeResolver() } accessPoint, err := process.newLocalCache(conn.Client, cache.ForDatabases, []string{teleport.ComponentDatabase}) @@ -191,11 +185,12 @@ func (process *TeleportProcess) initDatabaseService() (retErr error) { }() // Create and start the agent pool. - agentPool, err := reversetunnel.NewAgentPool(process.ExitContext(), + agentPool, err := reversetunnel.NewAgentPool( + process.ExitContext(), reversetunnel.AgentPoolConfig{ Component: teleport.ComponentDatabase, HostUUID: conn.ServerIdentity.ID.HostUUID, - ProxyAddr: tunnelAddr, + Resolver: tunnelAddrResolver, Client: conn.Client, Server: dbService, AccessPoint: conn.Client, diff --git a/lib/service/kubernetes.go b/lib/service/kubernetes.go index 30955bb11394c..09b3ed6d0b5fc 100644 --- a/lib/service/kubernetes.go +++ b/lib/service/kubernetes.go @@ -138,7 +138,7 @@ func (process *TeleportProcess) initKubernetesService(log *logrus.Entry, conn *C reversetunnel.AgentPoolConfig{ Component: teleport.ComponentKube, HostUUID: conn.ServerIdentity.ID.HostUUID, - ProxyAddr: conn.TunnelProxy(), + Resolver: conn.TunnelProxyResolver(), Client: conn.Client, AccessPoint: accessPoint, HostSigner: conn.ServerIdentity.KeySigner, diff --git a/lib/service/service.go b/lib/service/service.go index 9487994c6618d..0d313788f4883 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -213,23 +213,25 @@ type Connector struct { Client *auth.Client } -// TunnelProxy if non-empty, indicates that the client is connected to the Auth Server +// TunnelProxyResolver if non-nil, indicates that the client is connected to the Auth Server // through the reverse SSH tunnel proxy -func (c *Connector) TunnelProxy() string { +func (c *Connector) TunnelProxyResolver() reversetunnel.Resolver { if c.Client == nil || c.Client.Dialer() == nil { - return "" + return nil } - tun, ok := c.Client.Dialer().(*reversetunnel.TunnelAuthDialer) - if !ok { - return "" + + switch dialer := c.Client.Dialer().(type) { + case *reversetunnel.TunnelAuthDialer: + return dialer.Resolver + default: + return nil } - return tun.ProxyAddr } // UseTunnel indicates if the client is connected directly to the Auth Server // (false) or through the proxy (true). func (c *Connector) UseTunnel() bool { - return c.TunnelProxy() != "" + return c.TunnelProxyResolver() != nil } // Close closes resources associated with connector @@ -1857,7 +1859,7 @@ func (process *TeleportProcess) initSSH() error { reversetunnel.AgentPoolConfig{ Component: teleport.ComponentNode, HostUUID: conn.ServerIdentity.ID.HostUUID, - ProxyAddr: conn.TunnelProxy(), + Resolver: conn.TunnelProxyResolver(), Client: conn.Client, AccessPoint: conn.Client, HostSigner: conn.ServerIdentity.KeySigner, @@ -3071,16 +3073,11 @@ func (process *TeleportProcess) initApps() { // If this process connected through the web proxy, it will discover the // reverse tunnel address correctly and store it in the connector. // - // If it was not, it is running in single process mode which is used for // development and demos. In that case, wait until all dependencies (like // auth and reverse tunnel server) are ready before starting. - var tunnelAddr string - if conn.TunnelProxy() != "" { - tunnelAddr = conn.TunnelProxy() - } else { - if tunnelAddr, ok = process.singleProcessMode(); !ok { - return trace.BadParameter(`failed to find reverse tunnel address, if running in single process mode, make sure "auth_service", "proxy_service", and "app_service" are all enabled`) - } + tunnelAddrResolver := conn.TunnelProxyResolver() + if tunnelAddrResolver == nil { + tunnelAddrResolver = process.singleProcessModeResolver() // Block and wait for all dependencies to start before starting. log.Debugf("Waiting for application service dependencies to start.") @@ -3215,11 +3212,12 @@ func (process *TeleportProcess) initApps() { appServer.Start() // Create and start an agent pool. - agentPool, err = reversetunnel.NewAgentPool(process.ExitContext(), + agentPool, err = reversetunnel.NewAgentPool( + process.ExitContext(), reversetunnel.AgentPoolConfig{ Component: teleport.ComponentApp, HostUUID: conn.ServerIdentity.ID.HostUUID, - ProxyAddr: tunnelAddr, + Resolver: tunnelAddrResolver, Client: conn.Client, Server: appServer, AccessPoint: accessPoint, @@ -3480,19 +3478,35 @@ func (process *TeleportProcess) initDebugApp() { }) } +// singleProcessModeResolver returns the reversetunnel.Resolver that should be used when running all components needed +// within the same process. It's used for development and demo purposes. +func (process *TeleportProcess) singleProcessModeResolver() reversetunnel.Resolver { + return func() (*utils.NetAddr, error) { + addr, ok := process.singleProcessMode() + if !ok { + return nil, trace.BadParameter(`failed to find reverse tunnel address, if running in single process mode, make sure "auth_service", "proxy_service", and "app_service" are all enabled`) + } + return addr, nil + } +} + // singleProcessMode returns true when running all components needed within // the same process. It's used for development and demo purposes. -func (process *TeleportProcess) singleProcessMode() (string, bool) { +func (process *TeleportProcess) singleProcessMode() (*utils.NetAddr, bool) { if !process.Config.Proxy.Enabled || !process.Config.Auth.Enabled { - return "", false + return nil, false } if process.Config.Proxy.DisableReverseTunnel { - return "", false + return nil, false } if len(process.Config.Proxy.TunnelPublicAddrs) == 0 { - return net.JoinHostPort(string(teleport.PrincipalLocalhost), strconv.Itoa(defaults.SSHProxyTunnelListenPort)), true + addr, err := utils.ParseHostPortAddr(string(teleport.PrincipalLocalhost), defaults.SSHProxyTunnelListenPort) + if err != nil { + return nil, false + } + return addr, true } - return process.Config.Proxy.TunnelPublicAddrs[0].String(), true + return &process.Config.Proxy.TunnelPublicAddrs[0], true } // dumperHandler is an Application Access debugging application that will diff --git a/lib/utils/addr.go b/lib/utils/addr.go index 28a984b485d5b..71b39da321c51 100644 --- a/lib/utils/addr.go +++ b/lib/utils/addr.go @@ -342,3 +342,15 @@ func guessHostIP(addrs []net.Addr) (ip net.IP) { } return ip } + +// ReplaceUnspecifiedHost replaces unspecified "0.0.0.0" with localhost since "0.0.0.0" is never a valid +// principal (auth server explicitly removes it when issuing host certs) and when a reverse tunnel client used +// establishes SSH reverse tunnel connection the host is validated against +// the valid principal list. +func ReplaceUnspecifiedHost(addr *NetAddr, defaultPort int) string { + if !addr.IsHostUnspecified() { + return addr.String() + } + port := addr.Port(defaultPort) + return net.JoinHostPort("localhost", strconv.Itoa(port)) +} diff --git a/lib/cache/fncache.go b/lib/utils/fncache.go similarity index 75% rename from lib/cache/fncache.go rename to lib/utils/fncache.go index 3dbd3c2c592c0..c901de00eab05 100644 --- a/lib/cache/fncache.go +++ b/lib/utils/fncache.go @@ -14,20 +14,23 @@ See the License for the specific language governing permissions and limitations under the License. */ -package cache +package utils import ( "context" "sync" "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" ) -// fnCache is a helper for temporarily storing the results of regularly called functions. This helper is +// FnCache is a helper for temporarily storing the results of regularly called functions. This helper is // used to limit the amount of backend reads that occur while the primary cache is unhealthy. Most resources // do not require this treatment, but certain resources (cas, nodes, etc) can be loaded on a per-request // basis and can cause significant numbers of backend reads if the cache is unhealthy or taking a while to init. -type fnCache struct { - ttl time.Duration +type FnCache struct { + cfg FnCacheConfig mu sync.Mutex nextCleanup time.Time entries map[interface{}]*fnCacheEntry @@ -39,11 +42,32 @@ type fnCache struct { // removed upon subsequent reads of the same key. const cleanupMultiplier time.Duration = 16 -func newFnCache(ttl time.Duration) *fnCache { - return &fnCache{ - ttl: ttl, - entries: make(map[interface{}]*fnCacheEntry), +type FnCacheConfig struct { + TTL time.Duration + Clock clockwork.Clock +} + +func (c *FnCacheConfig) CheckAndSetDefaults() error { + if c.TTL <= 0 { + return trace.BadParameter("missing TTL parameter") + } + + if c.Clock == nil { + c.Clock = clockwork.NewRealClock() } + + return nil +} + +func NewFnCache(cfg FnCacheConfig) (*FnCache, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + return &FnCache{ + cfg: cfg, + entries: make(map[interface{}]*fnCacheEntry), + }, nil } type fnCacheEntry struct { @@ -53,11 +77,11 @@ type fnCacheEntry struct { loaded chan struct{} } -func (c *fnCache) removeExpiredLocked(now time.Time) { +func (c *FnCache) removeExpiredLocked(now time.Time) { for key, entry := range c.entries { select { case <-entry.loaded: - if now.After(entry.t.Add(c.ttl)) { + if now.After(entry.t.Add(c.cfg.TTL)) { delete(c.entries, key) } default: @@ -71,15 +95,15 @@ func (c *fnCache) removeExpiredLocked(now time.Time) { // block until the first call updates the entry. Note that the supplied context can cancel the call to Get, but will // not cancel loading. The supplied loadfn should not be canceled just because the specific request happens to have // been canceled. -func (c *fnCache) Get(ctx context.Context, key interface{}, loadfn func() (interface{}, error)) (interface{}, error) { +func (c *FnCache) Get(ctx context.Context, key interface{}, loadfn func() (interface{}, error)) (interface{}, error) { c.mu.Lock() - now := time.Now() + now := c.cfg.Clock.Now() // check if we need to perform periodic cleanup if now.After(c.nextCleanup) { c.removeExpiredLocked(now) - c.nextCleanup = now.Add(c.ttl * cleanupMultiplier) + c.nextCleanup = now.Add(c.cfg.TTL * cleanupMultiplier) } entry := c.entries[key] @@ -89,7 +113,7 @@ func (c *fnCache) Get(ctx context.Context, key interface{}, loadfn func() (inter if entry != nil { select { case <-entry.loaded: - needsReload = now.After(entry.t.Add(c.ttl)) + needsReload = now.After(entry.t.Add(c.cfg.TTL)) default: // reload is already in progress needsReload = false @@ -105,7 +129,7 @@ func (c *fnCache) Get(ctx context.Context, key interface{}, loadfn func() (inter c.entries[key] = entry go func() { entry.v, entry.e = loadfn() - entry.t = time.Now() + entry.t = c.cfg.Clock.Now() close(entry.loaded) }() } diff --git a/lib/cache/fncache_test.go b/lib/utils/fncache_test.go similarity index 85% rename from lib/cache/fncache_test.go rename to lib/utils/fncache_test.go index b5574b4d6ab26..e54d039e7d27f 100644 --- a/lib/cache/fncache_test.go +++ b/lib/utils/fncache_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package cache +package utils import ( "context" @@ -27,7 +27,34 @@ import ( "go.uber.org/atomic" ) -// TestFnCacheSanity runs basic fnCache test cases. +func TestFnCache_New(t *testing.T) { + cases := []struct { + desc string + config FnCacheConfig + assertion require.ErrorAssertionFunc + }{ + { + desc: "invalid ttl", + config: FnCacheConfig{TTL: 0}, + assertion: require.Error, + }, + + { + desc: "valid ttl", + config: FnCacheConfig{TTL: time.Second}, + assertion: require.NoError, + }, + } + + for _, tt := range cases { + t.Run(tt.desc, func(t *testing.T) { + _, err := NewFnCache(tt.config) + tt.assertion(t, err) + }) + } +} + +// TestFnCacheSanity runs basic FnCache test cases. func TestFnCacheSanity(t *testing.T) { tts := []struct { ttl time.Duration @@ -55,7 +82,8 @@ func testFnCacheSimple(t *testing.T, ttl time.Duration, delay time.Duration, msg ctx, cancel := context.WithCancel(context.Background()) defer cancel() - cache := newFnCache(ttl) + cache, err := NewFnCache(FnCacheConfig{TTL: ttl}) + require.NoError(t, err) // readCounter is incremented upon each cache miss. readCounter := atomic.NewInt64(0) @@ -123,7 +151,8 @@ func testFnCacheSimple(t *testing.T, ttl time.Duration, delay time.Duration, msg func TestFnCacheCancellation(t *testing.T) { const timeout = time.Millisecond * 10 - cache := newFnCache(time.Minute) + cache, err := NewFnCache(FnCacheConfig{TTL: time.Minute}) + require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() diff --git a/lib/utils/timed_counter.go b/lib/utils/timed_counter.go index 2798b7cd1f62f..c83ff15339a92 100644 --- a/lib/utils/timed_counter.go +++ b/lib/utils/timed_counter.go @@ -32,7 +32,7 @@ type TimedCounter struct { events []time.Time } -// TimedCounted creates a new timed counter with the specified timeout +// NewTimedCounter creates a new timed counter with the specified timeout func NewTimedCounter(clock clockwork.Clock, timeout time.Duration) *TimedCounter { return &TimedCounter{ clock: clock, diff --git a/lib/utils/workpool/workpool.go b/lib/utils/workpool/workpool.go index 8236ac91e1726..d2737783734fb 100644 --- a/lib/utils/workpool/workpool.go +++ b/lib/utils/workpool/workpool.go @@ -23,14 +23,14 @@ import ( "go.uber.org/atomic" ) -// Pool manages a collection of work groups by key and is the primary means -// by which groups are managed. Each work group has an adjustable target value +// Pool manages a collection of work group by key and is the primary means +// by which group are managed. Each work group has an adjustable target value // which is the number of target leases which should be active for the given // group. type Pool struct { mu sync.Mutex leaseIDs *atomic.Uint64 - groups map[interface{}]*group + group *group // grantC is an unbuffered channel that funnels available leases from the // workgroups to the outside world grantC chan Lease @@ -42,7 +42,6 @@ func NewPool(ctx context.Context) *Pool { ctx, cancel := context.WithCancel(ctx) return &Pool{ leaseIDs: atomic.NewUint64(0), - groups: make(map[interface{}]*group), grantC: make(chan Lease), ctx: ctx, cancel: cancel, @@ -53,7 +52,7 @@ func NewPool(ctx context.Context) *Pool { // new leases. Each lease acquired in this way *must* have its // Release method called when the lease is no longer needed. // Note this channel will deliver leases from all active work -// groups. It's up to the receiver to differentiate what group +// group. It's up to the receiver to differentiate what group // the lease refers to and act accordingly. func (p *Pool) Acquire() <-chan Lease { return p.grantC @@ -65,34 +64,37 @@ func (p *Pool) Done() <-chan struct{} { } // Get gets the current counts for the specified key. -func (p *Pool) Get(key interface{}) Counts { +func (p *Pool) Get() Counts { p.mu.Lock() defer p.mu.Unlock() - if g, ok := p.groups[key]; ok { - return g.loadCounts() + + if p.group == nil { + return Counts{} } - return Counts{} + + return p.group.loadCounts() } // Set sets the target for the specified key. -func (p *Pool) Set(key interface{}, target uint64) { +func (p *Pool) Set(target uint64) { p.mu.Lock() defer p.mu.Unlock() if target < 1 { - p.del(key) + p.del() return } - g, ok := p.groups[key] - if !ok { - p.start(key, target) + + if p.group == nil { + p.start(target) return } - g.setTarget(target) + + p.group.setTarget(target) } // Start starts a new work group with the specified initial target. // If Start returns false, the group already exists. -func (p *Pool) start(key interface{}, target uint64) { +func (p *Pool) start(target uint64) { ctx, cancel := context.WithCancel(p.ctx) notifyC := make(chan struct{}, 1) g := &group{ @@ -101,13 +103,12 @@ func (p *Pool) start(key interface{}, target uint64) { Target: target, }, leaseIDs: p.leaseIDs, - key: key, grantC: p.grantC, notifyC: notifyC, ctx: ctx, cancel: cancel, } - p.groups[key] = g + p.group = g // Start a routine to monitor the group's lease acquisition // and handle notifications when a lease is returned to the @@ -115,17 +116,17 @@ func (p *Pool) start(key interface{}, target uint64) { go g.run() } -func (p *Pool) del(key interface{}) (ok bool) { - group, ok := p.groups[key] - if !ok { +func (p *Pool) del() (ok bool) { + if p.group == nil { return false } - group.cancel() - delete(p.groups, key) + + p.group.cancel() + p.group = nil return true } -// Stop permanently halts all associated groups. +// Stop permanently halts all associated group. func (p *Pool) Stop() { p.cancel() } @@ -146,7 +147,6 @@ type group struct { cmu sync.Mutex counts Counts leaseIDs *atomic.Uint64 - key interface{} grantC chan Lease notifyC chan struct{} ctx context.Context @@ -231,7 +231,7 @@ func (g *group) run() { // is called (usually by a lease being returned), or the context is // canceled. // - // Otherwise we post the lease to the outside world and block on + // Otherwise, we post the lease to the outside world and block on // someone picking it up (or cancellation) select { case grant <- nextLease: @@ -267,11 +267,6 @@ func (l Lease) ID() uint64 { return l.id } -// Key returns the key that this lease is associated with. -func (l Lease) Key() interface{} { - return l.key -} - // IsZero checks if this is the zero value of Lease. func (l Lease) IsZero() bool { return l == Lease{} diff --git a/lib/utils/workpool/workpool_test.go b/lib/utils/workpool/workpool_test.go index dac46a41fc914..d50510350cef3 100644 --- a/lib/utils/workpool/workpool_test.go +++ b/lib/utils/workpool/workpool_test.go @@ -18,65 +18,21 @@ package workpool import ( "context" - "fmt" "sync" "testing" "time" - - "gopkg.in/check.v1" ) -func Example() { - pool := NewPool(context.TODO()) - defer pool.Stop() - // create two keys with different target counts - pool.Set("spam", 2) - pool.Set("eggs", 1) - // track how many workers are spawned for each key - counts := make(map[string]int) - var mu sync.Mutex - var wg sync.WaitGroup - for i := 0; i < 12; i++ { - wg.Add(1) - go func() { - lease := <-pool.Acquire() - defer lease.Release() - mu.Lock() - counts[lease.Key().(string)]++ - mu.Unlock() - // in order to demonstrate the differing spawn rates we need - // work to take some time, otherwise pool will end up granting - // leases in a "round robin" fashion. - time.Sleep(time.Millisecond * 10) - wg.Done() - }() - } - wg.Wait() - // exact counts will vary, but leases with key `spam` - // will end up being generated approximately twice as - // often as leases with key `eggs`. - fmt.Println(counts["spam"] > counts["eggs"]) // Output: true -} - -func Test(t *testing.T) { - check.TestingT(t) -} - -type WorkSuite struct{} - -var _ = check.Suite(&WorkSuite{}) - // TestFull runs a pool though a round of normal usage, // and verifies expected state along the way: -// - A group of workers acquire leases, do some work, and release them. -// - A second group of workers receieve leases as the first group finishes. -// - The expected amout of leases are in play after this churn. +// - A group of workers acquires leases, do some work, and release them. +// - A second group of workers receives leases as the first group finishes. +// - The expected amount of leases is in play after this churn. // - Updating the target lease count has the expected effect. -func (s *WorkSuite) TestFull(c *check.C) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() +func TestFull(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) p := NewPool(ctx) - key := "some-key" var wg sync.WaitGroup // signal channel to cause the first group of workers to // release their leases. @@ -88,7 +44,7 @@ func (s *WorkSuite) TestFull(c *check.C) { time.Sleep(time.Millisecond * 500) close(g1timeout) }() - p.Set(key, 200) + p.Set(200) // spawn first group of workers. for i := 0; i < 200; i++ { wg.Add(1) @@ -98,7 +54,7 @@ func (s *WorkSuite) TestFull(c *check.C) { <-g1done l.Release() case <-g1timeout: - c.Errorf("Timeout waiting for lease") + t.Errorf("Timeout waiting for lease") } wg.Done() }() @@ -107,7 +63,7 @@ func (s *WorkSuite) TestFull(c *check.C) { // no additional leases should exist select { case l := <-p.Acquire(): - c.Errorf("unexpected lease: %+v", l) + t.Errorf("unexpected lease: %+v", l) default: } // spawn a second group of workers that won't be able to @@ -119,7 +75,7 @@ func (s *WorkSuite) TestFull(c *check.C) { case <-p.Acquire(): // leak deliberately case <-time.After(time.Millisecond * 512): - c.Errorf("Timeout waiting for lease") + t.Errorf("Timeout waiting for lease") } wg.Done() }() @@ -132,46 +88,43 @@ func (s *WorkSuite) TestFull(c *check.C) { select { case l := <-p.Acquire(): counts := l.loadCounts() - c.Errorf("unexpected lease grant: %+v, counts=%+v", l, counts) + t.Errorf("unexpected lease grant: %+v, counts=%+v", l, counts) case <-time.After(time.Millisecond * 128): } // make one additional lease available - p.Set(key, 201) + p.Set(201) select { case l := <-p.Acquire(): - c.Assert(l.Key().(string), check.Equals, key) l.Release() case <-time.After(time.Millisecond * 128): - c.Errorf("timeout waiting for lease grant") + t.Errorf("timeout waiting for lease grant") } } -// TestZeroed varifies that a zeroed pool stops granting +// TestZeroed verifies that a zeroed pool stops granting // leases as expected. -func (s *WorkSuite) TestZeroed(c *check.C) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() +func TestZeroed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) p := NewPool(ctx) - key := "some-key" - p.Set(key, 1) + p.Set(1) var l Lease select { case l = <-p.Acquire(): - c.Assert(l.Key().(string), check.Equals, key) l.Release() case <-time.After(time.Millisecond * 128): - c.Errorf("timeout waiting for lease grant") + t.Errorf("timeout waiting for lease grant") } - p.Set(key, 0) + p.Set(0) // modifications to counts are *ordered*, but asynchronous, - // so we could actually receieve a lease here if we don't sleep + // so we could actually receive a lease here if we don't sleep // briefly. if we opted for condvars instead of channels, this // issue could be avoided at the cost of more cumbersome // composition/cancellation. time.Sleep(time.Millisecond * 10) select { case l := <-p.Acquire(): - c.Errorf("unexpected lease grant: %+v", l) + t.Errorf("unexpected lease grant: %+v", l) case <-time.After(time.Millisecond * 128): } } diff --git a/tool/tctl/common/tctl.go b/tool/tctl/common/tctl.go index 1da70c878d30d..edc6bf97b9652 100644 --- a/tool/tctl/common/tctl.go +++ b/tool/tctl/common/tctl.go @@ -25,7 +25,6 @@ import ( "github.com/gravitational/teleport" apiclient "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" @@ -222,20 +221,18 @@ func connectToAuthService(ctx context.Context, cfg *service.Config, clientConfig errs := []error{err} - // Figure out the reverse tunnel address on the proxy first. - tunAddr, err := findReverseTunnel(ctx, cfg.AuthServers, clientConfig.TLS.InsecureSkipVerify) - if err != nil { - errs = append(errs, trace.Wrap(err, "failed lookup of proxy reverse tunnel address: %v", err)) - return nil, trace.NewAggregate(errs...) - } - log.Debugf("Attempting to connect using reverse tunnel address %v.", tunAddr) // reversetunnel.TunnelAuthDialer will take care of creating a net.Conn // within an SSH tunnel. + dialer, err := reversetunnel.NewTunnelAuthDialer(reversetunnel.TunnelAuthDialerConfig{ + Resolver: reversetunnel.WebClientResolver(ctx, cfg.AuthServers, clientConfig.TLS.InsecureSkipVerify), + ClientConfig: clientConfig.SSH, + Log: cfg.Log, + }) + if err != nil { + return nil, trace.Wrap(err) + } client, err = auth.NewClient(apiclient.Config{ - Dialer: &reversetunnel.TunnelAuthDialer{ - ProxyAddr: tunAddr, - ClientConfig: clientConfig.SSH, - }, + Dialer: dialer, Credentials: []apiclient.Credentials{ apiclient.LoadTLS(clientConfig.TLS), }, @@ -253,22 +250,6 @@ func connectToAuthService(ctx context.Context, cfg *service.Config, clientConfig return client, nil } -// findReverseTunnel uses the web proxy to discover where the SSH reverse tunnel -// server is running. -func findReverseTunnel(ctx context.Context, addrs []utils.NetAddr, insecureTLS bool) (string, error) { - var errs []error - for _, addr := range addrs { - // In insecure mode, any certificate is accepted. In secure mode the hosts - // CAs are used to validate the certificate on the proxy. - tunnelAddr, err := webclient.GetTunnelAddr(ctx, addr.String(), insecureTLS, nil) - if err == nil { - return tunnelAddr, nil - } - errs = append(errs, err) - } - return "", trace.NewAggregate(errs...) -} - // applyConfig takes configuration values from the config file and applies // them to 'service.Config' object. //