From b01894e3e659505f277fa556e6b4a51e61a12f63 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Wed, 11 May 2022 14:18:47 -0400 Subject: [PATCH] Stop loading the enitre node set into memory per tsh ssh connection (#12014) (#12571) * Prevent proxy from loading entire node set into memory more than once When establishing a new session to a node, the proxy would load the entire node set into memory in an attempt to find the matching host. For smaller clusters this may not be that problematic. But on larger clusters, loading >40k nodes into memory from the cache can be quite expensive. This problem is compounded by the fact that it happened**per** session, which could potentially cause the proxy to consume all available memory and be OOM killed. A new `NodeWatcher` is introduced which will maintain an in memory list of all nodes per process. The watcher leverages the existing resource watcher system and stores all nodes as types.Server, to eliminate the cost incurred by unmarshalling the nodes from the cache. The `NodeWatcher` provides a way to retrieve a filtered list of nodes in order to reduce the number of copies made to only the matches. (cherry picked from commit fa12352214ea382633277e92d3217853e715e2ac) --- lib/reversetunnel/api.go | 3 + lib/reversetunnel/localsite.go | 30 +- lib/reversetunnel/peer.go | 12 + lib/reversetunnel/remotesite.go | 13 +- lib/reversetunnel/srv.go | 34 +- lib/service/service.go | 13 + lib/services/presence.go | 13 +- lib/services/watcher.go | 156 ++++ lib/services/watcher_test.go | 78 +- lib/srv/regular/proxy.go | 55 +- lib/srv/regular/proxy_test.go | 21 +- lib/srv/regular/sshserver.go | 11 + lib/srv/regular/sshserver_test.go | 18 + lib/web/apiserver.go | 13 +- lib/web/apiserver_test.go | 1149 +++++++++++++++-------------- lib/web/files.go | 15 +- lib/web/terminal.go | 65 +- lib/web/ui/cluster.go | 5 +- tool/tsh/proxy_test.go | 20 +- tool/tsh/tsh.go | 26 +- 20 files changed, 1078 insertions(+), 672 deletions(-) diff --git a/lib/reversetunnel/api.go b/lib/reversetunnel/api.go index 8cbe8b31c6b2c..85b50fac56b7e 100644 --- a/lib/reversetunnel/api.go +++ b/lib/reversetunnel/api.go @@ -25,6 +25,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" ) @@ -94,6 +95,8 @@ type RemoteSite interface { // CachingAccessPoint returns access point that is lightweight // but is resilient to auth server crashes CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) + // NodeWatcher returns the node watcher that maintains the node set for the site + NodeWatcher() (*services.NodeWatcher, error) // GetTunnelsCount returns the amount of active inbound tunnels // from the remote cluster GetTunnelsCount() int diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index f96e8a00e9c5a..d8c73f616990c 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -22,8 +22,6 @@ import ( "sync" "time" - "golang.org/x/crypto/ssh" - "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" @@ -34,11 +32,12 @@ import ( "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/proxy" - "github.com/prometheus/client_golang/prometheus" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" ) func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSite, error) { @@ -129,6 +128,11 @@ func (s *localSite) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) { return s.accessPoint, nil } +// NodeWatcher returns a services.NodeWatcher for this cluster. +func (s *localSite) NodeWatcher() (*services.NodeWatcher, error) { + return s.srv.NodeWatcher, nil +} + // GetClient returns a client to the full Auth Server API. func (s *localSite) GetClient() (auth.ClientI, error) { return s.client, nil @@ -519,14 +523,7 @@ func (s *localSite) periodicFunctions() { // sshTunnelStats reports SSH tunnel statistics for the cluster. func (s *localSite) sshTunnelStats() error { - servers, err := s.accessPoint.GetNodes(s.srv.ctx, apidefaults.Namespace) - if err != nil { - return trace.Wrap(err) - } - - var missing []string - - for _, server := range servers { + missing := s.srv.NodeWatcher.GetNodes(func(server services.Node) bool { // Skip over any servers that that have a TTL larger than announce TTL (10 // minutes) and are non-IoT SSH servers (they won't have tunnels). // @@ -535,10 +532,10 @@ func (s *localSite) sshTunnelStats() error { // their TTL value. ttl := s.clock.Now().Add(-1 * apidefaults.ServerAnnounceTTL) if server.Expiry().Before(ttl) { - continue + return false } if !server.GetUseTunnel() { - continue + return false } // Check if the tunnel actually exists. @@ -546,12 +543,9 @@ func (s *localSite) sshTunnelStats() error { ServerID: fmt.Sprintf("%v.%v", server.GetName(), s.domainName), ConnType: types.NodeTunnel, }) - if err == nil { - continue - } - missing = append(missing, server.GetName()) - } + return err != nil + }) // Update Prometheus metrics and also log if any tunnels are missing. missingSSHTunnels.Set(float64(len(missing))) diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index 9fab4c78201f9..1f65b2404e5a0 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -87,6 +87,14 @@ func (p *clusterPeers) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) return peer.CachingAccessPoint() } +func (p *clusterPeers) NodeWatcher() (*services.NodeWatcher, error) { + peer, err := p.pickPeer() + if err != nil { + return nil, trace.Wrap(err) + } + return peer.NodeWatcher() +} + func (p *clusterPeers) GetClient() (auth.ClientI, error) { peer, err := p.pickPeer() if err != nil { @@ -191,6 +199,10 @@ func (s *clusterPeer) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s) } +func (s *clusterPeer) NodeWatcher() (*services.NodeWatcher, error) { + return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s) +} + func (s *clusterPeer) GetClient() (auth.ClientI, error) { return nil, trace.ConnectionProblem(nil, "unable to fetch client, this proxy %v has not been discovered yet, try again later", s) } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 5419f4b3cbbca..eca009ced1230 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -34,13 +34,14 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) -// remoteSite is a remote site that established the inbound connecton to +// remoteSite is a remote site that established the inbound connection to // the local reverse tunnel server, and now it can provide access to the // cluster behind it. type remoteSite struct { @@ -77,6 +78,9 @@ type remoteSite struct { // the remote cluster this site belongs to. remoteAccessPoint auth.RemoteProxyAccessPoint + // nodeWatcher provides access the node set for the remote site + nodeWatcher *services.NodeWatcher + // remoteCA is the last remote certificate authority recorded by the client. // It is used to detect CA rotation status changes. If the rotation // state has been changed, the tunnel will reconnect to re-create the client @@ -138,6 +142,11 @@ func (s *remoteSite) CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) { return s.remoteAccessPoint, nil } +// NodeWatcher returns the services.NodeWatcher for the remote cluster. +func (s *remoteSite) NodeWatcher() (*services.NodeWatcher, error) { + return s.nodeWatcher, nil +} + func (s *remoteSite) GetClient() (auth.ClientI, error) { return s.remoteClient, nil } @@ -379,7 +388,7 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch } else { s.WithFields(log.Fields{"nodeID": conn.nodeID}).Debugf("Ping <- %v", conn.conn.RemoteAddr()) } - tm := time.Now().UTC() + tm := s.clock.Now().UTC() conn.setLastHeartbeat(tm) go s.registerHeartbeat(tm) // Note that time.After is re-created everytime a request is processed. diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index d101e40e96403..829d23d9b228f 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -38,6 +38,7 @@ import ( "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" @@ -201,6 +202,9 @@ type Config struct { // LockWatcher is a lock watcher. LockWatcher *services.LockWatcher + + // NodeWatcher is a node watcher. + NodeWatcher *services.NodeWatcher } // CheckAndSetDefaults checks parameters and sets default values @@ -252,6 +256,9 @@ func (cfg *Config) CheckAndSetDefaults() error { if cfg.LockWatcher == nil { return trace.BadParameter("missing parameter LockWatcher") } + if cfg.NodeWatcher == nil { + return trace.BadParameter("missing parameter NodeWatcher") + } return nil } @@ -890,7 +897,7 @@ func (s *server) upsertRemoteCluster(conn net.Conn, sshConn *ssh.ServerConn) (*r // treat first connection as a registered heartbeat, // otherwise the connection information will appear after initial // heartbeat delay - go site.registerHeartbeat(time.Now()) + go site.registerHeartbeat(s.Clock.Now()) return site, remoteConn, nil } @@ -1023,7 +1030,7 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, types.TunnelConnectionSpecV2{ ClusterName: domainName, ProxyName: srv.ID, - LastHeartbeat: time.Now().UTC(), + LastHeartbeat: srv.Clock.Now().UTC(), }, ) if err != nil { @@ -1055,27 +1062,42 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, clt, _, err := remoteSite.getRemoteClient() if err != nil { + cancel() return nil, trace.Wrap(err) } remoteSite.remoteClient = clt remoteVersion, err := getRemoteAuthVersion(closeContext, sconn) if err != nil { + cancel() return nil, trace.Wrap(err) } accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName) if err != nil { + cancel() return nil, trace.Wrap(err) } remoteSite.remoteAccessPoint = accessPoint - + nodeWatcher, err := services.NewNodeWatcher(closeContext, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: srv.Component, + Client: accessPoint, + Log: srv.Log, + }, + }) + if err != nil { + cancel() + return nil, trace.Wrap(err) + } + remoteSite.nodeWatcher = nodeWatcher // instantiate a cache of host certificates for the forwarding server. the // certificate cache is created in each site (instead of creating it in // reversetunnel.server and passing it along) so that the host certificate // is signed by the correct certificate authority. certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient) if err != nil { + cancel() return nil, trace.Wrap(err) } remoteSite.certificateCache = certificateCache @@ -1088,7 +1110,8 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, Clock: srv.Clock, }) if err != nil { - return nil, err + cancel() + return nil, trace.Wrap(err) } go remoteSite.updateCertAuthorities(caRetry) @@ -1101,7 +1124,8 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, Clock: srv.Clock, }) if err != nil { - return nil, err + cancel() + return nil, trace.Wrap(err) } go remoteSite.updateLocks(lockRetry) diff --git a/lib/service/service.go b/lib/service/service.go index 2d87714ded8fe..8ee3698444693 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2809,6 +2809,17 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } + nodeWatcher, err := services.NewNodeWatcher(process.ExitContext(), services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: process.log.WithField(trace.Component, teleport.ComponentProxy), + Client: conn.Client, + }, + }) + if err != nil { + return trace.Wrap(err) + } + serverTLSConfig, err := conn.ServerIdentity.TLSConfig(cfg.CipherSuites) if err != nil { return trace.Wrap(err) @@ -2848,6 +2859,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Emitter: streamEmitter, Log: process.log, LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, }) if err != nil { return trace.Wrap(err) @@ -2976,6 +2988,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { regular.SetOnHeartbeat(process.onHeartbeat(teleport.ComponentProxy)), regular.SetEmitter(streamEmitter), regular.SetLockWatcher(lockWatcher), + regular.SetNodeWatcher(nodeWatcher), ) if err != nil { return trace.Wrap(err) diff --git a/lib/services/presence.go b/lib/services/presence.go index b218e87058d68..2889cb7768ac3 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -29,6 +29,12 @@ type ProxyGetter interface { GetProxies() ([]types.Server, error) } +// NodesGetter is a service that gets nodes. +type NodesGetter interface { + // GetNodes returns a list of registered servers. + GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error) +} + // Presence records and reports the presence of all components // of the cluster - Nodes, Proxies and SSH nodes type Presence interface { @@ -43,13 +49,12 @@ type Presence interface { // GetNode returns a node by name and namespace. GetNode(ctx context.Context, namespace, name string) (types.Server, error) - - // GetNodes returns a list of registered servers. - GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error) - // ListNodes returns a paginated list of registered servers. ListNodes(ctx context.Context, req proto.ListNodesRequest) (nodes []types.Server, nextKey string, err error) + // NodesGetter gets nodes + NodesGetter + // DeleteAllNodes deletes all nodes in a namespace. DeleteAllNodes(ctx context.Context, namespace string) error diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 16cb6dbd88801..d472f76da09c8 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -22,6 +22,7 @@ import ( "time" "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" @@ -1087,3 +1088,158 @@ func casToSlice(host map[string]types.CertAuthority, user map[string]types.CertA } return slice } + +// NodeWatcherConfig is a NodeWatcher configuration. +type NodeWatcherConfig struct { + ResourceWatcherConfig + // NodesGetter is used to directly fetch the list of active nodes. + NodesGetter +} + +// CheckAndSetDefaults checks parameters and sets default values. +func (cfg *NodeWatcherConfig) CheckAndSetDefaults() error { + if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + if cfg.NodesGetter == nil { + getter, ok := cfg.Client.(NodesGetter) + if !ok { + return trace.BadParameter("missing parameter NodesGetter and Client not usable as NodesGetter") + } + cfg.NodesGetter = getter + } + return nil +} + +// NewNodeWatcher returns a new instance of NodeWatcher. +func NewNodeWatcher(ctx context.Context, cfg NodeWatcherConfig) (*NodeWatcher, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + collector := &nodeCollector{ + NodeWatcherConfig: cfg, + current: map[string]types.Server{}, + } + watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) + if err != nil { + return nil, trace.Wrap(err) + } + return &NodeWatcher{watcher, collector}, nil +} + +// NodeWatcher is built on top of resourceWatcher to monitor additions +// and deletions to the set of nodes. +type NodeWatcher struct { + *resourceWatcher + *nodeCollector +} + +// nodeCollector accompanies resourceWatcher when monitoring nodes. +type nodeCollector struct { + NodeWatcherConfig + // current holds a map of the currently known nodes (keyed by server name, + // RWMutex protected). + current map[string]types.Server + rw sync.RWMutex +} + +// Node is a readonly subset of the types.Server interface which +// users may filter by in GetNodes. +type Node interface { + // Resource provides common resource headers + types.Resource + // GetTeleportVersion returns the teleport version the server is running on + GetTeleportVersion() string + // GetAddr return server address + GetAddr() string + // GetHostname returns server hostname + GetHostname() string + // GetNamespace returns server namespace + GetNamespace() string + // GetLabels returns server's static label key pairs + GetLabels() map[string]string + // GetCmdLabels gets command labels + GetCmdLabels() map[string]types.CommandLabel + // GetPublicAddr is an optional field that returns the public address this cluster can be reached at. + GetPublicAddr() string + // GetRotation gets the state of certificate authority rotation. + GetRotation() types.Rotation + // GetUseTunnel gets if a reverse tunnel should be used to connect to this node. + GetUseTunnel() bool +} + +// GetNodes allows callers to retrieve a subset of nodes that match the filter provided. The +// returned servers are a copy and can be safely modified. It is intentionally hard to retrieve +// the full set of nodes to reduce the number of copies needed since the number of nodes can get +// quite large and doing so can be expensive. +func (n *nodeCollector) GetNodes(fn func(n Node) bool) []types.Server { + n.rw.RLock() + defer n.rw.RUnlock() + + var matched []types.Server + for _, server := range n.current { + if fn(server) { + matched = append(matched, server.DeepCopy()) + } + } + + return matched +} + +func (n *nodeCollector) NodeCount() int { + n.rw.RLock() + defer n.rw.RUnlock() + return len(n.current) +} + +// resourceKind specifies the resource kind to watch. +func (n *nodeCollector) resourceKind() string { + return types.KindNode +} + +// getResourcesAndUpdateCurrent is called when the resources should be +// (re-)fetched directly. +func (n *nodeCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { + nodes, err := n.NodesGetter.GetNodes(ctx, apidefaults.Namespace) + if err != nil { + return trace.Wrap(err) + } + if len(nodes) == 0 { + return nil + } + newCurrent := make(map[string]types.Server, len(nodes)) + for _, node := range nodes { + newCurrent[node.GetName()] = node + } + n.rw.Lock() + defer n.rw.Unlock() + n.current = newCurrent + return nil +} + +// processEventAndUpdateCurrent is called when a watcher event is received. +func (n *nodeCollector) processEventAndUpdateCurrent(ctx context.Context, event types.Event) { + if event.Resource == nil || event.Resource.GetKind() != types.KindNode { + n.Log.Warningf("Unexpected event: %v.", event) + return + } + + n.rw.Lock() + defer n.rw.Unlock() + + switch event.Type { + case types.OpDelete: + delete(n.current, event.Resource.GetName()) + case types.OpPut: + server, ok := event.Resource.(types.Server) + if !ok { + n.Log.Warningf("Unexpected type %T.", event.Resource) + return + } + n.current[server.GetName()] = server + default: + n.Log.Warningf("Skipping unsupported event type %s.", event.Type) + } +} + +func (n *nodeCollector) notifyStale() {} diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 4d7b3cfbcacc9..6521e65aa30d9 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -20,13 +20,19 @@ import ( "context" "crypto/x509/pkix" "errors" + "fmt" "sync" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend/lite" @@ -34,9 +40,6 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/tlsca" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/require" ) var _ types.Events = (*errorWatcher)(nil) @@ -853,3 +856,72 @@ func newCertAuthority(t *testing.T, name string, caType types.CertAuthType) type require.NoError(t, err) return ca } + +func TestNodeWatcher(t *testing.T) { + t.Parallel() + ctx := context.Background() + + bk, err := lite.NewWithConfig(ctx, lite.Config{ + Path: t.TempDir(), + PollStreamPeriod: 200 * time.Millisecond, + }) + require.NoError(t, err) + + type client struct { + services.Presence + types.Events + } + + presence := local.NewPresenceService(bk) + w, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: &client{ + Presence: presence, + Events: local.NewEventsService(bk), + }, + }, + }) + require.NoError(t, err) + t.Cleanup(w.Close) + + // Add some node servers. + nodes := make([]types.Server, 0, 5) + for i := 0; i < 5; i++ { + node := newNodeServer(t, fmt.Sprintf("node%d", i), "127.0.0.1:2023", i%2 == 0) + _, err = presence.UpsertNode(ctx, node) + require.NoError(t, err) + nodes = append(nodes, node) + } + + require.Eventually(t, func() bool { + filtered := w.GetNodes(func(n services.Node) bool { + return true + }) + return len(filtered) == len(nodes) + }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") + + require.Len(t, w.GetNodes(func(n services.Node) bool { return n.GetUseTunnel() }), 3) + + require.NoError(t, presence.DeleteNode(ctx, apidefaults.Namespace, nodes[0].GetName())) + + require.Eventually(t, func() bool { + filtered := w.GetNodes(func(n services.Node) bool { + return true + }) + return len(filtered) == len(nodes)-1 + }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") + + require.Empty(t, w.GetNodes(func(n services.Node) bool { return n.GetName() == nodes[0].GetName() })) + +} + +func newNodeServer(t *testing.T, name, addr string, tunnel bool) types.Server { + s, err := types.NewServer(name, types.KindNode, types.ServerSpecV2{ + Addr: addr, + PublicAddr: addr, + UseTunnel: tunnel, + }) + require.NoError(t, err) + return s +} diff --git a/lib/srv/regular/proxy.go b/lib/srv/regular/proxy.go index 75313bc9b8e74..b3502079c2fac 100644 --- a/lib/srv/regular/proxy.go +++ b/lib/srv/regular/proxy.go @@ -34,6 +34,7 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -322,18 +323,15 @@ func (t *proxySubsys) proxyToHost( // network resolution (by IP or DNS) // var ( - strategy types.RoutingStrategy - servers []types.Server - err error + strategy types.RoutingStrategy + nodeWatcher *services.NodeWatcher + err error ) localCluster, _ := t.srv.proxyAccessPoint.GetClusterName() // going to "local" CA? lets use the caching 'auth service' directly and avoid // hitting the reverse tunnel link (it can be offline if the CA is down) if site.GetName() == localCluster.GetName() { - servers, err = t.srv.proxyAccessPoint.GetNodes(ctx.CancelContext(), t.namespace) - if err != nil { - t.log.Warn(err) - } + nodeWatcher = t.srv.nodeWatcher cfg, err := t.srv.authService.GetClusterNetworkingConfig(ctx.CancelContext()) if err != nil { @@ -347,9 +345,11 @@ func (t *proxySubsys) proxyToHost( if err != nil { t.log.Warn(err) } else { - servers, err = siteClient.GetNodes(ctx.CancelContext(), t.namespace) + watcher, err := site.NodeWatcher() if err != nil { t.log.Warn(err) + } else { + nodeWatcher = watcher } cfg, err := siteClient.GetClusterNetworkingConfig(ctx.CancelContext()) @@ -366,7 +366,7 @@ func (t *proxySubsys) proxyToHost( t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v, strategy=%s", t.host, t.port, t.SpecifiedPort(), strategy) // determine which server to connect to - server, err := t.getMatchingServer(servers, strategy) + server, err := t.getMatchingServer(nodeWatcher, strategy) if err != nil { return trace.Wrap(err) } @@ -453,47 +453,58 @@ func (t *proxySubsys) proxyToHost( return nil } +// NodesGetter is a function that retrieves a subset of nodes matching +// the filter criteria. +type NodesGetter interface { + GetNodes(fn func(n services.Node) bool) []types.Server +} + // getMatchingServer determines the server to connect to from the provided servers. Duplicate entries are treated // differently based on strategy. Legacy behavior of returning an ambiguous error occurs if the strategy // is types.RoutingStrategy_UNAMBIGUOUS_MATCH. When the strategy is types.RoutingStrategy_MOST_RECENT then // the server that has heartbeated most recently will be returned instead of an error. If no matches are found then // both the types.Server and error returned will be nil. -func (t *proxySubsys) getMatchingServer(servers []types.Server, strategy types.RoutingStrategy) (types.Server, error) { +func (t *proxySubsys) getMatchingServer(watcher NodesGetter, strategy types.RoutingStrategy) (types.Server, error) { + if watcher == nil { + return nil, trace.NotFound("unable to retrieve nodes matching host %s", t.host) + } + // check if hostname is a valid uuid or EC2 node ID. If it is, we will // preferentially match by node ID over node hostname. hostIsUniqueID := uuid.Parse(t.host) != nil || utils.IsEC2NodeID(t.host) ips, _ := net.LookupHost(t.host) + var unambiguousIDMatch bool // enumerate and try to find a server with self-registered with a matching name/IP: - var matches []types.Server - for _, server := range servers { + matches := watcher.GetNodes(func(server services.Node) bool { + if unambiguousIDMatch { + return false + } + // If the host parameter is a UUID or EC2 node ID, and it matches the // Node ID, treat this as an unambiguous match. if hostIsUniqueID && server.GetName() == t.host { - matches = []types.Server{server} - break + unambiguousIDMatch = true + return true } // If the server has connected over a reverse tunnel, match only on hostname. if server.GetUseTunnel() { - if t.host == server.GetHostname() { - matches = append(matches, server) - } - continue + return t.host == server.GetHostname() } ip, port, err := net.SplitHostPort(server.GetAddr()) if err != nil { t.log.Errorf("Failed to parse address %q: %v.", server.GetAddr(), err) - continue + return false } if t.host == ip || t.host == server.GetHostname() || apiutils.SliceContainsStr(ips, ip) { if !t.SpecifiedPort() || t.port == port { - matches = append(matches, server) - continue + return true } } - } + return false + }) var server types.Server switch { diff --git a/lib/srv/regular/proxy_test.go b/lib/srv/regular/proxy_test.go index fcfc406f5e58b..d6259bbb209cf 100644 --- a/lib/srv/regular/proxy_test.go +++ b/lib/srv/regular/proxy_test.go @@ -21,10 +21,12 @@ import ( "time" "github.com/google/uuid" + "github.com/stretchr/testify/require" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" - "github.com/stretchr/testify/require" ) func TestParseProxyRequest(t *testing.T) { @@ -129,6 +131,21 @@ func TestParseBadRequests(t *testing.T) { } } +type nodeGetter struct { + servers []types.Server +} + +func (n nodeGetter) GetNodes(fn func(n services.Node) bool) []types.Server { + var servers []types.Server + for _, s := range n.servers { + if fn(s) { + servers = append(servers, s) + } + } + + return servers +} + func TestProxySubsys_getMatchingServer(t *testing.T) { t.Parallel() @@ -312,7 +329,7 @@ func TestProxySubsys_getMatchingServer(t *testing.T) { srv: &Server{}, } - server, err := subsystem.getMatchingServer(tt.servers, tt.strategy) + server, err := subsystem.getMatchingServer(nodeGetter{tt.servers}, tt.strategy) tt.expectError(t, err) if tt.expectServer != nil { require.Equal(t, tt.expectServer(tt.servers), server) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 8fa6af96840f8..18b16d36f9f38 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -188,6 +188,9 @@ type Server struct { // lockWatcher is the server's lock watcher. lockWatcher *services.LockWatcher + + // nodeWatcher is the server's node watcher. + nodeWatcher *services.NodeWatcher } // GetClock returns server clock implementation @@ -555,6 +558,14 @@ func SetLockWatcher(lockWatcher *services.LockWatcher) ServerOption { } } +// SetNodeWatcher sets the server's node watcher. +func SetNodeWatcher(nodeWatcher *services.NodeWatcher) ServerOption { + return func(s *Server) error { + s.nodeWatcher = nodeWatcher + return nil + } +} + // SetX11ForwardingConfig sets the server's X11 forwarding configuration func SetX11ForwardingConfig(xc *x11.ServerConfig) ServerOption { return func(s *Server) error { diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 2e2a422f6a095..c26946ecca119 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -1124,6 +1124,7 @@ func TestProxyRoundRobin(t *testing.T) { listener, reverseTunnelAddress := mustListen(t) defer listener.Close() lockWatcher := newLockWatcher(ctx, t, proxyClient) + nodeWatcher := newNodeWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClusterName: f.testSrv.ClusterName(), @@ -1140,6 +1141,7 @@ func TestProxyRoundRobin(t *testing.T) { Emitter: proxyClient, Log: logger, LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, }) require.NoError(t, err) logger.WithField("tun-addr", reverseTunnelAddress.String()).Info("Created reverse tunnel server.") @@ -1164,6 +1166,7 @@ func TestProxyRoundRobin(t *testing.T) { SetRestrictedSessionManager(&restricted.NOP{}), SetClock(f.clock), SetLockWatcher(lockWatcher), + SetNodeWatcher(nodeWatcher), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1246,6 +1249,7 @@ func TestProxyDirectAccess(t *testing.T) { logger := logrus.WithField("test", "TestProxyDirectAccess") proxyClient, _ := newProxyClient(t, f.testSrv) lockWatcher := newLockWatcher(ctx, t, proxyClient) + nodeWatcher := newNodeWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClientTLS: proxyClient.TLSConfig(), @@ -1262,6 +1266,7 @@ func TestProxyDirectAccess(t *testing.T) { Emitter: proxyClient, Log: logger, LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, }) require.NoError(t, err) @@ -1287,6 +1292,7 @@ func TestProxyDirectAccess(t *testing.T) { SetRestrictedSessionManager(&restricted.NOP{}), SetClock(f.clock), SetLockWatcher(lockWatcher), + SetNodeWatcher(nodeWatcher), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1941,6 +1947,18 @@ func newLockWatcher(ctx context.Context, t *testing.T, client types.Events) *ser return lockWatcher } +func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *services.NodeWatcher { + nodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: client, + }, + }) + require.NoError(t, err) + t.Cleanup(nodeWatcher.Close) + return nodeWatcher +} + // maxPipeSize is one larger than the maximum pipe size for most operating // systems which appears to be 65536 bytes. // diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 45b8d4da0e8c2..3d083346399ed 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -1914,12 +1914,12 @@ func (h *Handler) siteNodeConnect( req.ProxyHostPort = h.ProxyHostPort() req.Cluster = site.GetName() - clt, err := ctx.GetUserClient(site) + watcher, err := site.NodeWatcher() if err != nil { return nil, trace.Wrap(err) } - term, err := NewTerminal(r.Context(), *req, clt, ctx) + term, err := NewTerminal(*req, watcher, ctx) if err != nil { h.log.WithError(err).Error("Unable to create terminal.") return nil, trace.Wrap(err) @@ -1943,11 +1943,6 @@ type siteSessionGenerateResponse struct { // siteSessionCreate generates a new site session that can be used by UI // The ServerID from request can be in the form of hostname, uuid, or ip address. func (h *Handler) siteSessionGenerate(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { - clt, err := ctx.GetUserClient(site) - if err != nil { - return nil, trace.Wrap(err) - } - var req *siteSessionGenerateReq if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) @@ -1955,12 +1950,12 @@ func (h *Handler) siteSessionGenerate(w http.ResponseWriter, r *http.Request, p namespace := apidefaults.Namespace if req.Session.ServerID != "" { - servers, err := clt.GetNodes(r.Context(), namespace) + watcher, err := site.NodeWatcher() if err != nil { return nil, trace.Wrap(err) } - hostname, _, err := resolveServerHostPort(req.Session.ServerID, servers) + hostname, _, err := resolveServerHostPort(req.Session.ServerID, watcher) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 63781ec90709a..e7b1ebaa87865 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -27,7 +27,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net" "net/http" "net/http/cookiejar" @@ -41,9 +40,21 @@ import ( "testing" "time" + "github.com/beevik/etree" + "github.com/gogo/protobuf/proto" + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/gravitational/roundtrip" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + lemma_secret "github.com/mailgun/lemma/secret" + "github.com/pquerna/otp/totp" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "golang.org/x/net/websocket" "golang.org/x/text/encoding/unicode" + kyaml "k8s.io/apimachinery/pkg/util/yaml" "github.com/gravitational/teleport" apiProto "github.com/gravitational/teleport/api/client/proto" @@ -76,29 +87,10 @@ import ( "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/ui" - - "github.com/gravitational/roundtrip" - "github.com/gravitational/trace" - - "github.com/beevik/etree" - "github.com/gogo/protobuf/proto" - "github.com/google/go-cmp/cmp" - "github.com/jonboulle/clockwork" - lemma_secret "github.com/mailgun/lemma/secret" - "github.com/pborman/uuid" - "github.com/pquerna/otp/totp" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - . "gopkg.in/check.v1" - kyaml "k8s.io/apimachinery/pkg/util/yaml" ) const hostID = "00000000-0000-0000-0000-000000000000" -func TestWeb(t *testing.T) { - TestingT(t) -} - type WebSuite struct { ctx context.Context cancel context.CancelFunc @@ -117,8 +109,6 @@ type WebSuite struct { clock clockwork.FakeClock } -var _ = Suite(&WebSuite{}) - // TestMain will re-execute Teleport to run a command if "exec" is passed to // it as an argument. Otherwise it will run tests as normal. func TestMain(m *testing.M) { @@ -135,35 +125,37 @@ func TestMain(m *testing.M) { os.Exit(code) } -func (s *WebSuite) SetUpSuite(c *C) { - os.Unsetenv(teleport.DebugEnvVar) - - var err error - s.mockU2F, err = mocku2f.Create() - c.Assert(err, IsNil) - c.Assert(s.mockU2F, NotNil) -} +func newWebSuite(t *testing.T) *WebSuite { + mockU2F, err := mocku2f.Create() + require.NoError(t, err) + require.NotNil(t, mockU2F) -func noCache(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, error) { - return clt, nil -} + u, err := user.Current() + require.NoError(t, err) -func (s *WebSuite) SetUpTest(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + s := &WebSuite{ + mockU2F: mockU2F, + clock: clockwork.NewFakeClock(), + user: u.Username, + ctx: ctx, + cancel: cancel, + } - u, err := user.Current() - c.Assert(err, IsNil) - s.user = u.Username - s.clock = clockwork.NewFakeClock() + networkingConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{ + KeepAliveInterval: types.Duration(10 * time.Second), + }) + require.NoError(t, err) s.server, err = auth.NewTestServer(auth.TestServerConfig{ Auth: auth.TestAuthServerConfig{ - ClusterName: "localhost", - Dir: c.MkDir(), - Clock: s.clock, + ClusterName: "localhost", + Dir: t.TempDir(), + Clock: s.clock, + ClusterNetworkingConfig: networkingConfig, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) // Register the auth server, since test auth server doesn't start its own // heartbeat. @@ -180,13 +172,13 @@ func (s *WebSuite) SetUpTest(c *C) { Version: teleport.Version, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) priv, pub, err := s.server.AuthServer.AuthServer.GenerateKeyPair("") - c.Assert(err, IsNil) + require.NoError(t, err) tlsPub, err := auth.PrivateKeyToPublicKeyTLS(priv) - c.Assert(err, IsNil) + require.NoError(t, err) // start node certs, err := s.server.Auth().GenerateHostCerts(s.ctx, @@ -197,10 +189,10 @@ func (s *WebSuite) SetUpTest(c *C) { PublicSSHKey: pub, PublicTLSKey: tlsPub, }) - c.Assert(err, IsNil) + require.NoError(t, err) signer, err := sshutils.NewSigner(priv, certs.SSH) - c.Assert(err, IsNil) + require.NoError(t, err) nodeID := "node" nodeClient, err := s.server.NewClient(auth.TestIdentity{ @@ -209,7 +201,7 @@ func (s *WebSuite) SetUpTest(c *C) { Username: nodeID, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) nodeLockWatcher, err := services.NewLockWatcher(s.ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ @@ -217,10 +209,10 @@ func (s *WebSuite) SetUpTest(c *C) { Client: nodeClient, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) // create SSH service: - nodeDataDir := c.MkDir() + nodeDataDir := t.TempDir() node, err := regular.New( utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), @@ -240,12 +232,11 @@ func (s *WebSuite) SetUpTest(c *C) { regular.SetClock(s.clock), regular.SetLockWatcher(nodeLockWatcher), ) - c.Assert(err, IsNil) + require.NoError(t, err) s.node = node s.srvID = node.ID() - c.Assert(s.node.Start(), IsNil) - - c.Assert(auth.CreateUploaderDir(nodeDataDir), IsNil) + require.NoError(t, s.node.Start()) + require.NoError(t, auth.CreateUploaderDir(nodeDataDir)) // create reverse tunnel service: proxyID := "proxy" @@ -255,10 +246,10 @@ func (s *WebSuite) SetUpTest(c *C) { Username: proxyID, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) revTunListener, err := net.Listen("tcp", fmt.Sprintf("%v:0", s.server.ClusterName())) - c.Assert(err, IsNil) + require.NoError(t, err) proxyLockWatcher, err := services.NewLockWatcher(s.ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ @@ -266,7 +257,15 @@ func (s *WebSuite) SetUpTest(c *C) { Client: s.proxyClient, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) + + proxyNodeWatcher, err := services.NewNodeWatcher(s.ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.proxyClient, + }, + }) + require.NoError(t, err) revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), @@ -279,10 +278,11 @@ func (s *WebSuite) SetUpTest(c *C) { Emitter: s.proxyClient, NewCachingAccessPoint: noCache, DirectClusters: []reversetunnel.DirectCluster{{Name: s.server.ClusterName(), Client: s.proxyClient}}, - DataDir: c.MkDir(), + DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, + NodeWatcher: proxyNodeWatcher, }) - c.Assert(err, IsNil) + require.NoError(t, err) s.proxyTunnel = revTunServer // proxy server: @@ -291,7 +291,7 @@ func (s *WebSuite) SetUpTest(c *C) { s.server.ClusterName(), []ssh.Signer{signer}, s.proxyClient, - c.MkDir(), + t.TempDir(), "", utils.NetAddr{}, regular.SetUUID(proxyID), @@ -303,13 +303,14 @@ func (s *WebSuite) SetUpTest(c *C) { regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(s.clock), regular.SetLockWatcher(proxyLockWatcher), + regular.SetNodeWatcher(proxyNodeWatcher), ) - c.Assert(err, IsNil) + require.NoError(t, err) // Expired sessions are purged immediately var sessionLingeringThreshold time.Duration fs, err := NewDebugFileSystem("../../webassets/teleport") - c.Assert(err, IsNil) + require.NoError(t, err) handler, err := NewHandler(Config{ Proxy: revTunServer, AuthServers: utils.FromAddr(s.server.TLS.Addr()), @@ -324,22 +325,22 @@ func (s *WebSuite) SetUpTest(c *C) { cachedSessionLingeringThreshold: &sessionLingeringThreshold, ProxySettings: &mockProxySettings{}, }, SetSessionStreamPollPeriod(200*time.Millisecond), SetClock(s.clock)) - c.Assert(err, IsNil) + require.NoError(t, err) s.webServer = httptest.NewUnstartedServer(handler) s.webServer.StartTLS() err = s.proxy.Start() - c.Assert(err, IsNil) + require.NoError(t, err) // Wait for proxy to fully register before starting the test. for start := time.Now(); ; { proxies, err := s.proxyClient.GetProxies() - c.Assert(err, IsNil) + require.NoError(t, err) if len(proxies) != 0 { break } if time.Since(start) > 5*time.Second { - c.Fatal("proxy didn't register within 5s after startup") + t.Fatal("proxy didn't register within 5s after startup") } } @@ -349,27 +350,37 @@ func (s *WebSuite) SetUpTest(c *C) { handler.handler.cfg.ProxyWebAddr = *addr handler.handler.cfg.ProxySSHAddr = *proxyAddr _, sshPort, err := net.SplitHostPort(proxyAddr.String()) - c.Assert(err, IsNil) + require.NoError(t, err) handler.handler.sshPort = sshPort -} -func (s *WebSuite) TearDownTest(c *C) { - // In particular close the lock watchers by cancelling the context. - s.cancel() + t.Cleanup(func() { + // In particular close the lock watchers by cancelling the context. + s.cancel() - var errors []error - s.proxyTunnel.Close() - if err := s.node.Close(); err != nil { - errors = append(errors, err) - } - s.webServer.Close() - if err := s.proxy.Close(); err != nil { - errors = append(errors, err) - } - if err := s.server.Shutdown(context.Background()); err != nil { - errors = append(errors, err) - } - c.Assert(errors, HasLen, 0) + s.webServer.Close() + + var errors []error + if err := s.proxyTunnel.Close(); err != nil { + errors = append(errors, err) + } + if err := s.node.Close(); err != nil { + errors = append(errors, err) + } + s.webServer.Close() + if err := s.proxy.Close(); err != nil { + errors = append(errors, err) + } + if err := s.server.Shutdown(context.Background()); err != nil { + errors = append(errors, err) + } + require.Empty(t, errors) + }) + + return s +} + +func noCache(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, error) { + return clt, nil } func (r *authPack) renewSession(ctx context.Context, t *testing.T) *roundtrip.Response { @@ -395,7 +406,7 @@ type authPack struct { // authPack returns new authenticated package consisting of created valid // user, otp token, created web session and authenticated client. -func (s *WebSuite) authPack(c *C, user string) *authPack { +func (s *WebSuite) authPack(t *testing.T, user string) *authPack { login := s.user pass := "abc123" rawSecret := "def456" @@ -405,15 +416,15 @@ func (s *WebSuite) authPack(c *C, user string) *authPack { Type: constants.Local, SecondFactor: constants.SecondFactorOTP, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) - s.createUser(c, user, login, pass, otpSecret) + s.createUser(t, user, login, pass, otpSecret) // create a valid otp token validToken, err := totp.GenerateCode(otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() req := CreateSessionReq{ @@ -424,16 +435,16 @@ func (s *WebSuite) authPack(c *C, user string) *authPack { csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" re, err := s.login(clt, csrfToken, csrfToken, req) - c.Assert(err, IsNil) + require.NoError(t, err) var rawSess *CreateSessionResponse - c.Assert(json.Unmarshal(re.Bytes(), &rawSess), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) sess, err := rawSess.response() - c.Assert(err, IsNil) + require.NoError(t, err) jar, err := cookiejar.New(nil) - c.Assert(err, IsNil) + require.NoError(t, err) clt = s.client(roundtrip.BearerAuth(sess.Token), roundtrip.CookieJar(jar)) jar.SetCookies(s.url(), re.Cookies()) @@ -448,32 +459,32 @@ func (s *WebSuite) authPack(c *C, user string) *authPack { } } -func (s *WebSuite) createUser(c *C, user string, login string, pass string, otpSecret string) { +func (s *WebSuite) createUser(t *testing.T, user string, login string, pass string, otpSecret string) { teleUser, err := types.NewUser(user) - c.Assert(err, IsNil) + require.NoError(t, err) role := services.RoleForUser(teleUser) role.SetLogins(types.Allow, []string{login}) options := role.GetOptions() options.ForwardAgent = types.NewBool(true) role.SetOptions(options) err = s.server.Auth().UpsertRole(s.ctx, role) - c.Assert(err, IsNil) + require.NoError(t, err) teleUser.AddRole(role.GetName()) teleUser.SetCreatedBy(types.CreatedBy{ User: types.UserRef{Name: "some-auth-user"}, }) err = s.server.Auth().CreateUser(s.ctx, teleUser) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertPassword(user, []byte(pass)) - c.Assert(err, IsNil) + require.NoError(t, err) if otpSecret != "" { dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertMFADevice(context.Background(), user, dev) - c.Assert(err, IsNil) + require.NoError(t, err) } } @@ -496,18 +507,20 @@ func TestValidRedirectURL(t *testing.T) { } } -func (s *WebSuite) TestSAMLSuccess(c *C) { +func TestSAMLSuccess(t *testing.T) { + t.Parallel() + s := newWebSuite(t) input := fixtures.SAMLOktaConnectorV2 decoder := kyaml.NewYAMLOrJSONDecoder(strings.NewReader(input), defaults.LookaheadBufSize) var raw services.UnknownResource err := decoder.Decode(&raw) - c.Assert(err, IsNil) + require.NoError(t, err) connector, err := services.UnmarshalSAMLConnector(raw.Raw) - c.Assert(err, IsNil) + require.NoError(t, err) err = services.ValidateSAMLConnector(connector) - c.Assert(err, IsNil) + require.NoError(t, err) role, err := types.NewRole(connector.GetAttributesToRoles()[0].Roles[0], types.RoleSpecV4{ Options: types.RoleOptions{ @@ -521,64 +534,64 @@ func (s *WebSuite) TestSAMLSuccess(c *C) { }, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) role.SetLogins(types.Allow, []string{s.user}) err = s.server.Auth().UpsertRole(s.ctx, role) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().CreateSAMLConnector(connector) - c.Assert(err, IsNil) - s.server.Auth().SetClock(clockwork.NewFakeClockAt(time.Date(2017, 05, 10, 18, 53, 0, 0, time.UTC))) + require.NoError(t, err) + s.server.Auth().SetClock(clockwork.NewFakeClockAt(time.Date(2017, 5, 10, 18, 53, 0, 0, time.UTC))) clt := s.clientNoRedirects() csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" - baseURL, err := url.Parse(clt.Endpoint("webapi", "saml", "sso") + `?redirect_url=http://localhost/after&connector_id=` + connector.GetName()) - c.Assert(err, IsNil) + baseURL, err := url.Parse(clt.Endpoint("webapi", "saml", "sso") + `?connector_id=` + connector.GetName() + `&redirect_url=http://localhost/after`) + require.NoError(t, err) req, err := http.NewRequest("GET", baseURL.String(), nil) - c.Assert(err, IsNil) + require.NoError(t, err) addCSRFCookieToReq(req, csrfToken) re, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) // we got a redirect urlPattern := regexp.MustCompile(`URL='([^']*)'`) locationURL := urlPattern.FindStringSubmatch(string(re.Bytes()))[1] u, err := url.Parse(locationURL) - c.Assert(err, IsNil) - c.Assert(u.Scheme+"://"+u.Host+u.Path, Equals, fixtures.SAMLOktaSSO) + require.NoError(t, err) + require.Equal(t, fixtures.SAMLOktaSSO, u.Scheme+"://"+u.Host+u.Path) data, err := base64.StdEncoding.DecodeString(u.Query().Get("SAMLRequest")) - c.Assert(err, IsNil) - buf, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(data))) - c.Assert(err, IsNil) + require.NoError(t, err) + buf, err := io.ReadAll(flate.NewReader(bytes.NewReader(data))) + require.NoError(t, err) doc := etree.NewDocument() err = doc.ReadFromBytes(buf) - c.Assert(err, IsNil) + require.NoError(t, err) id := doc.Root().SelectAttr("ID") - c.Assert(id, NotNil) + require.NotNil(t, id) authRequest, err := s.server.Auth().GetSAMLAuthRequest(id.Value) - c.Assert(err, IsNil) + require.NoError(t, err) // now swap the request id to the hardcoded one in fixtures authRequest.ID = fixtures.SAMLOktaAuthRequestID authRequest.CSRFToken = csrfToken err = s.server.Auth().Identity.CreateSAMLAuthRequest(*authRequest, backend.Forever) - c.Assert(err, IsNil) + require.NoError(t, err) // now respond with pre-recorded request to the POST url in := &bytes.Buffer{} fw, err := flate.NewWriter(in, flate.DefaultCompression) - c.Assert(err, IsNil) + require.NoError(t, err) _, err = fw.Write([]byte(fixtures.SAMLOktaAuthnResponseXML)) - c.Assert(err, IsNil) + require.NoError(t, err) err = fw.Close() - c.Assert(err, IsNil) + require.NoError(t, err) encodedResponse := base64.StdEncoding.EncodeToString(in.Bytes()) - c.Assert(encodedResponse, NotNil) + require.NotNil(t, encodedResponse) // now send the response to the server to exchange it for auth session form := url.Values{} @@ -586,43 +599,46 @@ func (s *WebSuite) TestSAMLSuccess(c *C) { req, err = http.NewRequest("POST", clt.Endpoint("webapi", "saml", "acs"), strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") addCSRFCookieToReq(req, csrfToken) - c.Assert(err, IsNil) + require.NoError(t, err) authRe, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) - comment := Commentf("Response: %v", string(authRe.Bytes())) - c.Assert(authRe.Code(), Equals, http.StatusFound, comment) + require.NoError(t, err) + require.Equal(t, http.StatusFound, authRe.Code(), "Response: %v", string(authRe.Bytes())) // we have got valid session - c.Assert(authRe.Headers().Get("Set-Cookie"), Not(Equals), "") - // we are being redirected to orignal URL - c.Assert(authRe.Headers().Get("Location"), Equals, "/after") + require.NotEmpty(t, authRe.Headers().Get("Set-Cookie")) + // we are being redirected to original URL + require.Equal(t, "/after", authRe.Headers().Get("Location")) } -func (s *WebSuite) TestWebSessionsCRUD(c *C) { - pack := s.authPack(c, "foo") +func TestWebSessionsCRUD(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") // make sure we can use client to make authenticated requests re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var clusters []ui.Cluster - c.Assert(json.Unmarshal(re.Bytes(), &clusters), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &clusters)) // now delete session _, err = pack.clt.Delete( context.Background(), pack.clt.Endpoint("webapi", "sessions")) - c.Assert(err, IsNil) + require.NoError(t, err) // subsequent requests trying to use this session will fail _, err = pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) } -func (s *WebSuite) TestCSRF(c *C) { +func TestCSRF(t *testing.T) { + t.Parallel() + s := newWebSuite(t) type input struct { reqToken string cookieToken string @@ -632,11 +648,11 @@ func (s *WebSuite) TestCSRF(c *C) { user := "csrfuser" pass := "abc123" otpSecret := base32.StdEncoding.EncodeToString([]byte("def456")) - s.createUser(c, user, user, pass, otpSecret) + s.createUser(t, user, user, pass, otpSecret) // create a valid login form request validToken, err := totp.GenerateCode(otpSecret, time.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) loginForm := CreateSessionReq{ User: user, Pass: pass, @@ -656,23 +672,25 @@ func (s *WebSuite) TestCSRF(c *C) { // valid _, err = s.login(clt, encodedToken1, encodedToken1, loginForm) - c.Assert(err, IsNil) + require.NoError(t, err) // invalid for i := range invalid { _, err := s.login(clt, invalid[i].cookieToken, invalid[i].reqToken, loginForm) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) } } -func (s *WebSuite) TestPasswordChange(c *C) { - pack := s.authPack(c, "foo") +func TestPasswordChange(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") // invalidate the token s.clock.Advance(1 * time.Minute) validToken, err := totp.GenerateCode(pack.otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) req := changePasswordReq{ OldPassword: []byte("abc123"), @@ -681,26 +699,28 @@ func (s *WebSuite) TestPasswordChange(c *C) { } _, err = pack.clt.PutJSON(context.Background(), pack.clt.Endpoint("webapi", "users", "password"), req) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *WebSuite) TestWebSessionsBadInput(c *C) { +func TestWebSessionsBadInput(t *testing.T) { + t.Parallel() + s := newWebSuite(t) user := "bob" pass := "abc123" rawSecret := "def456" otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret)) err := s.server.Auth().UpsertPassword(user, []byte(pass)) - c.Assert(err, IsNil) + require.NoError(t, err) dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertMFADevice(context.Background(), user, dev) - c.Assert(err, IsNil) + require.NoError(t, err) // create valid token validToken, err := totp.GenerateCode(otpSecret, time.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() @@ -736,9 +756,11 @@ func (s *WebSuite) TestWebSessionsBadInput(c *C) { }, } for i, req := range reqs { - _, err = clt.PostJSON(context.Background(), clt.Endpoint("webapi", "sessions"), req) - c.Assert(err, NotNil, Commentf("tc %v", i)) - c.Assert(trace.IsAccessDenied(err), Equals, true, Commentf("tc %v %T is not access denied", i, err)) + t.Run(fmt.Sprintf("tc %v", i), func(t *testing.T) { + _, err := clt.PostJSON(s.ctx, clt.Endpoint("webapi", "sessions"), req) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + }) } } @@ -746,33 +768,38 @@ type getSiteNodeResponse struct { Items []ui.Server `json:"items"` } -func (s *WebSuite) TestGetSiteNodes(c *C) { - pack := s.authPack(c, "foo") +func TestGetSiteNodes(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") // get site nodes re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "nodes"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) nodes := getSiteNodeResponse{} - c.Assert(json.Unmarshal(re.Bytes(), &nodes), IsNil) - c.Assert(len(nodes.Items), Equals, 1) + require.NoError(t, json.Unmarshal(re.Bytes(), &nodes)) + require.Len(t, nodes.Items, 1) // get site nodes using shortcut re, err = pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", currentSiteShortcut, "nodes"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) nodes2 := getSiteNodeResponse{} - c.Assert(json.Unmarshal(re.Bytes(), &nodes2), IsNil) - c.Assert(len(nodes.Items), Equals, 1) - c.Assert(nodes2, DeepEquals, nodes) + require.NoError(t, json.Unmarshal(re.Bytes(), &nodes2)) + require.Len(t, nodes.Items, 1) + require.Empty(t, cmp.Diff(nodes, nodes2)) } -func (s *WebSuite) TestSiteNodeConnectInvalidSessionID(c *C) { - _, err := s.makeTerminal(s.authPack(c, "foo"), session.ID("/../../../foo")) - c.Assert(err, NotNil) +func TestSiteNodeConnectInvalidSessionID(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + _, err := s.makeTerminal(s.authPack(t, "foo"), session.ID("/../../../foo")) + require.Error(t, err) } -func (s *WebSuite) TestResolveServerHostPort(c *C) { +func TestResolveServerHostPort(t *testing.T) { + t.Parallel() sampleNode := types.ServerV2{} sampleNode.SetName("eca53e45-86a9-11e7-a893-0242ac0a0101") sampleNode.Spec.Hostname = "nodehostname" @@ -830,20 +857,35 @@ func (s *WebSuite) TestResolveServerHostPort(c *C) { } for _, testCase := range validCases { - host, port, err := resolveServerHostPort(testCase.server, testCase.nodes) - c.Assert(err, IsNil, Commentf(testCase.server)) - c.Assert(host, Equals, testCase.expectedHost) - c.Assert(port, Equals, testCase.expectedPort) + host, port, err := resolveServerHostPort(testCase.server, nodeGetter{servers: testCase.nodes}) + require.NoError(t, err, testCase.server) + require.Equal(t, testCase.expectedHost, host, testCase.server) + require.Equal(t, testCase.expectedPort, port, testCase.server) } for _, testCase := range invalidCases { - _, _, err := resolveServerHostPort(testCase.server, nil) - c.Assert(err, NotNil, Commentf(testCase.expectedErr)) - c.Assert(err, ErrorMatches, ".*"+testCase.expectedErr+".*") + _, _, err := resolveServerHostPort(testCase.server, nodeGetter{}) + require.Error(t, err, testCase.server) + require.Regexp(t, ".*"+testCase.expectedErr+".*", err.Error(), testCase.server) } } -func (s *WebSuite) TestNewTerminalHandler(c *C) { +type nodeGetter struct { + servers []types.Server +} + +func (n nodeGetter) GetNodes(fn func(n services.Node) bool) []types.Server { + var servers []types.Server + for _, s := range n.servers { + if fn(s) { + servers = append(servers, s) + } + } + + return servers +} + +func TestNewTerminalHandler(t *testing.T) { validNode := types.ServerV2{} validNode.SetName("eca53e45-86a9-11e7-a893-0242ac0a0101") validNode.Spec.Hostname = "nodehostname" @@ -856,16 +898,10 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { W: 1, } - makeProvider := func(server types.ServerV2) AuthProvider { - return authProviderMock{ - server: server, - } - } - // valid cases validCases := []struct { req TerminalRequest - authProvider AuthProvider + site reversetunnel.RemoteSite expectedHost string expectedPort int }{ @@ -876,7 +912,6 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { SessionID: validSID, Term: validParams, }, - authProvider: makeProvider(validNode), expectedHost: validServer, expectedPort: 0, }, @@ -887,7 +922,6 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { SessionID: validSID, Term: validParams, }, - authProvider: makeProvider(validNode), expectedHost: "nodehostname", expectedPort: 0, }, @@ -895,13 +929,12 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { // invalid cases invalidCases := []struct { - req TerminalRequest - authProvider AuthProvider - expectedErr string + req TerminalRequest + site reversetunnel.RemoteSite + expectedErr string }{ { - expectedErr: "invalid session", - authProvider: makeProvider(validNode), + expectedErr: "invalid session", req: TerminalRequest{ SessionID: "", Login: validLogin, @@ -910,8 +943,7 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { }, }, { - expectedErr: "bad term dimensions", - authProvider: makeProvider(validNode), + expectedErr: "bad term dimensions", req: TerminalRequest{ SessionID: validSID, Login: validLogin, @@ -923,8 +955,7 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { }, }, { - expectedErr: "invalid server name", - authProvider: makeProvider(validNode), + expectedErr: "invalid server name", req: TerminalRequest{ Server: "localhost:port", SessionID: validSID, @@ -934,98 +965,103 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { }, } + getter := nodeGetter{servers: []types.Server{&validNode}} for _, testCase := range validCases { - term, err := NewTerminal(s.ctx, testCase.req, testCase.authProvider, nil) - c.Assert(err, IsNil) - c.Assert(term.params, DeepEquals, testCase.req) - c.Assert(term.hostName, Equals, testCase.expectedHost) - c.Assert(term.hostPort, Equals, testCase.expectedPort) + term, err := NewTerminal(testCase.req, getter, nil) + require.NoError(t, err) + require.Empty(t, cmp.Diff(testCase.req, term.params)) + require.Equal(t, testCase.expectedHost, testCase.expectedHost) + require.Equal(t, testCase.expectedPort, testCase.expectedPort) } for _, testCase := range invalidCases { - _, err := NewTerminal(s.ctx, testCase.req, testCase.authProvider, nil) - c.Assert(err, ErrorMatches, ".*"+testCase.expectedErr+".*") + _, err := NewTerminal(testCase.req, getter, nil) + require.Regexp(t, ".*"+testCase.expectedErr+".*", err.Error()) } } -func (s *WebSuite) TestResizeTerminal(c *C) { +func TestResizeTerminal(t *testing.T) { + t.Parallel() + s := newWebSuite(t) sid := session.NewID() // Create a new user "foo", open a terminal to a new session, and wait for // it to be ready. - pack1 := s.authPack(c, "foo") + pack1 := s.authPack(t, "foo") ws1, err := s.makeTerminal(pack1, sid) - c.Assert(err, IsNil) - defer ws1.Close() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws1.Close()) }) err = s.waitForRawEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // Create a new user "bar", open a terminal to the session created above, // and wait for it to be ready. - pack2 := s.authPack(c, "bar") + pack2 := s.authPack(t, "bar") ws2, err := s.makeTerminal(pack2, sid) - c.Assert(err, IsNil) - defer ws2.Close() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws2.Close()) }) err = s.waitForRawEvent(ws2, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // Look at the audit events for the first terminal. It should have two // resize events from the second terminal (80x25 default then 100x100). Only // the second terminal will get these because resize events are not sent // back to the originator. err = s.waitForResizeEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.waitForResizeEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // Look at the stream events for the second terminal. We don't expect to see // any resize events yet. It will timeout. err = s.waitForResizeEvent(ws2, 1*time.Second) - c.Assert(err, NotNil) + require.Error(t, err) // Resize the second terminal. This should be reflected on the first terminal // because resize events are not sent back to the originator. params, err := session.NewTerminalParamsFromInt(300, 120) - c.Assert(err, IsNil) + require.NoError(t, err) data, err := json.Marshal(events.EventFields{ events.EventType: events.ResizeEvent, events.EventNamespace: apidefaults.Namespace, events.SessionEventID: sid.String(), events.TerminalSize: params.Serialize(), }) - c.Assert(err, IsNil) + require.NoError(t, err) envelope := &Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketResize, Payload: string(data), } envelopeBytes, err := proto.Marshal(envelope) - c.Assert(err, IsNil) + require.NoError(t, err) err = websocket.Message.Send(ws2, envelopeBytes) - c.Assert(err, IsNil) + require.NoError(t, err) // This time the first terminal will see the resize event. err = s.waitForResizeEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // The second terminal will not see any resize event. It will timeout. err = s.waitForResizeEvent(ws2, 1*time.Second) - c.Assert(err, NotNil) + require.Error(t, err) } -func (s *WebSuite) TestTerminal(c *C) { - ws, err := s.makeTerminal(s.authPack(c, "foo")) - c.Assert(err, IsNil) - defer ws.Close() +func TestTerminal(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ws, err := s.makeTerminal(s.authPack(t, "foo")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) _, err = io.WriteString(stream, "echo vinsong\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) err = waitForOutput(stream, "vinsong") - c.Assert(err, IsNil) + require.NoError(t, err) } func TestTerminalRequireSessionMfa(t *testing.T) { @@ -1133,58 +1169,55 @@ func TestTerminalRequireSessionMfa(t *testing.T) { // Wait for websocket authn challenge event. var raw []byte - require.Nil(t, websocket.Message.Receive(ws, &raw)) + require.NoError(t, websocket.Message.Receive(ws, &raw)) var env Envelope - require.Nil(t, proto.Unmarshal(raw, &env)) + require.NoError(t, proto.Unmarshal(raw, &env)) chals := &auth.MFAAuthenticateChallenge{} - require.Nil(t, json.Unmarshal([]byte(env.Payload), &chals)) + require.NoError(t, json.Unmarshal([]byte(env.Payload), &chals)) // Send response over ws. termHandler := newTerminalHandler() _, err := termHandler.write(tc.getChallengeResponseBytes(chals, dev), ws) - require.Nil(t, err) + require.NoError(t, err) // Test we can write. stream := termHandler.asTerminalStream(ws) _, err = io.WriteString(stream, "echo alpacas\r\n") - require.Nil(t, err) - require.Nil(t, waitForOutput(stream, "alpacas")) - - require.Nil(t, ws.Close()) + require.NoError(t, err) + require.NoError(t, waitForOutput(stream, "alpacas")) }) } } -func (s *WebSuite) TestWebsocketPingLoop(c *C) { +func TestWebsocketPingLoop(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + // Change cluster networking config for keep alive interval to be run faster. netConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{ KeepAliveInterval: types.NewDuration(250 * time.Millisecond), }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetClusterNetworkingConfig(s.ctx, netConfig) - c.Assert(err, IsNil) + require.NoError(t, err) recConfig, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{ Mode: types.RecordAtNode, ProxyChecksHostKeys: types.NewBoolOption(true), }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetSessionRecordingConfig(s.ctx, recConfig) - c.Assert(err, IsNil) - - ws, err := s.makeTerminal(s.authPack(c, "foo")) - c.Assert(err, IsNil) + require.NoError(t, err) - // flush out raw event (pty texts) - err = s.waitForRawEvent(ws, 5*time.Second) - c.Assert(err, IsNil) + ws, err := s.makeTerminal(s.authPack(t, "foo")) + require.NoError(t, err) var numPings int start := time.Now() for { frame, err := ws.NewFrameReader() - c.Assert(err, IsNil) + require.NoError(t, err) // We should get a mix of output (binary) and ping frames. Count only // the ping frames. if int(frame.PayloadType()) == websocket.PingFrame { @@ -1194,86 +1227,91 @@ func (s *WebSuite) TestWebsocketPingLoop(c *C) { break } if time.Since(start) > 5*time.Second { - c.Fatalf("received %d ping frames within 5s of opening a socket, expected at least 2", numPings) + t.Fatalf("received %d ping frames within 5s of opening a socket, expected at least 2", numPings) } } - err = ws.Close() - c.Assert(err, IsNil) + require.NoError(t, ws.Close()) } -func (s *WebSuite) TestWebAgentForward(c *C) { - ws, err := s.makeTerminal(s.authPack(c, "foo")) - c.Assert(err, IsNil) - defer ws.Close() +func TestWebAgentForward(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ws, err := s.makeTerminal(s.authPack(t, "foo")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) _, err = io.WriteString(stream, "echo $SSH_AUTH_SOCK\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) err = waitForOutput(stream, "/") - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *WebSuite) TestActiveSessions(c *C) { +func TestActiveSessions(t *testing.T) { + t.Parallel() + s := newWebSuite(t) sid := session.NewID() - pack := s.authPack(c, "foo") + pack := s.authPack(t, "foo") ws, err := s.makeTerminal(pack, sid) - c.Assert(err, IsNil) - defer ws.Close() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) // To make sure we have a session. _, err = io.WriteString(stream, "echo vinsong\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) // Make sure server has replied. err = waitForOutput(stream, "vinsong") - c.Assert(err, IsNil) + require.NoError(t, err) // Make sure this session appears in the list of active sessions. var sessResp *siteSessionsGetResponse for i := 0; i < 10; i++ { // Get site nodes and make sure the node has our active party. - re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) - c.Assert(err, IsNil) + re, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) + require.NoError(t, err) - c.Assert(json.Unmarshal(re.Bytes(), &sessResp), IsNil) - c.Assert(len(sessResp.Sessions), Equals, 1) + require.NoError(t, json.Unmarshal(re.Bytes(), &sessResp)) + require.Len(t, sessResp.Sessions, 1) // Sessions do not appear momentarily as there's async heartbeat // procedure. time.Sleep(250 * time.Millisecond) } - c.Assert(len(sessResp.Sessions), Equals, 1) + require.Len(t, sessResp.Sessions, 1) sess := sessResp.Sessions[0] - c.Assert(sess.ID, Equals, sid) - c.Assert(sess.Namespace, Equals, s.node.GetNamespace()) - c.Assert(sess.Parties, NotNil) - c.Assert(sess.TerminalParams.H > 0, Equals, true) - c.Assert(sess.TerminalParams.W > 0, Equals, true) - c.Assert(sess.Login, Equals, pack.login) - c.Assert(sess.Created.IsZero(), Equals, false) - c.Assert(sess.LastActive.IsZero(), Equals, false) - c.Assert(sess.ServerID, Equals, s.srvID) - c.Assert(sess.ServerHostname, Equals, s.node.GetInfo().GetHostname()) - c.Assert(sess.ServerAddr, Equals, s.node.GetInfo().GetAddr()) - c.Assert(sess.ClusterName, Equals, s.server.ClusterName()) + require.Equal(t, sid, sess.ID) + require.Equal(t, s.node.GetNamespace(), sess.Namespace) + require.NotNil(t, sess.Parties) + require.Greater(t, sess.TerminalParams.H, 0) + require.Greater(t, sess.TerminalParams.W, 0) + require.Equal(t, pack.login, sess.Login) + require.False(t, sess.Created.IsZero()) + require.False(t, sess.LastActive.IsZero()) + require.Equal(t, s.srvID, sess.ServerID) + require.Equal(t, s.node.GetInfo().GetHostname(), sess.ServerHostname) + require.Equal(t, s.node.GetInfo().GetAddr(), sess.ServerAddr) + require.Equal(t, s.server.ClusterName(), sess.ClusterName) } // DELETE IN: 5.0.0 // Tests the code snippet from apiserver.(*Handler).siteSessionGet/siteSessionsGet // that tests empty ClusterName and ServerHostname gets set. -func (s *WebSuite) TestEmptySessionClusterHostnameIsSet(c *C) { +func TestEmptySessionClusterHostnameIsSet(t *testing.T) { + t.Parallel() + s := newWebSuite(t) nodeClient, err := s.server.NewClient(auth.TestBuiltin(types.RoleNode)) - c.Assert(err, IsNil) + require.NoError(t, err) // Create a session with empty ClusterName. sess1 := session.Session{ @@ -1287,68 +1325,67 @@ func (s *WebSuite) TestEmptySessionClusterHostnameIsSet(c *C) { TerminalParams: session.TerminalParams{W: 100, H: 100}, } err = nodeClient.CreateSession(sess1) - c.Assert(err, IsNil) + require.NoError(t, err) // Retrieve the session with the empty ClusterName. - pack := s.authPack(c, "baz") - res, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions", sess1.ID.String()), url.Values{}) - c.Assert(err, IsNil) + pack := s.authPack(t, "baz") + res, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions", sess1.ID.String()), url.Values{}) + require.NoError(t, err) // Test that empty ClusterName and ServerHostname got set. var sessionResult *session.Session err = json.Unmarshal(res.Bytes(), &sessionResult) - c.Assert(err, IsNil) - c.Assert(sessionResult.ClusterName, Equals, s.server.ClusterName()) - c.Assert(sessionResult.ServerHostname, Equals, sess1.ServerID) + require.NoError(t, err) + require.Equal(t, s.server.ClusterName(), sessionResult.ClusterName) + require.Equal(t, sess1.ServerID, sessionResult.ServerHostname) // Create another session to test sessions list. sess2 := sess1 sess2.ID = session.NewID() sess2.ServerID = string(session.NewID()) err = nodeClient.CreateSession(sess2) - c.Assert(err, IsNil) + require.NoError(t, err) // Retrieve sessions list. - res, err = pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) - c.Assert(err, IsNil) + res, err = pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) + require.NoError(t, err) var sessionList *siteSessionsGetResponse err = json.Unmarshal(res.Bytes(), &sessionList) - c.Assert(err, IsNil) + require.NoError(t, err) s1 := sessionList.Sessions[0] s2 := sessionList.Sessions[1] - c.Assert(s1.ClusterName, Equals, s.server.ClusterName()) - c.Assert(s2.ClusterName, Equals, s.server.ClusterName()) - c.Assert(s1.ServerHostname, Equals, s1.ServerID) - c.Assert(s2.ServerHostname, Equals, s2.ServerID) + require.Equal(t, s.server.ClusterName(), s1.ClusterName) + require.Equal(t, s.server.ClusterName(), s2.ClusterName) + require.Equal(t, s1.ServerID, s1.ServerHostname) + require.Equal(t, s2.ServerID, s2.ServerHostname) } -func (s *WebSuite) TestCloseConnectionsOnLogout(c *C) { +func TestCloseConnectionsOnLogout(t *testing.T) { + t.Parallel() + s := newWebSuite(t) sid := session.NewID() - pack := s.authPack(c, "foo") + pack := s.authPack(t, "foo") ws, err := s.makeTerminal(pack, sid) - c.Assert(err, IsNil) - defer ws.Close() + require.NoError(t, err) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) // to make sure we have a session _, err = io.WriteString(stream, "expr 137 + 39\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) // make sure server has replied out := make([]byte, 100) _, err = stream.Read(out) - c.Assert(err, IsNil) + require.NoError(t, err) - _, err = pack.clt.Delete( - context.Background(), - pack.clt.Endpoint("webapi", "sessions")) - c.Assert(err, IsNil) + _, err = pack.clt.Delete(s.ctx, pack.clt.Endpoint("webapi", "sessions")) + require.NoError(t, err) // wait until we timeout or detect that connection has been closed after := time.After(5 * time.Second) @@ -1364,21 +1401,23 @@ func (s *WebSuite) TestCloseConnectionsOnLogout(c *C) { select { case <-after: - c.Fatalf("timeout") + t.Fatalf("timeout") case err := <-errC: - c.Assert(err, Equals, io.EOF) + require.ErrorIs(t, err, io.EOF) } } -func (s *WebSuite) TestCreateSession(c *C) { - pack := s.authPack(c, "foo") +func TestCreateSession(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") // get site nodes re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "nodes"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) nodes := getSiteNodeResponse{} - c.Assert(json.Unmarshal(re.Bytes(), &nodes), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &nodes)) node := nodes.Items[0] sess := session.Session{ @@ -1393,12 +1432,12 @@ func (s *WebSuite) TestCreateSession(c *C) { pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), siteSessionGenerateReq{Session: sess}, ) - c.Assert(err, IsNil) + require.NoError(t, err) var created *siteSessionGenerateResponse - c.Assert(json.Unmarshal(re.Bytes(), &created), IsNil) - c.Assert(created.Session.ID, Not(Equals), "") - c.Assert(created.Session.ServerHostname, Equals, node.Hostname) + require.NoError(t, json.Unmarshal(re.Bytes(), &created)) + require.NotEmpty(t, created.Session.ID) + require.Equal(t, node.Hostname, created.Session.ServerHostname) // test empty serverID (older version does not supply serverID) sess.ServerID = "" @@ -1407,38 +1446,42 @@ func (s *WebSuite) TestCreateSession(c *C) { pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), siteSessionGenerateReq{Session: sess}, ) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *WebSuite) TestPlayback(c *C) { - pack := s.authPack(c, "foo") +func TestPlayback(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") sid := session.NewID() ws, err := s.makeTerminal(pack, sid) - c.Assert(err, IsNil) - defer ws.Close() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) } -func (s *WebSuite) TestLogin(c *C) { +func TestLogin(t *testing.T) { + t.Parallel() + s := newWebSuite(t) ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOff, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) // create user - s.createUser(c, "user1", "root", "password", "") + s.createUser(t, "user1", "root", "password", "") loginReq, err := json.Marshal(CreateSessionReq{ User: "user1", Pass: "password", }) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions"), bytes.NewBuffer(loginReq)) - c.Assert(err, IsNil) + require.NoError(t, err) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" addCSRFCookieToReq(req, csrfToken) @@ -1448,91 +1491,94 @@ func (s *WebSuite) TestLogin(c *C) { re, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) var rawSess *CreateSessionResponse - c.Assert(json.Unmarshal(re.Bytes(), &rawSess), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) cookies := re.Cookies() - c.Assert(len(cookies), Equals, 1) + require.Len(t, cookies, 1) // now make sure we are logged in by calling authenticated method // we need to supply both session cookie and bearer token for // request to succeed jar, err := cookiejar.New(nil) - c.Assert(err, IsNil) + require.NoError(t, err) clt = s.client(roundtrip.BearerAuth(rawSess.Token), roundtrip.CookieJar(jar)) jar.SetCookies(s.url(), re.Cookies()) - re, err = clt.Get(context.Background(), clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, IsNil) + re, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + require.NoError(t, err) var clusters []ui.Cluster - c.Assert(json.Unmarshal(re.Bytes(), &clusters), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &clusters)) // in absence of session cookie or bearer auth the same request fill fail // no session cookie: clt = s.client(roundtrip.BearerAuth(rawSess.Token)) - _, err = clt.Get(context.Background(), clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + _, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) // no bearer token: clt = s.client(roundtrip.CookieJar(jar)) - _, err = clt.Get(context.Background(), clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + _, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) } -func (s *WebSuite) TestChangePasswordAndAddTOTPDeviceWithToken(c *C) { +func TestChangePasswordAndAddTOTPDeviceWithToken(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) // create user - s.createUser(c, "user1", "root", "password", "") + s.createUser(t, "user1", "root", "password", "") // create password change token token, err := s.server.Auth().CreateResetPasswordToken(context.TODO(), auth.CreateUserTokenRequest{ Name: "user1", }) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() re, err := clt.Get(context.Background(), clt.Endpoint("webapi", "users", "password", "token", token.GetName()), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var uiToken *ui.ResetPasswordToken - c.Assert(json.Unmarshal(re.Bytes(), &uiToken), IsNil) - c.Assert(uiToken.User, Equals, token.GetUser()) - c.Assert(uiToken.TokenID, Equals, token.GetName()) - c.Assert(uiToken.QRCode, NotNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &uiToken)) + require.Equal(t, token.GetUser(), uiToken.User) + require.Equal(t, token.GetName(), uiToken.TokenID) + require.NotNil(t, uiToken.QRCode) res, err := s.server.Auth().CreateRegisterChallenge(context.Background(), &apiProto.CreateRegisterChallengeRequest{ TokenID: token.GetName(), DeviceType: apiProto.DeviceType_DEVICE_TYPE_TOTP, }) - c.Assert(err, IsNil) + require.NoError(t, err) // Advance the clock to invalidate the TOTP token s.clock.Advance(1 * time.Minute) secondFactorToken, err := totp.GenerateCode(res.GetTOTP().GetSecret(), s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) data, err := json.Marshal(auth.ChangePasswordWithTokenRequest{ TokenID: token.GetName(), Password: []byte("abc123"), SecondFactorToken: secondFactorToken, }) - c.Assert(err, IsNil) + require.NoError(t, err) req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(data)) - c.Assert(err, IsNil) + require.NoError(t, err) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" addCSRFCookieToReq(req, csrfToken) @@ -1542,16 +1588,19 @@ func (s *WebSuite) TestChangePasswordAndAddTOTPDeviceWithToken(c *C) { re, err = clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) // Test that no recovery codes are returned b/c cloud feature isn't enabled. var response ui.RecoveryCodes - c.Assert(json.Unmarshal(re.Bytes(), &response), IsNil) - c.Assert(response.Codes, IsNil) - c.Assert(response.Created, IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &response)) + require.Nil(t, response.Codes) + require.Nil(t, response.Created) } -func (s *WebSuite) TestChangePasswordAndAddU2FDeviceWithToken(c *C) { +func TestChangePasswordAndAddU2FDeviceWithToken(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorU2F, @@ -1560,37 +1609,37 @@ func (s *WebSuite) TestChangePasswordAndAddU2FDeviceWithToken(c *C) { Facets: []string{"https://" + s.server.ClusterName()}, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) - s.createUser(c, "user2", "root", "password", "") + s.createUser(t, "user2", "root", "password", "") // create reset password token token, err := s.server.Auth().CreateResetPasswordToken(context.TODO(), auth.CreateUserTokenRequest{ Name: "user2", }) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() re, err := clt.Get(context.Background(), clt.Endpoint("webapi", "u2f", "signuptokens", token.GetName()), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var u2fRegReq u2f.RegisterChallenge - c.Assert(json.Unmarshal(re.Bytes(), &u2fRegReq), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &u2fRegReq)) u2fRegResp, err := s.mockU2F.RegisterResponse(&u2fRegReq) - c.Assert(err, IsNil) + require.NoError(t, err) data, err := json.Marshal(auth.ChangePasswordWithTokenRequest{ TokenID: token.GetName(), Password: []byte("qweQWE"), U2FRegisterResponse: u2fRegResp, }) - c.Assert(err, IsNil) + require.NoError(t, err) req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(data)) - c.Assert(err, IsNil) + require.NoError(t, err) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" addCSRFCookieToReq(req, csrfToken) @@ -1600,77 +1649,81 @@ func (s *WebSuite) TestChangePasswordAndAddU2FDeviceWithToken(c *C) { re, err = clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) // Test that no recovery codes are returned b/c cloud is not turned on. var response ui.RecoveryCodes - c.Assert(json.Unmarshal(re.Bytes(), &response), IsNil) - c.Assert(response.Codes, IsNil) - c.Assert(response.Created, IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &response)) + require.Nil(t, response.Codes) + require.Nil(t, response.Created) } // TestEmptyMotD ensures that responses returned by both /webapi/ping and // /webapi/motd work when no MotD is set -func (s *WebSuite) TestEmptyMotD(c *C) { - ctx := context.Background() +func TestEmptyMotD(t *testing.T) { + t.Parallel() + s := newWebSuite(t) wc := s.client() // Given an auth server configured *not* to expose a Message Of The // Day... // When I issue a ping request... - re, err := wc.Get(ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) + re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) + require.NoError(t, err) // Expect that the MotD flag in the ping response is *not* set var pingResponse *webclient.PingResponse - c.Assert(json.Unmarshal(re.Bytes(), &pingResponse), IsNil) - c.Assert(pingResponse.Auth.HasMessageOfTheDay, Equals, false) + require.NoError(t, json.Unmarshal(re.Bytes(), &pingResponse)) + require.False(t, pingResponse.Auth.HasMessageOfTheDay) // When I fetch the MotD... - re, err = wc.Get(ctx, wc.Endpoint("webapi", "motd"), url.Values{}) - c.Assert(err, IsNil) + re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "motd"), url.Values{}) + require.NoError(t, err) // Expect that an empty response returned var motdResponse *webclient.MotD - c.Assert(json.Unmarshal(re.Bytes(), &motdResponse), IsNil) - c.Assert(motdResponse.Text, Equals, "") + require.NoError(t, json.Unmarshal(re.Bytes(), &motdResponse)) + require.Empty(t, motdResponse.Text) } // TestMotD ensures that a response is returned by both /webapi/ping and /webapi/motd // and that that the response bodies contain their MOTD components -func (s *WebSuite) TestMotD(c *C) { +func TestMotD(t *testing.T) { + t.Parallel() const motd = "Hello. I'm a Teleport cluster!" - ctx := context.Background() + s := newWebSuite(t) wc := s.client() // Given an auth server configured to expose a Message Of The Day... prefs := types.DefaultAuthPreference() prefs.SetMessageOfTheDay(motd) - s.server.AuthServer.AuthServer.SetAuthPreference(ctx, prefs) + require.NoError(t, s.server.AuthServer.AuthServer.SetAuthPreference(s.ctx, prefs)) // When I issue a ping request... - re, err := wc.Get(ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) + re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) + require.NoError(t, err) // Expect that the MotD flag in the ping response is set to indicate // a MotD var pingResponse *webclient.PingResponse - c.Assert(json.Unmarshal(re.Bytes(), &pingResponse), IsNil) - c.Assert(pingResponse.Auth.HasMessageOfTheDay, Equals, true) + require.NoError(t, json.Unmarshal(re.Bytes(), &pingResponse)) + require.True(t, pingResponse.Auth.HasMessageOfTheDay) // When I fetch the MotD... - re, err = wc.Get(ctx, wc.Endpoint("webapi", "motd"), url.Values{}) - c.Assert(err, IsNil) + re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "motd"), url.Values{}) + require.NoError(t, err) // Expect that the text returned is the configured value var motdResponse *webclient.MotD - c.Assert(json.Unmarshal(re.Bytes(), &motdResponse), IsNil) - c.Assert(motdResponse.Text, Equals, motd) + require.NoError(t, json.Unmarshal(re.Bytes(), &motdResponse)) + require.Equal(t, motd, motdResponse.Text) } -func (s *WebSuite) TestMultipleConnectors(c *C) { +func TestMultipleConnectors(t *testing.T) { + t.Parallel() + s := newWebSuite(t) wc := s.client() // create two oidc connectors, one named "foo" and another named "bar" @@ -1690,61 +1743,61 @@ func (s *WebSuite) TestMultipleConnectors(c *C) { }, } o, err := types.NewOIDCConnector("foo", oidcConnectorSpec) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertOIDCConnector(s.ctx, o) - c.Assert(err, IsNil) + require.NoError(t, err) o2, err := types.NewOIDCConnector("bar", oidcConnectorSpec) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertOIDCConnector(s.ctx, o2) - c.Assert(err, IsNil) + require.NoError(t, err) // set the auth preferences to oidc with no connector name authPreference, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: "oidc", }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, authPreference) - c.Assert(err, IsNil) + require.NoError(t, err) // hit the ping endpoint to get the auth type and connector name re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var out *webclient.PingResponse - c.Assert(json.Unmarshal(re.Bytes(), &out), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &out)) // make sure the connector name we got back was the first connector // in the backend, in this case it's "bar" oidcConnectors, err := s.server.Auth().GetOIDCConnectors(s.ctx, false) - c.Assert(err, IsNil) - c.Assert(out.Auth.OIDC.Name, Equals, oidcConnectors[0].GetName()) + require.NoError(t, err) + require.Equal(t, oidcConnectors[0].GetName(), out.Auth.OIDC.Name) // update the auth preferences and this time specify the connector name authPreference, err = types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: "oidc", ConnectorName: "foo", }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, authPreference) - c.Assert(err, IsNil) + require.NoError(t, err) // hit the ping endpoing to get the auth type and connector name re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) - c.Assert(json.Unmarshal(re.Bytes(), &out), IsNil) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(re.Bytes(), &out)) // make sure the connector we get back is "foo" - c.Assert(out.Auth.OIDC.Name, Equals, "foo") + require.Equal(t, "foo", out.Auth.OIDC.Name) } // TestConstructSSHResponse checks if the secret package uses AES-GCM to // encrypt and decrypt data that passes through the ConstructSSHResponse // function. -func (s *WebSuite) TestConstructSSHResponse(c *C) { +func TestConstructSSHResponse(t *testing.T) { key, err := secret.NewKey() - c.Assert(err, IsNil) + require.NoError(t, err) u, err := url.Parse("http://www.example.com/callback") - c.Assert(err, IsNil) + require.NoError(t, err) query := u.Query() query.Set("secret_key", key.String()) u.RawQuery = query.Encode() @@ -1755,35 +1808,35 @@ func (s *WebSuite) TestConstructSSHResponse(c *C) { TLSCert: []byte{0x01}, ClientRedirectURL: u.String(), }) - c.Assert(err, IsNil) + require.NoError(t, err) - c.Assert(rawresp.Query().Get("secret"), Equals, "") - c.Assert(rawresp.Query().Get("secret_key"), Equals, "") - c.Assert(rawresp.Query().Get("response"), Not(Equals), "") + require.Empty(t, rawresp.Query().Get("secret")) + require.Empty(t, rawresp.Query().Get("secret_key")) + require.NotEmpty(t, rawresp.Query().Get("response")) plaintext, err := key.Open([]byte(rawresp.Query().Get("response"))) - c.Assert(err, IsNil) + require.NoError(t, err) var resp *auth.SSHLoginResponse err = json.Unmarshal(plaintext, &resp) - c.Assert(err, IsNil) - c.Assert(resp.Username, Equals, "foo") - c.Assert(resp.Cert, DeepEquals, []byte{0x00}) - c.Assert(resp.TLSCert, DeepEquals, []byte{0x01}) + require.NoError(t, err) + require.Equal(t, "foo", resp.Username) + require.EqualValues(t, []byte{0x00}, resp.Cert) + require.EqualValues(t, []byte{0x01}, resp.TLSCert) } // TestConstructSSHResponseLegacy checks if the secret package uses NaCl to // encrypt and decrypt data that passes through the ConstructSSHResponse // function. -func (s *WebSuite) TestConstructSSHResponseLegacy(c *C) { +func TestConstructSSHResponseLegacy(t *testing.T) { key, err := lemma_secret.NewKey() - c.Assert(err, IsNil) + require.NoError(t, err) lemma, err := lemma_secret.New(&lemma_secret.Config{KeyBytes: key}) - c.Assert(err, IsNil) + require.NoError(t, err) u, err := url.Parse("http://www.example.com/callback") - c.Assert(err, IsNil) + require.NoError(t, err) query := u.Query() query.Set("secret", lemma_secret.KeyToEncodedString(key)) u.RawQuery = query.Encode() @@ -1794,25 +1847,25 @@ func (s *WebSuite) TestConstructSSHResponseLegacy(c *C) { TLSCert: []byte{0x01}, ClientRedirectURL: u.String(), }) - c.Assert(err, IsNil) + require.NoError(t, err) - c.Assert(rawresp.Query().Get("secret"), Equals, "") - c.Assert(rawresp.Query().Get("secret_key"), Equals, "") - c.Assert(rawresp.Query().Get("response"), Not(Equals), "") + require.Empty(t, rawresp.Query().Get("secret")) + require.Empty(t, rawresp.Query().Get("secret_key")) + require.NotEmpty(t, rawresp.Query().Get("response")) var sealedData *lemma_secret.SealedBytes err = json.Unmarshal([]byte(rawresp.Query().Get("response")), &sealedData) - c.Assert(err, IsNil) + require.NoError(t, err) plaintext, err := lemma.Open(sealedData) - c.Assert(err, IsNil) + require.NoError(t, err) var resp *auth.SSHLoginResponse err = json.Unmarshal(plaintext, &resp) - c.Assert(err, IsNil) - c.Assert(resp.Username, Equals, "foo") - c.Assert(resp.Cert, DeepEquals, []byte{0x00}) - c.Assert(resp.TLSCert, DeepEquals, []byte{0x01}) + require.NoError(t, err) + require.Equal(t, "foo", resp.Username) + require.EqualValues(t, []byte{0x00}, resp.Cert) + require.EqualValues(t, []byte{0x01}, resp.TLSCert) } type byTimeAndIndex []apievents.AuditEvent @@ -1835,11 +1888,13 @@ func (f byTimeAndIndex) Swap(i, j int) { } // TestSearchClusterEvents makes sure web API allows querying events by type. -func (s *WebSuite) TestSearchClusterEvents(c *C) { +func TestSearchClusterEvents(t *testing.T) { + t.Parallel() // We need a clock that uses the current time here to work around // the fact that filelog doesn't support emitting past events. clock := clockwork.NewRealClock() + s := newWebSuite(t) sessionEvents := events.GenerateTestSession(events.SessionParams{ PrintEvents: 3, Clock: clock, @@ -1847,7 +1902,7 @@ func (s *WebSuite) TestSearchClusterEvents(c *C) { }) for _, e := range sessionEvents { - c.Assert(s.proxyClient.EmitAuditEvent(context.TODO(), e), IsNil) + require.NoError(t, s.proxyClient.EmitAuditEvent(s.ctx, e)) } sort.Sort(sort.Reverse(byTimeAndIndex(sessionEvents))) @@ -1932,49 +1987,50 @@ func (s *WebSuite) TestSearchClusterEvents(c *C) { }, } - pack := s.authPack(c, "foo") - // var sessionStartKey string + pack := s.authPack(t, "foo") for _, tc := range testCases { - result := s.searchEvents(c, pack.clt, tc.Query, []string{sessionStart.GetType(), sessionPrint.GetType(), sessionEnd.GetType()}) - c.Assert(result.Events, HasLen, len(tc.Result), Commentf(tc.Comment)) - for i, resultEvent := range result.Events { - c.Assert(resultEvent.GetType(), Equals, tc.Result[i].GetType(), Commentf(tc.Comment)) - c.Assert(resultEvent.GetID(), Equals, tc.Result[i].GetID(), Commentf(tc.Comment)) - } + tc := tc + t.Run(tc.Comment, func(t *testing.T) { + t.Parallel() + response, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "events", "search"), tc.Query) + require.NoError(t, err) + var result eventsListGetResponse + require.NoError(t, json.Unmarshal(response.Bytes(), &result)) - // Session prints do not have ID's, only sessionStart and sessionEnd. - // When retrieving events for sessionStart and sessionEnd, sessionStart is returned first. - if tc.TestStartKey { - c.Assert(result.StartKey, Equals, tc.StartKeyValue, Commentf(tc.Comment)) - } - } -} + require.Len(t, result.Events, len(tc.Result)) + for i, resultEvent := range result.Events { + require.Equal(t, tc.Result[i].GetType(), resultEvent.GetType()) + require.Equal(t, tc.Result[i].GetID(), resultEvent.GetID()) + } -func (s *WebSuite) searchEvents(c *C, clt *client.WebClient, query url.Values, filter []string) eventsListGetResponse { - response, err := clt.Get(context.Background(), clt.Endpoint("webapi", "sites", s.server.ClusterName(), "events", "search"), query) - c.Assert(err, IsNil) - var out eventsListGetResponse - c.Assert(json.Unmarshal(response.Bytes(), &out), IsNil) - return out + // Session prints do not have ID's, only sessionStart and sessionEnd. + // When retrieving events for sessionStart and sessionEnd, sessionStart is returned first. + if tc.TestStartKey { + require.Equal(t, tc.StartKeyValue, result.StartKey) + } + }) + } } -func (s *WebSuite) TestGetClusterDetails(c *C) { +func TestGetClusterDetails(t *testing.T) { + t.Parallel() + s := newWebSuite(t) site, err := s.proxyTunnel.GetSite(s.server.ClusterName()) - c.Assert(err, IsNil) - c.Assert(site, NotNil) + require.NoError(t, err) + require.NotNil(t, site) cluster, err := ui.GetClusterDetails(s.ctx, site) - c.Assert(err, IsNil) - c.Assert(cluster.Name, Equals, s.server.ClusterName()) - c.Assert(cluster.ProxyVersion, Equals, teleport.Version) - c.Assert(cluster.PublicURL, Equals, fmt.Sprintf("%v:%v", s.server.ClusterName(), defaults.HTTPListenPort)) - c.Assert(cluster.Status, Equals, teleport.RemoteClusterStatusOnline) - c.Assert(cluster.LastConnected, NotNil) - c.Assert(cluster.AuthVersion, Equals, teleport.Version) + require.NoError(t, err) + require.Equal(t, s.server.ClusterName(), cluster.Name) + require.Equal(t, teleport.Version, cluster.ProxyVersion) + require.Equal(t, fmt.Sprintf("%v:%v", s.server.ClusterName(), defaults.HTTPListenPort), cluster.PublicURL) + require.Equal(t, teleport.RemoteClusterStatusOnline, cluster.Status) + require.NotNil(t, cluster.LastConnected) + require.Equal(t, teleport.Version, cluster.AuthVersion) nodes, err := s.proxyClient.GetNodes(s.ctx, apidefaults.Namespace) - c.Assert(err, IsNil) - c.Assert(nodes, HasLen, cluster.NodeCount) + require.NoError(t, err) + require.Len(t, nodes, cluster.NodeCount) } type testModules struct { @@ -2036,7 +2092,9 @@ func TestTokenGeneration(t *testing.T) { } for _, tc := range tt { + tc := tc t.Run(tc.name, func(t *testing.T) { + t.Parallel() env := newWebPack(t, 1) proxy := env.proxies[0] @@ -2189,7 +2247,7 @@ func TestApplicationAccessDisabled(t *testing.T) { PublicAddr: "panel.example.com", }) require.NoError(t, err) - server, err := types.NewAppServerV3FromApp(app, "host", uuid.New()) + server, err := types.NewAppServerV3FromApp(app, "host", uuid.New().String()) require.NoError(t, err) _, err = env.server.Auth().UpsertApplicationServer(context.Background(), server) require.NoError(t, err) @@ -2623,9 +2681,11 @@ func TestCreateRegisterChallenge(t *testing.T) { } // TestCreateAppSession verifies that an existing session to the Web UI can -// be exchanged for a application specific session. -func (s *WebSuite) TestCreateAppSession(c *C) { - pack := s.authPack(c, "foo@example.com") +// be exchanged for an application specific session. +func TestCreateAppSession(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo@example.com") // Register an application called "panel". app, err := types.NewAppV3(types.Metadata{ @@ -2634,130 +2694,134 @@ func (s *WebSuite) TestCreateAppSession(c *C) { URI: "http://127.0.0.1:8080", PublicAddr: "panel.example.com", }) - c.Assert(err, IsNil) - server, err := types.NewAppServerV3FromApp(app, "host", uuid.New()) - c.Assert(err, IsNil) - _, err = s.server.Auth().UpsertApplicationServer(context.Background(), server) - c.Assert(err, IsNil) + require.NoError(t, err) + server, err := types.NewAppServerV3FromApp(app, "host", uuid.New().String()) + require.NoError(t, err) + _, err = s.server.Auth().UpsertApplicationServer(s.ctx, server) + require.NoError(t, err) // Extract the session ID and bearer token for the current session. rawCookie := *pack.cookies[0] cookieBytes, err := hex.DecodeString(rawCookie.Value) - c.Assert(err, IsNil) + require.NoError(t, err) var sessionCookie SessionCookie err = json.Unmarshal(cookieBytes, &sessionCookie) - c.Assert(err, IsNil) + require.NoError(t, err) tests := []struct { - inComment CommentInterface + name string inCreateRequest *CreateAppSessionRequest - outError bool + outError require.ErrorAssertionFunc outFQDN string outUsername string }{ { - inComment: Commentf("Valid request: all fields."), + name: "Valid request: all fields", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "panel.example.com", ClusterName: "localhost", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Valid request: without FQDN."), + name: "Valid request: without FQDN", inCreateRequest: &CreateAppSessionRequest{ PublicAddr: "panel.example.com", ClusterName: "localhost", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Valid request: only FQDN."), + name: "Valid request: only FQDN", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Invalid request: only public address."), + name: "Invalid request: only public address", inCreateRequest: &CreateAppSessionRequest{ PublicAddr: "panel.example.com", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Invalid request: only cluster name."), + name: "Invalid request: only cluster name", inCreateRequest: &CreateAppSessionRequest{ ClusterName: "localhost", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Invalid application."), + name: "Invalid application", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "invalid.example.com", ClusterName: "localhost", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Invalid cluster name."), + name: "Invalid cluster name", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "panel.example.com", ClusterName: "example.com", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Malicious request: all fields."), + name: "Malicious request: all fields", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com@malicious.com", PublicAddr: "panel.example.com", ClusterName: "localhost", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Malicious request: only FQDN."), + name: "Malicious request: only FQDN", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com@malicious.com", }, - outError: true, + outError: require.Error, }, } for _, tt := range tests { - // Make a request to create an application session for "panel". - endpoint := pack.clt.Endpoint("webapi", "sessions", "app") - resp, err := pack.clt.PostJSON(context.Background(), endpoint, tt.inCreateRequest) - c.Assert(err != nil, Equals, tt.outError, tt.inComment) - if tt.outError { - continue - } + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Make a request to create an application session for "panel". + endpoint := pack.clt.Endpoint("webapi", "sessions", "app") + resp, err := pack.clt.PostJSON(s.ctx, endpoint, tt.inCreateRequest) + tt.outError(t, err) + if err != nil { + return + } - // Unmarshal the response. - var response *CreateAppSessionResponse - c.Assert(json.Unmarshal(resp.Bytes(), &response), IsNil, tt.inComment) - c.Assert(response.FQDN, Equals, tt.outFQDN, tt.inComment) + // Unmarshal the response. + var response *CreateAppSessionResponse + require.NoError(t, json.Unmarshal(resp.Bytes(), &response)) + require.Equal(t, tt.outFQDN, response.FQDN) - // Verify that the application session was created. - session, err := s.server.Auth().GetAppSession(context.Background(), types.GetAppSessionRequest{ - SessionID: response.CookieValue, + // Verify that the application session was created. + sess, err := s.server.Auth().GetAppSession(s.ctx, types.GetAppSessionRequest{ + SessionID: response.CookieValue, + }) + require.NoError(t, err) + require.Equal(t, tt.outUsername, sess.GetUser()) + require.Equal(t, response.CookieValue, sess.GetName()) }) - c.Assert(err, IsNil) - c.Assert(session.GetUser(), Equals, tt.outUsername, tt.inComment) - c.Assert(session.GetName(), Equals, response.CookieValue, tt.inComment) } } @@ -2769,7 +2833,7 @@ func TestNewSessionResponseWithRenewSession(t *testing.T) { duration := time.Duration(5) * time.Minute cfg := types.DefaultClusterNetworkingConfig() cfg.SetWebIdleTimeout(duration) - env.server.Auth().SetClusterNetworkingConfig(context.Background(), cfg) + require.NoError(t, env.server.Auth().SetClusterNetworkingConfig(context.Background(), cfg)) proxy := env.proxies[0] pack := proxy.authPack(t, "foo") @@ -2807,7 +2871,7 @@ func TestWebSessionsRenewDoesNotBreakExistingTerminalSession(t *testing.T) { env.clock.Advance(auth.BearerTokenTTL - delta) // Renew the session using the 1st proxy - resp := pack1.renewSession(context.TODO(), t) + resp := pack1.renewSession(context.Background(), t) // Expire the old session and make sure it has been removed. // The bearer token is also removed after this point, so we have to @@ -2816,7 +2880,7 @@ func TestWebSessionsRenewDoesNotBreakExistingTerminalSession(t *testing.T) { pack2 = proxy2.authPackFromResponse(t, resp) // Verify that access via the 2nd proxy also works for the same session - pack2.validateAPI(context.TODO(), t) + pack2.validateAPI(context.Background(), t) // Check whether the terminal session is still active validateTerminalStream(t, ws) @@ -2845,12 +2909,12 @@ func TestWebSessionsRenewAllowsOldBearerTokenToLinger(t *testing.T) { // prevSessionCookie := *pack.cookies[0] prevBearerToken := pack.session.Token - resp := pack.renewSession(context.TODO(), t) + resp := pack.renewSession(context.Background(), t) newPack := proxy.authPackFromResponse(t, resp) // new session is functioning - newPack.validateAPI(context.TODO(), t) + newPack.validateAPI(context.Background(), t) sessionCookie := *newPack.cookies[0] bearerToken := newPack.session.Token @@ -2873,7 +2937,7 @@ func TestWebSessionsRenewAllowsOldBearerTokenToLinger(t *testing.T) { // now expire the old session and make sure it has been removed env.clock.Advance(delta) - _, err = proxy.client.GetWebSession(context.TODO(), types.GetWebSessionRequest{ + _, err = proxy.client.GetWebSession(context.Background(), types.GetWebSessionRequest{ User: "foo", SessionID: prevSessionID, }) @@ -2924,7 +2988,7 @@ func TestChangeUserAuthentication_recoveryCodesReturnedForCloud(t *testing.T) { // Creaet a username that is not a valid email format for recovery. teleUser, err := types.NewUser("invalid-name-for-recovery") require.NoError(t, err) - env.server.Auth().CreateUser(ctx, teleUser) + require.NoError(t, env.server.Auth().CreateUser(ctx, teleUser)) // Create a reset password token and secrets. resetToken, err := env.server.Auth().CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ @@ -2954,7 +3018,7 @@ func TestChangeUserAuthentication_recoveryCodesReturnedForCloud(t *testing.T) { // Create a user that is valid for recovery. teleUser, err = types.NewUser("valid-username@example.com") require.NoError(t, err) - env.server.Auth().CreateUser(ctx, teleUser) + require.NoError(t, env.server.Auth().CreateUser(ctx, teleUser)) // Create a reset password token and secrets. resetToken, err = env.server.Auth().CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ @@ -2982,18 +3046,6 @@ func TestChangeUserAuthentication_recoveryCodesReturnedForCloud(t *testing.T) { require.NotEmpty(t, re.Recovery.Created) } -type authProviderMock struct { - server types.ServerV2 -} - -func (mock authProviderMock) GetNodes(ctx context.Context, n string, opts ...services.MarshalOption) ([]types.Server, error) { - return []types.Server{&mock.server}, nil -} - -func (mock authProviderMock) GetSessionEvents(n string, s session.ID, c int, p bool) ([]events.EventFields, error) { - return []events.EventFields{}, nil -} - func (s *WebSuite) makeTerminal(pack *authPack, opts ...session.ID) (*websocket.Conn, error) { var sessionID session.ID if len(opts) == 0 { @@ -3067,7 +3119,7 @@ func waitForOutput(stream *terminalStream, substr string) error { } func (s *WebSuite) waitForRawEvent(ws *websocket.Conn, timeout time.Duration) error { - timeoutContext, timeoutCancel := context.WithTimeout(context.Background(), timeout) + timeoutContext, timeoutCancel := context.WithTimeout(s.ctx, timeout) defer timeoutCancel() done := make(chan error, 1) @@ -3106,7 +3158,7 @@ func (s *WebSuite) waitForRawEvent(ws *websocket.Conn, timeout time.Duration) er } func (s *WebSuite) waitForResizeEvent(ws *websocket.Conn, timeout time.Duration) error { - timeoutContext, timeoutCancel := context.WithTimeout(context.Background(), timeout) + timeoutContext, timeoutCancel := context.WithTimeout(s.ctx, timeout) defer timeoutCancel() done := make(chan error, 1) @@ -3391,6 +3443,15 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(proxyLockWatcher.Close) + proxyNodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: client, + }, + }) + require.NoError(t, err) + t.Cleanup(proxyNodeWatcher.Close) + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -3404,6 +3465,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula DirectClusters: []reversetunnel.DirectCluster{{Name: authServer.ClusterName(), Client: client}}, DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, + NodeWatcher: proxyNodeWatcher, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunServer.Close()) }) @@ -3425,6 +3487,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(clock), regular.SetLockWatcher(proxyLockWatcher), + regular.SetNodeWatcher(proxyNodeWatcher), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, proxyServer.Close()) }) @@ -3518,7 +3581,7 @@ func (r *proxy) authPack(t *testing.T, user string) *authPack { err = r.auth.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) - r.createUser(context.TODO(), t, user, loginUser, pass, otpSecret) + r.createUser(context.Background(), t, user, loginUser, pass, otpSecret) // create a valid otp token validToken, err := totp.GenerateCode(otpSecret, r.clock.Now()) @@ -3663,7 +3726,7 @@ func (r *proxy) makeTerminal(t *testing.T, pack *authPack, sessionID session.ID) ws, err := websocket.DialConfig(wscfg) require.NoError(t, err) - t.Cleanup(func() { ws.Close() }) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) return ws } diff --git a/lib/web/files.go b/lib/web/files.go index d603ecea0c6f0..4d568a6dd3fd8 100644 --- a/lib/web/files.go +++ b/lib/web/files.go @@ -21,7 +21,6 @@ import ( "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/sshutils/scp" @@ -57,17 +56,13 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou namespace: defaults.Namespace, } - clt, err := ctx.GetUserClient(site) - if err != nil { - return nil, trace.Wrap(err) - } - ft := fileTransfer{ ctx: ctx, - authClient: clt, + site: site, proxyHostPort: h.ProxyHostPort(), } + var err error isUpload := r.Method == http.MethodPost if isUpload { err = ft.upload(req, r) @@ -85,7 +80,7 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou type fileTransfer struct { // ctx is a web session context for the currently logged in user. ctx *SessionContext - authClient auth.ClientI + site reversetunnel.RemoteSite proxyHostPort string } @@ -145,12 +140,12 @@ func (f *fileTransfer) createClient(req fileTransferRequest, httpReq *http.Reque return nil, trace.BadParameter("missing login") } - servers, err := f.authClient.GetNodes(httpReq.Context(), req.namespace) + watcher, err := f.site.NodeWatcher() if err != nil { return nil, trace.Wrap(err) } - hostName, hostPort, err := resolveServerHostPort(req.server, servers) + hostName, hostPort, err := resolveServerHostPort(req.server, watcher) if err != nil { return nil, trace.BadParameter("invalid server name %q: %v", req.server, err) } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index f217271d77f2b..8720f81106c8e 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -26,6 +26,9 @@ import ( "sync" "time" + "github.com/gogo/protobuf/proto" + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "golang.org/x/net/websocket" "golang.org/x/text/encoding" @@ -45,11 +48,6 @@ import ( "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" - - "github.com/gravitational/trace" - - "github.com/gogo/protobuf/proto" - "github.com/sirupsen/logrus" ) // TerminalRequest describes a request to create a web-based terminal @@ -83,15 +81,9 @@ type TerminalRequest struct { KeepAliveInterval time.Duration } -// AuthProvider is a subset of the full Auth API. -type AuthProvider interface { - GetNodes(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.Server, error) - GetSessionEvents(namespace string, sid session.ID, after int, includePrintEvents bool) ([]events.EventFields, error) -} - // NewTerminal creates a web-based terminal based on WebSockets and returns a // new TerminalHandler. -func NewTerminal(ctx context.Context, req TerminalRequest, authProvider AuthProvider, sessCtx *SessionContext) (*TerminalHandler, error) { +func NewTerminal(req TerminalRequest, getter NodesGetter, sessCtx *SessionContext) (*TerminalHandler, error) { // Make sure whatever session is requested is a valid session. _, err := session.ParseID(string(req.SessionID)) if err != nil { @@ -105,16 +97,11 @@ func NewTerminal(ctx context.Context, req TerminalRequest, authProvider AuthProv return nil, trace.BadParameter("term: bad term dimensions") } - servers, err := authProvider.GetNodes(ctx, req.Namespace) - if err != nil { - return nil, trace.Wrap(err) - } - // DELETE IN: 5.0 // // All proxies will support lookup by uuid, so host/port lookup // and fallback can be dropped entirely. - hostName, hostPort, err := resolveServerHostPort(req.Server, servers) + hostName, hostPort, err := resolveServerHostPort(req.Server, getter) if err != nil { return nil, trace.BadParameter("invalid server name %q: %v", req.Server, err) } @@ -123,14 +110,13 @@ func NewTerminal(ctx context.Context, req TerminalRequest, authProvider AuthProv log: logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentWebsocket, }), - params: req, - ctx: sessCtx, - hostName: hostName, - hostPort: hostPort, - hostUUID: req.Server, - authProvider: authProvider, - encoder: unicode.UTF8.NewEncoder(), - decoder: unicode.UTF8.NewDecoder(), + params: req, + ctx: sessCtx, + hostName: hostName, + hostPort: hostPort, + hostUUID: req.Server, + encoder: unicode.UTF8.NewEncoder(), + decoder: unicode.UTF8.NewDecoder(), }, nil } @@ -164,9 +150,6 @@ type TerminalHandler struct { // terminalCancel is used to signal when the terminal session is closing. terminalCancel context.CancelFunc - // authProvider is used to fetch nodes and sessions from the backend. - authProvider AuthProvider - // encoder is used to encode strings into UTF-8. encoder *encoding.Encoder @@ -599,9 +582,15 @@ func (t *TerminalHandler) writeError(err error, ws *websocket.Conn) error { return nil } +// NodesGetter is a function that retrieves a subset of nodes matching +// the filter criteria. +type NodesGetter interface { + GetNodes(fn func(n services.Node) bool) []types.Server +} + // resolveServerHostPort parses server name and attempts to resolve hostname // and port. -func resolveServerHostPort(servername string, existingServers []types.Server) (string, int, error) { +func resolveServerHostPort(servername string, getter NodesGetter) (string, int, error) { // If port is 0, client wants us to figure out which port to use. var defaultPort = 0 @@ -609,12 +598,20 @@ func resolveServerHostPort(servername string, existingServers []types.Server) (s return "", defaultPort, trace.BadParameter("empty server name") } + var hostname string // Check if servername is UUID. - for i := range existingServers { - node := existingServers[i] - if node.GetName() == servername { - return node.GetHostname(), defaultPort, nil + getter.GetNodes(func(n services.Node) bool { + if hostname != "" { + return false + } + if n.GetName() == servername { + hostname = n.GetHostname() } + return false + }) + + if hostname != "" { + return hostname, defaultPort, nil } if !strings.Contains(servername, ":") { diff --git a/lib/web/ui/cluster.go b/lib/web/ui/cluster.go index c216243272862..91982c589867d 100644 --- a/lib/web/ui/cluster.go +++ b/lib/web/ui/cluster.go @@ -21,7 +21,6 @@ import ( "sort" "time" - apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" @@ -92,7 +91,7 @@ func GetClusterDetails(ctx context.Context, site reversetunnel.RemoteSite, opts return nil, trace.Wrap(err) } - nodes, err := clt.GetNodes(ctx, apidefaults.Namespace, opts...) + watcher, err := site.NodeWatcher() if err != nil { return nil, trace.Wrap(err) } @@ -121,7 +120,7 @@ func GetClusterDetails(ctx context.Context, site reversetunnel.RemoteSite, opts Name: site.GetName(), LastConnected: site.GetLastConnected(), Status: site.GetStatus(), - NodeCount: len(nodes), + NodeCount: watcher.NodeCount(), PublicURL: proxyHost, AuthVersion: authVersion, ProxyVersion: proxyVersion, diff --git a/tool/tsh/proxy_test.go b/tool/tsh/proxy_test.go index 5fbc857cb6298..a6b7562335334 100644 --- a/tool/tsh/proxy_test.go +++ b/tool/tsh/proxy_test.go @@ -51,10 +51,8 @@ func TestTSHSSH(t *testing.T) { lib.SetInsecureDevMode(true) defer lib.SetInsecureDevMode(false) - os.RemoveAll(profile.FullProfilePath("")) - t.Cleanup(func() { - os.RemoveAll(profile.FullProfilePath("")) - }) + require.NoError(t, os.RemoveAll(profile.FullProfilePath(""))) + t.Cleanup(func() { require.NoError(t, os.RemoveAll(profile.FullProfilePath(""))) }) s := newTestSuite(t, withRootConfigFunc(func(cfg *service.Config) { @@ -142,12 +140,14 @@ func testLeafClusterSSHAccess(t *testing.T, s *suite) { }) require.NoError(t, err) - err = Run([]string{ - "ssh", - s.leaf.Config.Hostname, - "echo", "hello", - }) - require.NoError(t, err) + require.Eventually(t, func() bool { + err = Run([]string{ + "ssh", + s.leaf.Config.Hostname, + "echo", "hello", + }) + return err == nil + }, 5*time.Second, time.Second) identityFile := path.Join(t.TempDir(), "identity.pem") err = Run([]string{ diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index b396491ab5339..23620278aec94 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -19,6 +19,7 @@ package main import ( "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -294,6 +295,14 @@ func (c *CLIConf) Stderr() io.Writer { return os.Stderr } +type exitCodeError struct { + code int +} + +func (e *exitCodeError) Error() string { + return fmt.Sprintf("exit code %d", e.code) +} + func main() { cmdLineOrig := os.Args[1:] var cmdLine []string @@ -309,6 +318,10 @@ func main() { cmdLine = cmdLineOrig } if err := Run(cmdLine); err != nil { + var exitError *exitCodeError + if errors.As(err, &exitError) { + os.Exit(exitError.code) + } utils.FatalError(err) } } @@ -1174,7 +1187,7 @@ func onLogout(cf *CLIConf) error { if err != nil { if trace.IsNotFound(err) { fmt.Printf("User %v already logged out from %v.\n", cf.Username, proxyHost) - os.Exit(1) + return trace.Wrap(&exitCodeError{code: 1}) } return trace.Wrap(err) } @@ -1734,15 +1747,14 @@ func onSSH(cf *CLIConf) error { fmt.Fprintf(os.Stderr, "Hint: try addressing the node by unique id (ex: tsh ssh user@node-id)\n") fmt.Fprintf(os.Stderr, "Hint: use 'tsh ls -v' to list all nodes with their unique ids\n") fmt.Fprintf(os.Stderr, "\n") - os.Exit(1) + return trace.Wrap(&exitCodeError{code: 1}) } // exit with the same exit status as the failed command: if tc.ExitStatus != 0 { fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err)) - os.Exit(tc.ExitStatus) - } else { - return trace.Wrap(err) + return trace.Wrap(&exitCodeError{code: tc.ExitStatus}) } + return trace.Wrap(err) } return nil } @@ -1761,7 +1773,7 @@ func onBenchmark(cf *CLIConf) error { result, err := cnf.Benchmark(cf.Context, tc) if err != nil { fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err)) - os.Exit(255) + return trace.Wrap(&exitCodeError{code: 255}) } fmt.Printf("\n") fmt.Printf("* Requests originated: %v\n", result.RequestsOriginated) @@ -1829,7 +1841,7 @@ func onSCP(cf *CLIConf) error { // exit with the same exit status as the failed command: if tc.ExitStatus != 0 { fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err)) - os.Exit(tc.ExitStatus) + return trace.Wrap(&exitCodeError{code: tc.ExitStatus}) } return trace.Wrap(err) }