diff --git a/lib/reversetunnel/api.go b/lib/reversetunnel/api.go index 38da060501603..9a0cd8196a22c 100644 --- a/lib/reversetunnel/api.go +++ b/lib/reversetunnel/api.go @@ -26,6 +26,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" ) @@ -95,6 +96,8 @@ type RemoteSite interface { // CachingAccessPoint returns access point that is lightweight // but is resilient to auth server crashes CachingAccessPoint() (auth.AccessPoint, error) + // NodeWatcher returns the node watcher that maintains the node set for the site + NodeWatcher() (*services.NodeWatcher, error) // GetTunnelsCount returns the amount of active inbound tunnels // from the remote cluster GetTunnelsCount() int diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 3040d6ffd016e..9500f4ef62e23 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -22,8 +22,6 @@ import ( "sync" "time" - "golang.org/x/crypto/ssh" - "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" @@ -34,11 +32,12 @@ import ( "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/proxy" - "github.com/prometheus/client_golang/prometheus" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" ) func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSite, error) { @@ -129,6 +128,11 @@ func (s *localSite) CachingAccessPoint() (auth.AccessPoint, error) { return s.accessPoint, nil } +// NodeWatcher returns a services.NodeWatcher for this cluster. +func (s *localSite) NodeWatcher() (*services.NodeWatcher, error) { + return s.srv.NodeWatcher, nil +} + // GetClient returns a client to the full Auth Server API. func (s *localSite) GetClient() (auth.ClientI, error) { return s.client, nil @@ -518,14 +522,7 @@ func (s *localSite) periodicFunctions() { // sshTunnelStats reports SSH tunnel statistics for the cluster. func (s *localSite) sshTunnelStats() error { - servers, err := s.accessPoint.GetNodes(s.srv.ctx, apidefaults.Namespace) - if err != nil { - return trace.Wrap(err) - } - - var missing []string - - for _, server := range servers { + missing := s.srv.NodeWatcher.GetNodes(func(server services.Node) bool { // Skip over any servers that that have a TTL larger than announce TTL (10 // minutes) and are non-IoT SSH servers (they won't have tunnels). // @@ -534,10 +531,10 @@ func (s *localSite) sshTunnelStats() error { // their TTL value. ttl := s.clock.Now().Add(-1 * apidefaults.ServerAnnounceTTL) if server.Expiry().Before(ttl) { - continue + return false } if !server.GetUseTunnel() { - continue + return false } // Check if the tunnel actually exists. @@ -545,12 +542,9 @@ func (s *localSite) sshTunnelStats() error { ServerID: fmt.Sprintf("%v.%v", server.GetName(), s.domainName), ConnType: types.NodeTunnel, }) - if err == nil { - continue - } - missing = append(missing, server.GetName()) - } + return err != nil + }) // Update Prometheus metrics and also log if any tunnels are missing. missingSSHTunnels.Set(float64(len(missing))) diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index 35f803a689cfd..4716035f98962 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -87,6 +87,14 @@ func (p *clusterPeers) CachingAccessPoint() (auth.AccessPoint, error) { return peer.CachingAccessPoint() } +func (p *clusterPeers) NodeWatcher() (*services.NodeWatcher, error) { + peer, err := p.pickPeer() + if err != nil { + return nil, trace.Wrap(err) + } + return peer.NodeWatcher() +} + func (p *clusterPeers) GetClient() (auth.ClientI, error) { peer, err := p.pickPeer() if err != nil { @@ -191,6 +199,10 @@ func (s *clusterPeer) CachingAccessPoint() (auth.AccessPoint, error) { return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s) } +func (s *clusterPeer) NodeWatcher() (*services.NodeWatcher, error) { + return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s) +} + func (s *clusterPeer) GetClient() (auth.ClientI, error) { return nil, trace.ConnectionProblem(nil, "unable to fetch client, this proxy %v has not been discovered yet, try again later", s) } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 4c6df631085f7..cc635c9586f9b 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -33,13 +33,14 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) -// remoteSite is a remote site that established the inbound connecton to +// remoteSite is a remote site that established the inbound connection to // the local reverse tunnel server, and now it can provide access to the // cluster behind it. type remoteSite struct { @@ -76,6 +77,9 @@ type remoteSite struct { // the remote cluster this site belongs to. remoteAccessPoint auth.AccessPoint + // nodeWatcher provides access the node set for the remote site + nodeWatcher *services.NodeWatcher + // remoteCA is the last remote certificate authority recorded by the client. // It is used to detect CA rotation status changes. If the rotation // state has been changed, the tunnel will reconnect to re-create the client @@ -137,6 +141,11 @@ func (s *remoteSite) CachingAccessPoint() (auth.AccessPoint, error) { return s.remoteAccessPoint, nil } +// NodeWatcher returns the services.NodeWatcher for the remote cluster. +func (s *remoteSite) NodeWatcher() (*services.NodeWatcher, error) { + return s.nodeWatcher, nil +} + func (s *remoteSite) GetClient() (auth.ClientI, error) { return s.remoteClient, nil } @@ -379,7 +388,7 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch } else { s.WithFields(log.Fields{"nodeID": conn.nodeID}).Debugf("Ping <- %v", conn.conn.RemoteAddr()) } - tm := time.Now().UTC() + tm := s.clock.Now().UTC() conn.setLastHeartbeat(tm) go s.registerHeartbeat(tm) // Note that time.After is re-created everytime a request is processed. diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 0aba941727cd7..af89c620caf0c 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -202,6 +202,9 @@ type Config struct { // LockWatcher is a lock watcher. LockWatcher *services.LockWatcher + + // NodeWatcher is a node watcher. + NodeWatcher *services.NodeWatcher } // CheckAndSetDefaults checks parameters and sets default values @@ -253,6 +256,9 @@ func (cfg *Config) CheckAndSetDefaults() error { if cfg.LockWatcher == nil { return trace.BadParameter("missing parameter LockWatcher") } + if cfg.NodeWatcher == nil { + return trace.BadParameter("missing parameter NodeWatcher") + } return nil } @@ -889,7 +895,7 @@ func (s *server) upsertRemoteCluster(conn net.Conn, sshConn *ssh.ServerConn) (*r // treat first connection as a registered heartbeat, // otherwise the connection information will appear after initial // heartbeat delay - go site.registerHeartbeat(time.Now()) + go site.registerHeartbeat(s.Clock.Now()) return site, remoteConn, nil } @@ -1022,7 +1028,7 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, types.TunnelConnectionSpecV2{ ClusterName: domainName, ProxyName: srv.ID, - LastHeartbeat: time.Now().UTC(), + LastHeartbeat: srv.Clock.Now().UTC(), }, ) if err != nil { @@ -1054,27 +1060,42 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, clt, _, err := remoteSite.getRemoteClient() if err != nil { + cancel() return nil, trace.Wrap(err) } remoteSite.remoteClient = clt remoteVersion, err := getRemoteAuthVersion(closeContext, sconn) if err != nil { + cancel() return nil, trace.Wrap(err) } accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName) if err != nil { + cancel() return nil, trace.Wrap(err) } remoteSite.remoteAccessPoint = accessPoint - + nodeWatcher, err := services.NewNodeWatcher(closeContext, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: srv.Component, + Client: accessPoint, + Log: srv.Log, + }, + }) + if err != nil { + cancel() + return nil, trace.Wrap(err) + } + remoteSite.nodeWatcher = nodeWatcher // instantiate a cache of host certificates for the forwarding server. the // certificate cache is created in each site (instead of creating it in // reversetunnel.server and passing it along) so that the host certificate // is signed by the correct certificate authority. certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient) if err != nil { + cancel() return nil, trace.Wrap(err) } remoteSite.certificateCache = certificateCache @@ -1087,7 +1108,8 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, Clock: srv.Clock, }) if err != nil { - return nil, err + cancel() + return nil, trace.Wrap(err) } go remoteSite.updateCertAuthorities(caRetry) @@ -1100,7 +1122,8 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, Clock: srv.Clock, }) if err != nil { - return nil, err + cancel() + return nil, trace.Wrap(err) } go remoteSite.updateLocks(lockRetry) diff --git a/lib/service/service.go b/lib/service/service.go index 0f5b899e2b0d0..9173e4ae55dbc 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2487,6 +2487,17 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } + nodeWatcher, err := services.NewNodeWatcher(process.ExitContext(), services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Log: process.log.WithField(trace.Component, teleport.ComponentProxy), + Client: conn.Client, + }, + }) + if err != nil { + return trace.Wrap(err) + } + // register SSH reverse tunnel server that accepts connections // from remote teleport nodes var tsrv reversetunnel.Server @@ -2520,6 +2531,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Emitter: streamEmitter, Log: process.log, LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, }) if err != nil { return trace.Wrap(err) @@ -2750,6 +2762,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { }), regular.SetEmitter(streamEmitter), regular.SetLockWatcher(lockWatcher), + regular.SetNodeWatcher(nodeWatcher), ) if err != nil { return trace.Wrap(err) diff --git a/lib/services/presence.go b/lib/services/presence.go index aec26906d1dc9..9bef879871706 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -29,6 +29,12 @@ type ProxyGetter interface { GetProxies() ([]types.Server, error) } +// NodesGetter is a service that gets nodes. +type NodesGetter interface { + // GetNodes returns a list of registered servers. + GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error) +} + // Presence records and reports the presence of all components // of the cluster - Nodes, Proxies and SSH nodes type Presence interface { @@ -43,13 +49,12 @@ type Presence interface { // GetNode returns a node by name and namespace. GetNode(ctx context.Context, namespace, name string) (types.Server, error) - - // GetNodes returns a list of registered servers. - GetNodes(ctx context.Context, namespace string, opts ...MarshalOption) ([]types.Server, error) - // ListNodes returns a paginated list of registered servers. ListNodes(ctx context.Context, req proto.ListNodesRequest) (nodes []types.Server, nextKey string, err error) + // NodesGetter gets nodes + NodesGetter + // DeleteAllNodes deletes all nodes in a namespace. DeleteAllNodes(ctx context.Context, namespace string) error diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 3d11524cc0fb4..73e083f2f1360 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -22,6 +22,7 @@ import ( "time" "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" @@ -819,6 +820,60 @@ func (c *caCollector) GetCurrent() []types.CertAuthority { func (c *caCollector) notifyStale() {} +// NodeWatcherConfig is a NodeWatcher configuration. +type NodeWatcherConfig struct { + ResourceWatcherConfig + // NodesGetter is used to directly fetch the list of active nodes. + NodesGetter +} + +// CheckAndSetDefaults checks parameters and sets default values. +func (cfg *NodeWatcherConfig) CheckAndSetDefaults() error { + if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + if cfg.NodesGetter == nil { + getter, ok := cfg.Client.(NodesGetter) + if !ok { + return trace.BadParameter("missing parameter NodesGetter and Client not usable as NodesGetter") + } + cfg.NodesGetter = getter + } + return nil +} + +// NewNodeWatcher returns a new instance of NodeWatcher. +func NewNodeWatcher(ctx context.Context, cfg NodeWatcherConfig) (*NodeWatcher, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + collector := &nodeCollector{ + NodeWatcherConfig: cfg, + current: map[string]types.Server{}, + } + watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) + if err != nil { + return nil, trace.Wrap(err) + } + return &NodeWatcher{watcher, collector}, nil +} + +// NodeWatcher is built on top of resourceWatcher to monitor additions +// and deletions to the set of nodes. +type NodeWatcher struct { + *resourceWatcher + *nodeCollector +} + +// nodeCollector accompanies resourceWatcher when monitoring nodes. +type nodeCollector struct { + NodeWatcherConfig + // current holds a map of the currently known nodes (keyed by server name, + // RWMutex protected). + current map[string]types.Server + rw sync.RWMutex +} + func casToSlice(host map[string]types.CertAuthority, user map[string]types.CertAuthority) []types.CertAuthority { slice := make([]types.CertAuthority, 0, len(host)+len(user)) for _, ca := range host { @@ -829,3 +884,106 @@ func casToSlice(host map[string]types.CertAuthority, user map[string]types.CertA } return slice } + +// Node is a readonly subset of the types.Server interface which +// users may filter by in GetNodes. +type Node interface { + // Resource provides common resource headers + types.Resource + // GetTeleportVersion returns the teleport version the server is running on + GetTeleportVersion() string + // GetAddr return server address + GetAddr() string + // GetHostname returns server hostname + GetHostname() string + // GetNamespace returns server namespace + GetNamespace() string + // GetLabels returns server's static label key pairs + GetLabels() map[string]string + // GetCmdLabels gets command labels + GetCmdLabels() map[string]types.CommandLabel + // GetPublicAddr is an optional field that returns the public address this cluster can be reached at. + GetPublicAddr() string + // GetRotation gets the state of certificate authority rotation. + GetRotation() types.Rotation + // GetUseTunnel gets if a reverse tunnel should be used to connect to this node. + GetUseTunnel() bool + // GetAllLabels returns all resource's labels. + GetAllLabels() map[string]string +} + +// GetNodes allows callers to retrieve a subset of nodes that match the filter provided. The +// returned servers are a copy and can be safely modified. It is intentionally hard to retrieve +// the full set of nodes to reduce the number of copies needed since the number of nodes can get +// quite large and doing so can be expensive. +func (n *nodeCollector) GetNodes(fn func(n Node) bool) []types.Server { + n.rw.RLock() + defer n.rw.RUnlock() + + var matched []types.Server + for _, server := range n.current { + if fn(server) { + matched = append(matched, server.DeepCopy()) + } + } + + return matched +} + +func (n *nodeCollector) NodeCount() int { + n.rw.RLock() + defer n.rw.RUnlock() + return len(n.current) +} + +// resourceKind specifies the resource kind to watch. +func (n *nodeCollector) resourceKind() string { + return types.KindNode +} + +// getResourcesAndUpdateCurrent is called when the resources should be +// (re-)fetched directly. +func (n *nodeCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { + nodes, err := n.NodesGetter.GetNodes(ctx, apidefaults.Namespace) + if err != nil { + return trace.Wrap(err) + } + if len(nodes) == 0 { + return nil + } + newCurrent := make(map[string]types.Server, len(nodes)) + for _, node := range nodes { + newCurrent[node.GetName()] = node + } + n.rw.Lock() + defer n.rw.Unlock() + n.current = newCurrent + return nil +} + +// processEventAndUpdateCurrent is called when a watcher event is received. +func (n *nodeCollector) processEventAndUpdateCurrent(ctx context.Context, event types.Event) { + if event.Resource == nil || event.Resource.GetKind() != types.KindNode { + n.Log.Warningf("Unexpected event: %v.", event) + return + } + + n.rw.Lock() + defer n.rw.Unlock() + + switch event.Type { + case types.OpDelete: + delete(n.current, event.Resource.GetName()) + case types.OpPut: + server, ok := event.Resource.(types.Server) + if !ok { + n.Log.Warningf("Unexpected type %T.", event.Resource) + return + } + n.current[server.GetName()] = server + default: + n.Log.Warningf("Skipping unsupported event type %s.", event.Type) + } +} + +func (n *nodeCollector) notifyStale() {} diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index e401426f29b37..3bfecf88c67e8 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -20,13 +20,19 @@ import ( "context" "crypto/x509/pkix" "errors" + "fmt" "sync" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend/lite" @@ -34,9 +40,6 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/tlsca" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/require" ) var _ types.Events = (*errorWatcher)(nil) @@ -657,3 +660,72 @@ func newCertAuthority(t *testing.T, name string, caType types.CertAuthType) type require.NoError(t, err) return ca } + +func TestNodeWatcher(t *testing.T) { + t.Parallel() + ctx := context.Background() + + bk, err := lite.NewWithConfig(ctx, lite.Config{ + Path: t.TempDir(), + PollStreamPeriod: 200 * time.Millisecond, + }) + require.NoError(t, err) + + type client struct { + services.Presence + types.Events + } + + presence := local.NewPresenceService(bk) + w, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: &client{ + Presence: presence, + Events: local.NewEventsService(bk, nil), + }, + }, + }) + require.NoError(t, err) + t.Cleanup(w.Close) + + // Add some node servers. + nodes := make([]types.Server, 0, 5) + for i := 0; i < 5; i++ { + node := newNodeServer(t, fmt.Sprintf("node%d", i), "127.0.0.1:2023", i%2 == 0) + _, err = presence.UpsertNode(ctx, node) + require.NoError(t, err) + nodes = append(nodes, node) + } + + require.Eventually(t, func() bool { + filtered := w.GetNodes(func(n services.Node) bool { + return true + }) + return len(filtered) == len(nodes) + }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") + + require.Len(t, w.GetNodes(func(n services.Node) bool { return n.GetUseTunnel() }), 3) + + require.NoError(t, presence.DeleteNode(ctx, apidefaults.Namespace, nodes[0].GetName())) + + require.Eventually(t, func() bool { + filtered := w.GetNodes(func(n services.Node) bool { + return true + }) + return len(filtered) == len(nodes)-1 + }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") + + require.Empty(t, w.GetNodes(func(n services.Node) bool { return n.GetName() == nodes[0].GetName() })) + +} + +func newNodeServer(t *testing.T, name, addr string, tunnel bool) types.Server { + s, err := types.NewServer(name, types.KindNode, types.ServerSpecV2{ + Addr: addr, + PublicAddr: addr, + UseTunnel: tunnel, + }) + require.NoError(t, err) + return s +} diff --git a/lib/srv/regular/proxy.go b/lib/srv/regular/proxy.go index 31f8f3503a309..d408c9813ddc4 100644 --- a/lib/srv/regular/proxy.go +++ b/lib/srv/regular/proxy.go @@ -34,6 +34,7 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -314,18 +315,15 @@ func (t *proxySubsys) proxyToHost( // network resolution (by IP or DNS) // var ( - strategy types.RoutingStrategy - servers []types.Server - err error + strategy types.RoutingStrategy + nodeWatcher *services.NodeWatcher + err error ) localCluster, _ := t.srv.authService.GetClusterName() // going to "local" CA? let's use the caching 'auth service' directly and avoid // hitting the reverse tunnel link (it can be offline if the CA is down) if site.GetName() == localCluster.GetName() { - servers, err = t.srv.authService.GetNodes(ctx.CancelContext(), t.namespace) - if err != nil { - t.log.Warn(err) - } + nodeWatcher = t.srv.nodeWatcher cfg, err := t.srv.authService.GetClusterNetworkingConfig(ctx.CancelContext()) if err != nil { @@ -339,9 +337,11 @@ func (t *proxySubsys) proxyToHost( if err != nil { t.log.Warn(err) } else { - servers, err = siteClient.GetNodes(ctx.CancelContext(), t.namespace) + watcher, err := site.NodeWatcher() if err != nil { t.log.Warn(err) + } else { + nodeWatcher = watcher } cfg, err := siteClient.GetClusterNetworkingConfig(ctx.CancelContext()) @@ -358,7 +358,7 @@ func (t *proxySubsys) proxyToHost( t.log.Debugf("proxy connecting to host=%v port=%v, exact port=%v, strategy=%s", t.host, t.port, t.SpecifiedPort(), strategy) // determine which server to connect to - server, err := t.getMatchingServer(servers, strategy) + server, err := t.getMatchingServer(nodeWatcher, strategy) if err != nil { return trace.Wrap(err) } @@ -444,47 +444,58 @@ func (t *proxySubsys) proxyToHost( return nil } +// NodesGetter is a function that retrieves a subset of nodes matching +// the filter criteria. +type NodesGetter interface { + GetNodes(fn func(n services.Node) bool) []types.Server +} + // getMatchingServer determines the server to connect to from the provided servers. Duplicate entries are treated // differently based on strategy. Legacy behavior of returning an ambiguous error occurs if the strategy // is types.RoutingStrategy_UNAMBIGUOUS_MATCH. When the strategy is types.RoutingStrategy_MOST_RECENT then // the server that has heartbeated most recently will be returned instead of an error. If no matches are found then // both the types.Server and error returned will be nil. -func (t *proxySubsys) getMatchingServer(servers []types.Server, strategy types.RoutingStrategy) (types.Server, error) { +func (t *proxySubsys) getMatchingServer(watcher NodesGetter, strategy types.RoutingStrategy) (types.Server, error) { + if watcher == nil { + return nil, trace.NotFound("unable to retrieve nodes matching host %s", t.host) + } + // check if hostname is a valid uuid. If it is, we will preferentially match // by node ID over node hostname. hostIsUUID := uuid.Parse(t.host) != nil ips, _ := net.LookupHost(t.host) + var unambiguousIDMatch bool // enumerate and try to find a server with self-registered with a matching name/IP: - var matches []types.Server - for _, server := range servers { - // If the host parameter is a UUID, and it matches the Node ID, - // treat this as an unambiguous match. + matches := watcher.GetNodes(func(server services.Node) bool { + if unambiguousIDMatch { + return false + } + + // If the host parameter is a UUID or EC2 node ID, and it matches the + // Node ID, treat this as an unambiguous match. if hostIsUUID && server.GetName() == t.host { - matches = []types.Server{server} - break + unambiguousIDMatch = true + return true } // If the server has connected over a reverse tunnel, match only on hostname. if server.GetUseTunnel() { - if t.host == server.GetHostname() { - matches = append(matches, server) - } - continue + return t.host == server.GetHostname() } ip, port, err := net.SplitHostPort(server.GetAddr()) if err != nil { t.log.Errorf("Failed to parse address %q: %v.", server.GetAddr(), err) - continue + return false } if t.host == ip || t.host == server.GetHostname() || apiutils.SliceContainsStr(ips, ip) { if !t.SpecifiedPort() || t.port == port { - matches = append(matches, server) - continue + return true } } - } + return false + }) var server types.Server switch { diff --git a/lib/srv/regular/proxy_test.go b/lib/srv/regular/proxy_test.go index 15b3667ed7769..aca83cf880318 100644 --- a/lib/srv/regular/proxy_test.go +++ b/lib/srv/regular/proxy_test.go @@ -21,10 +21,12 @@ import ( "time" "github.com/google/uuid" + "github.com/stretchr/testify/require" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" - "github.com/stretchr/testify/require" ) func TestParseProxyRequest(t *testing.T) { @@ -129,6 +131,21 @@ func TestParseBadRequests(t *testing.T) { } } +type nodeGetter struct { + servers []types.Server +} + +func (n nodeGetter) GetNodes(fn func(n services.Node) bool) []types.Server { + var servers []types.Server + for _, s := range n.servers { + if fn(s) { + servers = append(servers, s) + } + } + + return servers +} + func TestProxySubsys_getMatchingServer(t *testing.T) { t.Parallel() @@ -295,7 +312,7 @@ func TestProxySubsys_getMatchingServer(t *testing.T) { srv: &Server{}, } - server, err := subsystem.getMatchingServer(tt.servers, tt.strategy) + server, err := subsystem.getMatchingServer(nodeGetter{tt.servers}, tt.strategy) tt.expectError(t, err) if tt.expectServer != nil { require.Equal(t, tt.expectServer(tt.servers), server) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 488d09f7c7958..1aafd69b52aed 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -183,6 +183,9 @@ type Server struct { // lockWatcher is the server's lock watcher. lockWatcher *services.LockWatcher + + // nodeWatcher is the server's node watcher. + nodeWatcher *services.NodeWatcher } // GetClock returns server clock implementation @@ -549,6 +552,14 @@ func SetLockWatcher(lockWatcher *services.LockWatcher) ServerOption { } } +// SetNodeWatcher sets the server's node watcher. +func SetNodeWatcher(nodeWatcher *services.NodeWatcher) ServerOption { + return func(s *Server) error { + s.nodeWatcher = nodeWatcher + return nil + } +} + // New returns an unstarted server func New(addr utils.NetAddr, hostname string, diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index e0bc6d3b5a6a6..c96dea7a41753 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -948,6 +948,8 @@ func TestProxyReverseTunnel(t *testing.T) { defer listener.Close() lockWatcher := newLockWatcher(ctx, t, proxyClient) + nodeWatcher := newNodeWatcher(ctx, t, proxyClient) + reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClientTLS: proxyClient.TLSConfig(), ID: hostID, @@ -964,6 +966,7 @@ func TestProxyReverseTunnel(t *testing.T) { Emitter: proxyClient, Log: logger, LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, }) require.NoError(t, err) require.NoError(t, reverseTunnelServer.Start()) @@ -1137,6 +1140,7 @@ func TestProxyRoundRobin(t *testing.T) { listener, reverseTunnelAddress := mustListen(t) defer listener.Close() lockWatcher := newLockWatcher(ctx, t, proxyClient) + nodeWatcher := newNodeWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClusterName: f.testSrv.ClusterName(), @@ -1153,6 +1157,7 @@ func TestProxyRoundRobin(t *testing.T) { Emitter: proxyClient, Log: logger, LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, }) require.NoError(t, err) logger.WithField("tun-addr", reverseTunnelAddress.String()).Info("Created reverse tunnel server.") @@ -1177,6 +1182,7 @@ func TestProxyRoundRobin(t *testing.T) { SetRestrictedSessionManager(&restricted.NOP{}), SetClock(f.clock), SetLockWatcher(lockWatcher), + SetNodeWatcher(nodeWatcher), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1259,6 +1265,7 @@ func TestProxyDirectAccess(t *testing.T) { logger := logrus.WithField("test", "TestProxyDirectAccess") proxyClient, _ := newProxyClient(t, f.testSrv) lockWatcher := newLockWatcher(ctx, t, proxyClient) + nodeWatcher := newNodeWatcher(ctx, t, proxyClient) reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ ClientTLS: proxyClient.TLSConfig(), @@ -1275,6 +1282,7 @@ func TestProxyDirectAccess(t *testing.T) { Emitter: proxyClient, Log: logger, LockWatcher: lockWatcher, + NodeWatcher: nodeWatcher, }) require.NoError(t, err) @@ -1300,6 +1308,7 @@ func TestProxyDirectAccess(t *testing.T) { SetRestrictedSessionManager(&restricted.NOP{}), SetClock(f.clock), SetLockWatcher(lockWatcher), + SetNodeWatcher(nodeWatcher), ) require.NoError(t, err) require.NoError(t, proxy.Start()) @@ -1967,6 +1976,18 @@ func newLockWatcher(ctx context.Context, t *testing.T, client types.Events) *ser return lockWatcher } +func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *services.NodeWatcher { + nodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: client, + }, + }) + require.NoError(t, err) + t.Cleanup(nodeWatcher.Close) + return nodeWatcher +} + // maxPipeSize is one larger than the maximum pipe size for most operating // systems which appears to be 65536 bytes. // diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 8abe5ff3256d2..dbc0b5d958b63 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -27,7 +27,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net" "net/http" "net/http/cookiejar" @@ -41,9 +40,21 @@ import ( "testing" "time" + "github.com/beevik/etree" + "github.com/gogo/protobuf/proto" + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/gravitational/roundtrip" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + lemma_secret "github.com/mailgun/lemma/secret" + "github.com/pquerna/otp/totp" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "golang.org/x/net/websocket" "golang.org/x/text/encoding/unicode" + kyaml "k8s.io/apimachinery/pkg/util/yaml" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/webclient" @@ -74,29 +85,10 @@ import ( "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/ui" - - "github.com/gravitational/roundtrip" - "github.com/gravitational/trace" - - "github.com/beevik/etree" - "github.com/gogo/protobuf/proto" - "github.com/google/go-cmp/cmp" - "github.com/jonboulle/clockwork" - lemma_secret "github.com/mailgun/lemma/secret" - "github.com/pborman/uuid" - "github.com/pquerna/otp/totp" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" - . "gopkg.in/check.v1" - kyaml "k8s.io/apimachinery/pkg/util/yaml" ) const hostID = "00000000-0000-0000-0000-000000000000" -func TestWeb(t *testing.T) { - TestingT(t) -} - type WebSuite struct { ctx context.Context cancel context.CancelFunc @@ -115,8 +107,6 @@ type WebSuite struct { clock clockwork.FakeClock } -var _ = Suite(&WebSuite{}) - // TestMain will re-execute Teleport to run a command if "exec" is passed to // it as an argument. Otherwise it will run tests as normal. func TestMain(m *testing.M) { @@ -134,31 +124,37 @@ func TestMain(m *testing.M) { os.Exit(code) } -func (s *WebSuite) SetUpSuite(c *C) { - os.Unsetenv(teleport.DebugEnvVar) +func newWebSuite(t *testing.T) *WebSuite { + mockU2F, err := mocku2f.Create() + require.NoError(t, err) + require.NotNil(t, mockU2F) - var err error - s.mockU2F, err = mocku2f.Create() - c.Assert(err, IsNil) - c.Assert(s.mockU2F, NotNil) -} + u, err := user.Current() + require.NoError(t, err) -func (s *WebSuite) SetUpTest(c *C) { - s.ctx, s.cancel = context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + s := &WebSuite{ + mockU2F: mockU2F, + clock: clockwork.NewFakeClock(), + user: u.Username, + ctx: ctx, + cancel: cancel, + } - u, err := user.Current() - c.Assert(err, IsNil) - s.user = u.Username - s.clock = clockwork.NewFakeClock() + networkingConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{ + KeepAliveInterval: types.Duration(10 * time.Second), + }) + require.NoError(t, err) s.server, err = auth.NewTestServer(auth.TestServerConfig{ Auth: auth.TestAuthServerConfig{ - ClusterName: "localhost", - Dir: c.MkDir(), - Clock: s.clock, + ClusterName: "localhost", + Dir: t.TempDir(), + Clock: s.clock, + ClusterNetworkingConfig: networkingConfig, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) // Register the auth server, since test auth server doesn't start its own // heartbeat. @@ -175,7 +171,7 @@ func (s *WebSuite) SetUpTest(c *C) { Version: teleport.Version, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) // start node certs, err := s.server.Auth().GenerateServerKeys(auth.GenerateServerKeysRequest{ @@ -183,10 +179,10 @@ func (s *WebSuite) SetUpTest(c *C) { NodeName: s.server.ClusterName(), Roles: types.SystemRoles{types.RoleNode}, }) - c.Assert(err, IsNil) + require.NoError(t, err) signer, err := sshutils.NewSigner(certs.Key, certs.Cert) - c.Assert(err, IsNil) + require.NoError(t, err) nodeID := "node" nodeClient, err := s.server.NewClient(auth.TestIdentity{ @@ -195,7 +191,8 @@ func (s *WebSuite) SetUpTest(c *C) { Username: nodeID, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) + require.NoError(t, err) nodeLockWatcher, err := services.NewLockWatcher(s.ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ @@ -203,10 +200,10 @@ func (s *WebSuite) SetUpTest(c *C) { Client: nodeClient, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) // create SSH service: - nodeDataDir := c.MkDir() + nodeDataDir := t.TempDir() node, err := regular.New( utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), @@ -226,12 +223,11 @@ func (s *WebSuite) SetUpTest(c *C) { regular.SetClock(s.clock), regular.SetLockWatcher(nodeLockWatcher), ) - c.Assert(err, IsNil) + require.NoError(t, err) s.node = node s.srvID = node.ID() - c.Assert(s.node.Start(), IsNil) - - c.Assert(auth.CreateUploaderDir(nodeDataDir), IsNil) + require.NoError(t, s.node.Start()) + require.NoError(t, auth.CreateUploaderDir(nodeDataDir)) // create reverse tunnel service: proxyID := "proxy" @@ -241,10 +237,10 @@ func (s *WebSuite) SetUpTest(c *C) { Username: proxyID, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) revTunListener, err := net.Listen("tcp", fmt.Sprintf("%v:0", s.server.ClusterName())) - c.Assert(err, IsNil) + require.NoError(t, err) proxyLockWatcher, err := services.NewLockWatcher(s.ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ @@ -252,7 +248,15 @@ func (s *WebSuite) SetUpTest(c *C) { Client: s.proxyClient, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) + + proxyNodeWatcher, err := services.NewNodeWatcher(s.ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.proxyClient, + }, + }) + require.NoError(t, err) revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), @@ -265,10 +269,11 @@ func (s *WebSuite) SetUpTest(c *C) { Emitter: s.proxyClient, NewCachingAccessPoint: auth.NoCache, DirectClusters: []reversetunnel.DirectCluster{{Name: s.server.ClusterName(), Client: s.proxyClient}}, - DataDir: c.MkDir(), + DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, + NodeWatcher: proxyNodeWatcher, }) - c.Assert(err, IsNil) + require.NoError(t, err) s.proxyTunnel = revTunServer // proxy server: @@ -277,7 +282,7 @@ func (s *WebSuite) SetUpTest(c *C) { s.server.ClusterName(), []ssh.Signer{signer}, s.proxyClient, - c.MkDir(), + t.TempDir(), "", utils.NetAddr{}, regular.SetUUID(proxyID), @@ -289,13 +294,14 @@ func (s *WebSuite) SetUpTest(c *C) { regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(s.clock), regular.SetLockWatcher(proxyLockWatcher), + regular.SetNodeWatcher(proxyNodeWatcher), ) - c.Assert(err, IsNil) + require.NoError(t, err) // Expired sessions are purged immediately var sessionLingeringThreshold time.Duration = 0 fs, err := NewDebugFileSystem("../../webassets/teleport") - c.Assert(err, IsNil) + require.NoError(t, err) handler, err := NewHandler(Config{ Proxy: revTunServer, AuthServers: utils.FromAddr(s.server.TLS.Addr()), @@ -309,22 +315,22 @@ func (s *WebSuite) SetUpTest(c *C) { StaticFS: fs, cachedSessionLingeringThreshold: &sessionLingeringThreshold, }, SetSessionStreamPollPeriod(200*time.Millisecond), SetClock(s.clock)) - c.Assert(err, IsNil) + require.NoError(t, err) s.webServer = httptest.NewUnstartedServer(handler) s.webServer.StartTLS() err = s.proxy.Start() - c.Assert(err, IsNil) + require.NoError(t, err) // Wait for proxy to fully register before starting the test. for start := time.Now(); ; { proxies, err := s.proxyClient.GetProxies() - c.Assert(err, IsNil) + require.NoError(t, err) if len(proxies) != 0 { break } if time.Since(start) > 5*time.Second { - c.Fatal("proxy didn't register within 5s after startup") + t.Fatal("proxy didn't register within 5s after startup") } } @@ -334,27 +340,33 @@ func (s *WebSuite) SetUpTest(c *C) { handler.handler.cfg.ProxyWebAddr = *addr handler.handler.cfg.ProxySSHAddr = *proxyAddr _, sshPort, err := net.SplitHostPort(proxyAddr.String()) - c.Assert(err, IsNil) + require.NoError(t, err) handler.handler.sshPort = sshPort -} -func (s *WebSuite) TearDownTest(c *C) { - // In particular close the lock watchers by cancelling the context. - s.cancel() + t.Cleanup(func() { + // In particular close the lock watchers by cancelling the context. + s.cancel() - var errors []error - s.proxyTunnel.Close() - if err := s.node.Close(); err != nil { - errors = append(errors, err) - } - s.webServer.Close() - if err := s.proxy.Close(); err != nil { - errors = append(errors, err) - } - if err := s.server.Shutdown(context.Background()); err != nil { - errors = append(errors, err) - } - c.Assert(errors, HasLen, 0) + s.webServer.Close() + + var errors []error + if err := s.proxyTunnel.Close(); err != nil { + errors = append(errors, err) + } + if err := s.node.Close(); err != nil { + errors = append(errors, err) + } + s.webServer.Close() + if err := s.proxy.Close(); err != nil { + errors = append(errors, err) + } + if err := s.server.Shutdown(context.Background()); err != nil { + errors = append(errors, err) + } + require.Empty(t, errors) + }) + + return s } func (r *authPack) renewSession(ctx context.Context, t *testing.T) *roundtrip.Response { @@ -379,7 +391,7 @@ type authPack struct { // authPack returns new authenticated package consisting of created valid // user, otp token, created web session and authenticated client. -func (s *WebSuite) authPack(c *C, user string) *authPack { +func (s *WebSuite) authPack(t *testing.T, user string) *authPack { login := s.user pass := "abc123" rawSecret := "def456" @@ -389,15 +401,15 @@ func (s *WebSuite) authPack(c *C, user string) *authPack { Type: constants.Local, SecondFactor: constants.SecondFactorOTP, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) - s.createUser(c, user, login, pass, otpSecret) + s.createUser(t, user, login, pass, otpSecret) // create a valid otp token validToken, err := totp.GenerateCode(otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() req := CreateSessionReq{ @@ -408,16 +420,16 @@ func (s *WebSuite) authPack(c *C, user string) *authPack { csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" re, err := s.login(clt, csrfToken, csrfToken, req) - c.Assert(err, IsNil) + require.NoError(t, err) var rawSess *CreateSessionResponse - c.Assert(json.Unmarshal(re.Bytes(), &rawSess), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) sess, err := rawSess.response() - c.Assert(err, IsNil) + require.NoError(t, err) jar, err := cookiejar.New(nil) - c.Assert(err, IsNil) + require.NoError(t, err) clt = s.client(roundtrip.BearerAuth(sess.Token), roundtrip.CookieJar(jar)) jar.SetCookies(s.url(), re.Cookies()) @@ -432,32 +444,32 @@ func (s *WebSuite) authPack(c *C, user string) *authPack { } } -func (s *WebSuite) createUser(c *C, user string, login string, pass string, otpSecret string) { +func (s *WebSuite) createUser(t *testing.T, user string, login string, pass string, otpSecret string) { teleUser, err := types.NewUser(user) - c.Assert(err, IsNil) + require.NoError(t, err) role := services.RoleForUser(teleUser) role.SetLogins(services.Allow, []string{login}) options := role.GetOptions() options.ForwardAgent = types.NewBool(true) role.SetOptions(options) err = s.server.Auth().UpsertRole(s.ctx, role) - c.Assert(err, IsNil) + require.NoError(t, err) teleUser.AddRole(role.GetName()) teleUser.SetCreatedBy(types.CreatedBy{ User: types.UserRef{Name: "some-auth-user"}, }) err = s.server.Auth().CreateUser(s.ctx, teleUser) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertPassword(user, []byte(pass)) - c.Assert(err, IsNil) + require.NoError(t, err) if otpSecret != "" { dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertMFADevice(context.Background(), user, dev) - c.Assert(err, IsNil) + require.NoError(t, err) } } @@ -480,18 +492,20 @@ func TestValidRedirectURL(t *testing.T) { } } -func (s *WebSuite) TestSAMLSuccess(c *C) { +func TestSAMLSuccess(t *testing.T) { + t.Parallel() + s := newWebSuite(t) input := fixtures.SAMLOktaConnectorV2 decoder := kyaml.NewYAMLOrJSONDecoder(strings.NewReader(input), defaults.LookaheadBufSize) var raw services.UnknownResource err := decoder.Decode(&raw) - c.Assert(err, IsNil) + require.NoError(t, err) connector, err := services.UnmarshalSAMLConnector(raw.Raw) - c.Assert(err, IsNil) + require.NoError(t, err) err = services.ValidateSAMLConnector(connector) - c.Assert(err, IsNil) + require.NoError(t, err) role, err := types.NewRole(connector.GetAttributesToRoles()[0].Roles[0], types.RoleSpecV4{ Options: types.RoleOptions{ @@ -505,64 +519,64 @@ func (s *WebSuite) TestSAMLSuccess(c *C) { }, }, }) - c.Assert(err, IsNil) - role.SetLogins(services.Allow, []string{s.user}) + require.NoError(t, err) + role.SetLogins(types.Allow, []string{s.user}) err = s.server.Auth().UpsertRole(s.ctx, role) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().CreateSAMLConnector(connector) - c.Assert(err, IsNil) - s.server.Auth().SetClock(clockwork.NewFakeClockAt(time.Date(2017, 05, 10, 18, 53, 0, 0, time.UTC))) + require.NoError(t, err) + s.server.Auth().SetClock(clockwork.NewFakeClockAt(time.Date(2017, 5, 10, 18, 53, 0, 0, time.UTC))) clt := s.clientNoRedirects() csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" - baseURL, err := url.Parse(clt.Endpoint("webapi", "saml", "sso") + `?redirect_url=http://localhost/after;connector_id=` + connector.GetName()) - c.Assert(err, IsNil) + baseURL, err := url.Parse(clt.Endpoint("webapi", "saml", "sso") + `?connector_id=` + connector.GetName() + `&redirect_url=http://localhost/after`) + require.NoError(t, err) req, err := http.NewRequest("GET", baseURL.String(), nil) - c.Assert(err, IsNil) + require.NoError(t, err) addCSRFCookieToReq(req, csrfToken) re, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) // we got a redirect urlPattern := regexp.MustCompile(`URL='([^']*)'`) locationURL := urlPattern.FindStringSubmatch(string(re.Bytes()))[1] u, err := url.Parse(locationURL) - c.Assert(err, IsNil) - c.Assert(u.Scheme+"://"+u.Host+u.Path, Equals, fixtures.SAMLOktaSSO) + require.NoError(t, err) + require.Equal(t, fixtures.SAMLOktaSSO, u.Scheme+"://"+u.Host+u.Path) data, err := base64.StdEncoding.DecodeString(u.Query().Get("SAMLRequest")) - c.Assert(err, IsNil) - buf, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(data))) - c.Assert(err, IsNil) + require.NoError(t, err) + buf, err := io.ReadAll(flate.NewReader(bytes.NewReader(data))) + require.NoError(t, err) doc := etree.NewDocument() err = doc.ReadFromBytes(buf) - c.Assert(err, IsNil) + require.NoError(t, err) id := doc.Root().SelectAttr("ID") - c.Assert(id, NotNil) + require.NotNil(t, id) authRequest, err := s.server.Auth().GetSAMLAuthRequest(id.Value) - c.Assert(err, IsNil) + require.NoError(t, err) // now swap the request id to the hardcoded one in fixtures authRequest.ID = fixtures.SAMLOktaAuthRequestID authRequest.CSRFToken = csrfToken err = s.server.Auth().Identity.CreateSAMLAuthRequest(*authRequest, backend.Forever) - c.Assert(err, IsNil) + require.NoError(t, err) // now respond with pre-recorded request to the POST url in := &bytes.Buffer{} fw, err := flate.NewWriter(in, flate.DefaultCompression) - c.Assert(err, IsNil) + require.NoError(t, err) _, err = fw.Write([]byte(fixtures.SAMLOktaAuthnResponseXML)) - c.Assert(err, IsNil) + require.NoError(t, err) err = fw.Close() - c.Assert(err, IsNil) + require.NoError(t, err) encodedResponse := base64.StdEncoding.EncodeToString(in.Bytes()) - c.Assert(encodedResponse, NotNil) + require.NotNil(t, encodedResponse) // now send the response to the server to exchange it for auth session form := url.Values{} @@ -570,53 +584,57 @@ func (s *WebSuite) TestSAMLSuccess(c *C) { req, err = http.NewRequest("POST", clt.Endpoint("webapi", "saml", "acs"), strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") addCSRFCookieToReq(req, csrfToken) - c.Assert(err, IsNil) + require.NoError(t, err) authRe, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) - comment := Commentf("Response: %v", string(authRe.Bytes())) - c.Assert(authRe.Code(), Equals, http.StatusFound, comment) + require.NoError(t, err) + require.Equal(t, http.StatusFound, authRe.Code(), "Response: %v", string(authRe.Bytes())) // we have got valid session - c.Assert(authRe.Headers().Get("Set-Cookie"), Not(Equals), "") - // we are being redirected to orignal URL - c.Assert(authRe.Headers().Get("Location"), Equals, "/after") + require.NotEmpty(t, authRe.Headers().Get("Set-Cookie")) + // we are being redirected to original URL + require.Equal(t, "/after", authRe.Headers().Get("Location")) } -func (s *WebSuite) TestWebSessionsCRUD(c *C) { - pack := s.authPack(c, "foo") +func TestWebSessionsCRUD(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") // make sure we can use client to make authenticated requests re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var clusters []ui.Cluster - c.Assert(json.Unmarshal(re.Bytes(), &clusters), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &clusters)) // now delete session _, err = pack.clt.Delete( context.Background(), pack.clt.Endpoint("webapi", "sessions")) - c.Assert(err, IsNil) + require.NoError(t, err) // subsequent requests trying to use this session will fail _, err = pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) } -func (s *WebSuite) TestNamespace(c *C) { - pack := s.authPack(c, "foo") +func TestNamespace(t *testing.T) { + s := newWebSuite(t) + pack := s.authPack(t, "foo") _, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "namespaces", "..%252fevents%3f", "nodes"), url.Values{}) - c.Assert(err, NotNil) + require.Error(t, err) _, err = pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "namespaces", "default", "nodes"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *WebSuite) TestCSRF(c *C) { +func TestCSRF(t *testing.T) { + t.Parallel() + s := newWebSuite(t) type input struct { reqToken string cookieToken string @@ -626,11 +644,11 @@ func (s *WebSuite) TestCSRF(c *C) { user := "csrfuser" pass := "abc123" otpSecret := base32.StdEncoding.EncodeToString([]byte("def456")) - s.createUser(c, user, user, pass, otpSecret) + s.createUser(t, user, user, pass, otpSecret) // create a valid login form request validToken, err := totp.GenerateCode(otpSecret, time.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) loginForm := CreateSessionReq{ User: user, Pass: pass, @@ -650,23 +668,25 @@ func (s *WebSuite) TestCSRF(c *C) { // valid _, err = s.login(clt, encodedToken1, encodedToken1, loginForm) - c.Assert(err, IsNil) + require.NoError(t, err) // invalid for i := range invalid { _, err := s.login(clt, invalid[i].cookieToken, invalid[i].reqToken, loginForm) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) } } -func (s *WebSuite) TestPasswordChange(c *C) { - pack := s.authPack(c, "foo") +func TestPasswordChange(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") // invalidate the token s.clock.Advance(1 * time.Minute) validToken, err := totp.GenerateCode(pack.otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) req := changePasswordReq{ OldPassword: []byte("abc123"), @@ -675,26 +695,28 @@ func (s *WebSuite) TestPasswordChange(c *C) { } _, err = pack.clt.PutJSON(context.Background(), pack.clt.Endpoint("webapi", "users", "password"), req) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *WebSuite) TestWebSessionsBadInput(c *C) { +func TestWebSessionsBadInput(t *testing.T) { + t.Parallel() + s := newWebSuite(t) user := "bob" pass := "abc123" rawSecret := "def456" otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret)) err := s.server.Auth().UpsertPassword(user, []byte(pass)) - c.Assert(err, IsNil) + require.NoError(t, err) dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertMFADevice(context.Background(), user, dev) - c.Assert(err, IsNil) + require.NoError(t, err) // create valid token validToken, err := totp.GenerateCode(otpSecret, time.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() @@ -730,9 +752,11 @@ func (s *WebSuite) TestWebSessionsBadInput(c *C) { }, } for i, req := range reqs { - _, err = clt.PostJSON(context.Background(), clt.Endpoint("webapi", "sessions"), req) - c.Assert(err, NotNil, Commentf("tc %v", i)) - c.Assert(trace.IsAccessDenied(err), Equals, true, Commentf("tc %v %T is not access denied", i, err)) + t.Run(fmt.Sprintf("tc %v", i), func(t *testing.T) { + _, err := clt.PostJSON(s.ctx, clt.Endpoint("webapi", "sessions"), req) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + }) } } @@ -740,33 +764,37 @@ type getSiteNodeResponse struct { Items []ui.Server `json:"items"` } -func (s *WebSuite) TestGetSiteNodes(c *C) { - pack := s.authPack(c, "foo") +func TestGetSiteNodes(t *testing.T) { + s := newWebSuite(t) + pack := s.authPack(t, "foo") // get site nodes re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "nodes"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) nodes := getSiteNodeResponse{} - c.Assert(json.Unmarshal(re.Bytes(), &nodes), IsNil) - c.Assert(len(nodes.Items), Equals, 1) + require.NoError(t, json.Unmarshal(re.Bytes(), &nodes)) + require.Len(t, nodes.Items, 1) // get site nodes using shortcut re, err = pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", currentSiteShortcut, "nodes"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) nodes2 := getSiteNodeResponse{} - c.Assert(json.Unmarshal(re.Bytes(), &nodes2), IsNil) - c.Assert(len(nodes.Items), Equals, 1) - c.Assert(nodes2, DeepEquals, nodes) + require.NoError(t, json.Unmarshal(re.Bytes(), &nodes2)) + require.Len(t, nodes.Items, 1) + require.Empty(t, cmp.Diff(nodes, nodes2)) } -func (s *WebSuite) TestSiteNodeConnectInvalidSessionID(c *C) { - _, err := s.makeTerminal(s.authPack(c, "foo"), session.ID("/../../../foo")) - c.Assert(err, NotNil) +func TestSiteNodeConnectInvalidSessionID(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + _, err := s.makeTerminal(s.authPack(t, "foo"), session.ID("/../../../foo")) + require.Error(t, err) } -func (s *WebSuite) TestResolveServerHostPort(c *C) { +func TestResolveServerHostPort(t *testing.T) { + t.Parallel() sampleNode := types.ServerV2{} sampleNode.SetName("eca53e45-86a9-11e7-a893-0242ac0a0101") sampleNode.Spec.Hostname = "nodehostname" @@ -825,20 +853,20 @@ func (s *WebSuite) TestResolveServerHostPort(c *C) { for _, testCase := range validCases { host, port, err := resolveServerHostPort(testCase.server, testCase.nodes) - c.Assert(err, IsNil, Commentf(testCase.server)) - c.Assert(host, Equals, testCase.expectedHost) - c.Assert(port, Equals, testCase.expectedPort) + require.NoError(t, err, testCase.server) + require.Equal(t, testCase.expectedHost, host, testCase.server) + require.Equal(t, testCase.expectedPort, port, testCase.server) } for _, testCase := range invalidCases { _, _, err := resolveServerHostPort(testCase.server, nil) - c.Assert(err, NotNil, Commentf(testCase.expectedErr)) - c.Assert(err, ErrorMatches, ".*"+testCase.expectedErr+".*") + require.Error(t, err, testCase.server) + require.Regexp(t, ".*"+testCase.expectedErr+".*", err.Error(), testCase.server) } } -func (s *WebSuite) TestNewTerminalHandler(c *C) { +func TestNewTerminalHandler(t *testing.T) { validNode := types.ServerV2{} validNode.SetName("eca53e45-86a9-11e7-a893-0242ac0a0101") validNode.Spec.Hostname = "nodehostname" @@ -929,129 +957,132 @@ func (s *WebSuite) TestNewTerminalHandler(c *C) { }, } + ctx := context.Background() for _, testCase := range validCases { - term, err := NewTerminal(s.ctx, testCase.req, testCase.authProvider, nil) - c.Assert(err, IsNil) - c.Assert(term.params, DeepEquals, testCase.req) - c.Assert(term.hostName, Equals, testCase.expectedHost) - c.Assert(term.hostPort, Equals, testCase.expectedPort) + term, err := NewTerminal(ctx, testCase.req, testCase.authProvider, nil) + require.NoError(t, err) + require.Empty(t, cmp.Diff(testCase.req, term.params)) + require.Equal(t, testCase.expectedHost, testCase.expectedHost) + require.Equal(t, testCase.expectedPort, testCase.expectedPort) } for _, testCase := range invalidCases { - _, err := NewTerminal(s.ctx, testCase.req, testCase.authProvider, nil) - c.Assert(err, ErrorMatches, ".*"+testCase.expectedErr+".*") + _, err := NewTerminal(ctx, testCase.req, testCase.authProvider, nil) + require.Regexp(t, ".*"+testCase.expectedErr+".*", err.Error()) } } -func (s *WebSuite) TestResizeTerminal(c *C) { +func TestResizeTerminal(t *testing.T) { + t.Parallel() + s := newWebSuite(t) sid := session.NewID() // Create a new user "foo", open a terminal to a new session, and wait for // it to be ready. - pack1 := s.authPack(c, "foo") + pack1 := s.authPack(t, "foo") ws1, err := s.makeTerminal(pack1, sid) - c.Assert(err, IsNil) - defer ws1.Close() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws1.Close()) }) err = s.waitForRawEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // Create a new user "bar", open a terminal to the session created above, // and wait for it to be ready. - pack2 := s.authPack(c, "bar") + pack2 := s.authPack(t, "bar") ws2, err := s.makeTerminal(pack2, sid) - c.Assert(err, IsNil) - defer ws2.Close() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws2.Close()) }) err = s.waitForRawEvent(ws2, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // Look at the audit events for the first terminal. It should have two // resize events from the second terminal (80x25 default then 100x100). Only // the second terminal will get these because resize events are not sent // back to the originator. err = s.waitForResizeEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.waitForResizeEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // Look at the stream events for the second terminal. We don't expect to see // any resize events yet. It will timeout. err = s.waitForResizeEvent(ws2, 1*time.Second) - c.Assert(err, NotNil) + require.Error(t, err) // Resize the second terminal. This should be reflected on the first terminal // because resize events are not sent back to the originator. params, err := session.NewTerminalParamsFromInt(300, 120) - c.Assert(err, IsNil) + require.NoError(t, err) data, err := json.Marshal(events.EventFields{ events.EventType: events.ResizeEvent, events.EventNamespace: apidefaults.Namespace, events.SessionEventID: sid.String(), events.TerminalSize: params.Serialize(), }) - c.Assert(err, IsNil) + require.NoError(t, err) envelope := &Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketResize, Payload: string(data), } envelopeBytes, err := proto.Marshal(envelope) - c.Assert(err, IsNil) + require.NoError(t, err) err = websocket.Message.Send(ws2, envelopeBytes) - c.Assert(err, IsNil) + require.NoError(t, err) // This time the first terminal will see the resize event. err = s.waitForResizeEvent(ws1, 5*time.Second) - c.Assert(err, IsNil) + require.NoError(t, err) // The second terminal will not see any resize event. It will timeout. err = s.waitForResizeEvent(ws2, 1*time.Second) - c.Assert(err, NotNil) + require.Error(t, err) } -func (s *WebSuite) TestTerminal(c *C) { - ws, err := s.makeTerminal(s.authPack(c, "foo")) - c.Assert(err, IsNil) - defer ws.Close() +func TestTerminal(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ws, err := s.makeTerminal(s.authPack(t, "foo")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) _, err = io.WriteString(stream, "echo vinsong\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) err = waitForOutput(stream, "vinsong") - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *WebSuite) TestWebsocketPingLoop(c *C) { +func TestWebsocketPingLoop(t *testing.T) { + t.Parallel() + s := newWebSuite(t) // Change cluster networking config for keep alive interval to be run faster. netConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{ KeepAliveInterval: types.NewDuration(250 * time.Millisecond), }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetClusterNetworkingConfig(s.ctx, netConfig) - c.Assert(err, IsNil) + require.NoError(t, err) recConfig, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{ Mode: types.RecordAtNode, ProxyChecksHostKeys: types.NewBoolOption(true), }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetSessionRecordingConfig(s.ctx, recConfig) - c.Assert(err, IsNil) - - ws, err := s.makeTerminal(s.authPack(c, "foo")) - c.Assert(err, IsNil) + require.NoError(t, err) - // flush out raw event (pty texts) - err = s.waitForRawEvent(ws, 5*time.Second) - c.Assert(err, IsNil) + ws, err := s.makeTerminal(s.authPack(t, "foo")) + require.NoError(t, err) var numPings int start := time.Now() for { frame, err := ws.NewFrameReader() - c.Assert(err, IsNil) + require.NoError(t, err) // We should get a mix of output (binary) and ping frames. Count only // the ping frames. if int(frame.PayloadType()) == websocket.PingFrame { @@ -1061,86 +1092,91 @@ func (s *WebSuite) TestWebsocketPingLoop(c *C) { break } if time.Since(start) > 5*time.Second { - c.Fatalf("received %d ping frames within 5s of opening a socket, expected at least 2", numPings) + t.Fatalf("received %d ping frames within 5s of opening a socket, expected at least 2", numPings) } } - err = ws.Close() - c.Assert(err, IsNil) + require.NoError(t, ws.Close()) } -func (s *WebSuite) TestWebAgentForward(c *C) { - ws, err := s.makeTerminal(s.authPack(c, "foo")) - c.Assert(err, IsNil) - defer ws.Close() +func TestWebAgentForward(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ws, err := s.makeTerminal(s.authPack(t, "foo")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) _, err = io.WriteString(stream, "echo $SSH_AUTH_SOCK\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) err = waitForOutput(stream, "/") - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *WebSuite) TestActiveSessions(c *C) { +func TestActiveSessions(t *testing.T) { + t.Parallel() + s := newWebSuite(t) sid := session.NewID() - pack := s.authPack(c, "foo") + pack := s.authPack(t, "foo") ws, err := s.makeTerminal(pack, sid) - c.Assert(err, IsNil) - defer ws.Close() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) // To make sure we have a session. _, err = io.WriteString(stream, "echo vinsong\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) // Make sure server has replied. err = waitForOutput(stream, "vinsong") - c.Assert(err, IsNil) + require.NoError(t, err) // Make sure this session appears in the list of active sessions. var sessResp *siteSessionsGetResponse for i := 0; i < 10; i++ { // Get site nodes and make sure the node has our active party. - re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) - c.Assert(err, IsNil) + re, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) + require.NoError(t, err) - c.Assert(json.Unmarshal(re.Bytes(), &sessResp), IsNil) - c.Assert(len(sessResp.Sessions), Equals, 1) + require.NoError(t, json.Unmarshal(re.Bytes(), &sessResp)) + require.Len(t, sessResp.Sessions, 1) // Sessions do not appear momentarily as there's async heartbeat // procedure. time.Sleep(250 * time.Millisecond) } - c.Assert(len(sessResp.Sessions), Equals, 1) + require.Len(t, sessResp.Sessions, 1) sess := sessResp.Sessions[0] - c.Assert(sess.ID, Equals, sid) - c.Assert(sess.Namespace, Equals, s.node.GetNamespace()) - c.Assert(sess.Parties, NotNil) - c.Assert(sess.TerminalParams.H > 0, Equals, true) - c.Assert(sess.TerminalParams.W > 0, Equals, true) - c.Assert(sess.Login, Equals, pack.login) - c.Assert(sess.Created.IsZero(), Equals, false) - c.Assert(sess.LastActive.IsZero(), Equals, false) - c.Assert(sess.ServerID, Equals, s.srvID) - c.Assert(sess.ServerHostname, Equals, s.node.GetInfo().GetHostname()) - c.Assert(sess.ServerAddr, Equals, s.node.GetInfo().GetAddr()) - c.Assert(sess.ClusterName, Equals, s.server.ClusterName()) + require.Equal(t, sid, sess.ID) + require.Equal(t, s.node.GetNamespace(), sess.Namespace) + require.NotNil(t, sess.Parties) + require.Greater(t, sess.TerminalParams.H, 0) + require.Greater(t, sess.TerminalParams.W, 0) + require.Equal(t, pack.login, sess.Login) + require.False(t, sess.Created.IsZero()) + require.False(t, sess.LastActive.IsZero()) + require.Equal(t, s.srvID, sess.ServerID) + require.Equal(t, s.node.GetInfo().GetHostname(), sess.ServerHostname) + require.Equal(t, s.node.GetInfo().GetAddr(), sess.ServerAddr) + require.Equal(t, s.server.ClusterName(), sess.ClusterName) } // DELETE IN: 5.0.0 // Tests the code snippet from apiserver.(*Handler).siteSessionGet/siteSessionsGet // that tests empty ClusterName and ServerHostname gets set. -func (s *WebSuite) TestEmptySessionClusterHostnameIsSet(c *C) { +func TestEmptySessionClusterHostnameIsSet(t *testing.T) { + t.Parallel() + s := newWebSuite(t) nodeClient, err := s.server.NewClient(auth.TestBuiltin(types.RoleNode)) - c.Assert(err, IsNil) + require.NoError(t, err) // Create a session with empty ClusterName. sess1 := session.Session{ @@ -1154,68 +1190,67 @@ func (s *WebSuite) TestEmptySessionClusterHostnameIsSet(c *C) { TerminalParams: session.TerminalParams{W: 100, H: 100}, } err = nodeClient.CreateSession(sess1) - c.Assert(err, IsNil) + require.NoError(t, err) // Retrieve the session with the empty ClusterName. - pack := s.authPack(c, "baz") - res, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "namespaces", "default", "sessions", sess1.ID.String()), url.Values{}) - c.Assert(err, IsNil) + pack := s.authPack(t, "baz") + res, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions", sess1.ID.String()), url.Values{}) + require.NoError(t, err) // Test that empty ClusterName and ServerHostname got set. var sessionResult *session.Session err = json.Unmarshal(res.Bytes(), &sessionResult) - c.Assert(err, IsNil) - c.Assert(sessionResult.ClusterName, Equals, s.server.ClusterName()) - c.Assert(sessionResult.ServerHostname, Equals, sess1.ServerID) + require.NoError(t, err) + require.Equal(t, s.server.ClusterName(), sessionResult.ClusterName) + require.Equal(t, sess1.ServerID, sessionResult.ServerHostname) // Create another session to test sessions list. sess2 := sess1 sess2.ID = session.NewID() sess2.ServerID = string(session.NewID()) err = nodeClient.CreateSession(sess2) - c.Assert(err, IsNil) + require.NoError(t, err) // Retrieve sessions list. - res, err = pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "namespaces", "default", "sessions"), url.Values{}) - c.Assert(err, IsNil) + res, err = pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) + require.NoError(t, err) var sessionList *siteSessionsGetResponse err = json.Unmarshal(res.Bytes(), &sessionList) - c.Assert(err, IsNil) + require.NoError(t, err) s1 := sessionList.Sessions[0] s2 := sessionList.Sessions[1] - c.Assert(s1.ClusterName, Equals, s.server.ClusterName()) - c.Assert(s2.ClusterName, Equals, s.server.ClusterName()) - c.Assert(s1.ServerHostname, Equals, s1.ServerID) - c.Assert(s2.ServerHostname, Equals, s2.ServerID) + require.Equal(t, s.server.ClusterName(), s1.ClusterName) + require.Equal(t, s.server.ClusterName(), s2.ClusterName) + require.Equal(t, s1.ServerID, s1.ServerHostname) + require.Equal(t, s2.ServerID, s2.ServerHostname) } -func (s *WebSuite) TestCloseConnectionsOnLogout(c *C) { +func TestCloseConnectionsOnLogout(t *testing.T) { + t.Parallel() + s := newWebSuite(t) sid := session.NewID() - pack := s.authPack(c, "foo") + pack := s.authPack(t, "foo") ws, err := s.makeTerminal(pack, sid) - c.Assert(err, IsNil) - defer ws.Close() + require.NoError(t, err) termHandler := newTerminalHandler() stream := termHandler.asTerminalStream(ws) // to make sure we have a session _, err = io.WriteString(stream, "expr 137 + 39\r\n") - c.Assert(err, IsNil) + require.NoError(t, err) // make sure server has replied out := make([]byte, 100) _, err = stream.Read(out) - c.Assert(err, IsNil) + require.NoError(t, err) - _, err = pack.clt.Delete( - context.Background(), - pack.clt.Endpoint("webapi", "sessions")) - c.Assert(err, IsNil) + _, err = pack.clt.Delete(s.ctx, pack.clt.Endpoint("webapi", "sessions")) + require.NoError(t, err) // wait until we timeout or detect that connection has been closed after := time.After(5 * time.Second) @@ -1231,21 +1266,23 @@ func (s *WebSuite) TestCloseConnectionsOnLogout(c *C) { select { case <-after: - c.Fatalf("timeout") + t.Fatalf("timeout") case err := <-errC: - c.Assert(err, Equals, io.EOF) + require.ErrorIs(t, err, io.EOF) } } -func (s *WebSuite) TestCreateSession(c *C) { - pack := s.authPack(c, "foo") +func TestCreateSession(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") // get site nodes re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "nodes"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) nodes := getSiteNodeResponse{} - c.Assert(json.Unmarshal(re.Bytes(), &nodes), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &nodes)) node := nodes.Items[0] sess := session.Session{ @@ -1260,12 +1297,12 @@ func (s *WebSuite) TestCreateSession(c *C) { pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), siteSessionGenerateReq{Session: sess}, ) - c.Assert(err, IsNil) + require.NoError(t, err) var created *siteSessionGenerateResponse - c.Assert(json.Unmarshal(re.Bytes(), &created), IsNil) - c.Assert(created.Session.ID, Not(Equals), "") - c.Assert(created.Session.ServerHostname, Equals, node.Hostname) + require.NoError(t, json.Unmarshal(re.Bytes(), &created)) + require.NotEmpty(t, created.Session.ID) + require.Equal(t, node.Hostname, created.Session.ServerHostname) // test empty serverID (older version does not supply serverID) sess.ServerID = "" @@ -1274,38 +1311,42 @@ func (s *WebSuite) TestCreateSession(c *C) { pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), siteSessionGenerateReq{Session: sess}, ) - c.Assert(err, IsNil) + require.NoError(t, err) } -func (s *WebSuite) TestPlayback(c *C) { - pack := s.authPack(c, "foo") +func TestPlayback(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo") sid := session.NewID() ws, err := s.makeTerminal(pack, sid) - c.Assert(err, IsNil) - defer ws.Close() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) } -func (s *WebSuite) TestLogin(c *C) { +func TestLogin(t *testing.T) { + t.Parallel() + s := newWebSuite(t) ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOff, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) // create user - s.createUser(c, "user1", "root", "password", "") + s.createUser(t, "user1", "root", "password", "") loginReq, err := json.Marshal(CreateSessionReq{ User: "user1", Pass: "password", }) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions"), bytes.NewBuffer(loginReq)) - c.Assert(err, IsNil) + require.NoError(t, err) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" addCSRFCookieToReq(req, csrfToken) @@ -1315,87 +1356,89 @@ func (s *WebSuite) TestLogin(c *C) { re, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) var rawSess *CreateSessionResponse - c.Assert(json.Unmarshal(re.Bytes(), &rawSess), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) cookies := re.Cookies() - c.Assert(len(cookies), Equals, 1) + require.Len(t, cookies, 1) // now make sure we are logged in by calling authenticated method // we need to supply both session cookie and bearer token for // request to succeed jar, err := cookiejar.New(nil) - c.Assert(err, IsNil) + require.NoError(t, err) clt = s.client(roundtrip.BearerAuth(rawSess.Token), roundtrip.CookieJar(jar)) jar.SetCookies(s.url(), re.Cookies()) - re, err = clt.Get(context.Background(), clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, IsNil) + re, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + require.NoError(t, err) var clusters []ui.Cluster - c.Assert(json.Unmarshal(re.Bytes(), &clusters), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &clusters)) // in absence of session cookie or bearer auth the same request fill fail // no session cookie: clt = s.client(roundtrip.BearerAuth(rawSess.Token)) - _, err = clt.Get(context.Background(), clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + _, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) // no bearer token: clt = s.client(roundtrip.CookieJar(jar)) - _, err = clt.Get(context.Background(), clt.Endpoint("webapi", "sites"), url.Values{}) - c.Assert(err, NotNil) - c.Assert(trace.IsAccessDenied(err), Equals, true) + _, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) } -func (s *WebSuite) TestChangePasswordWithTokenOTP(c *C) { +func TestChangePasswordWithTokenOTP(t *testing.T) { + t.Parallel() + s := newWebSuite(t) ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) // create user - s.createUser(c, "user1", "root", "password", "") + s.createUser(t, "user1", "root", "password", "") // create password change token token, err := s.server.Auth().CreateResetPasswordToken(context.TODO(), auth.CreateResetPasswordTokenRequest{ Name: "user1", }) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() re, err := clt.Get(context.Background(), clt.Endpoint("webapi", "users", "password", "token", token.GetName()), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var uiToken *ui.ResetPasswordToken - c.Assert(json.Unmarshal(re.Bytes(), &uiToken), IsNil) - c.Assert(uiToken.User, Equals, token.GetUser()) - c.Assert(uiToken.TokenID, Equals, token.GetName()) + require.NoError(t, json.Unmarshal(re.Bytes(), &uiToken)) + require.Equal(t, token.GetUser(), uiToken.User) + require.Equal(t, token.GetName(), uiToken.TokenID) secrets, err := s.server.Auth().RotateResetPasswordTokenSecrets(context.TODO(), token.GetName()) - c.Assert(err, IsNil) + require.NoError(t, err) // Advance the clock to invalidate the TOTP token s.clock.Advance(1 * time.Minute) secondFactorToken, err := totp.GenerateCode(secrets.GetOTPKey(), s.clock.Now()) - c.Assert(err, IsNil) + require.NoError(t, err) data, err := json.Marshal(auth.ChangePasswordWithTokenRequest{ TokenID: token.GetName(), Password: []byte("abc123"), SecondFactorToken: secondFactorToken, }) - c.Assert(err, IsNil) + require.NoError(t, err) req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(data)) - c.Assert(err, IsNil) + require.NoError(t, err) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" addCSRFCookieToReq(req, csrfToken) @@ -1405,14 +1448,17 @@ func (s *WebSuite) TestChangePasswordWithTokenOTP(c *C) { re, err = clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) var rawSess *CreateSessionResponse - c.Assert(json.Unmarshal(re.Bytes(), &rawSess), IsNil) - c.Assert(rawSess.Token != "", Equals, true) + require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) + require.NotEmpty(t, rawSess.Token) } -func (s *WebSuite) TestChangePasswordWithTokenU2F(c *C) { +func TestChangePasswordWithTokenU2F(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorU2F, @@ -1421,37 +1467,37 @@ func (s *WebSuite) TestChangePasswordWithTokenU2F(c *C) { Facets: []string{"https://" + s.server.ClusterName()}, }, }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) - c.Assert(err, IsNil) + require.NoError(t, err) - s.createUser(c, "user2", "root", "password", "") + s.createUser(t, "user2", "root", "password", "") // create reset password token token, err := s.server.Auth().CreateResetPasswordToken(context.TODO(), auth.CreateResetPasswordTokenRequest{ Name: "user2", }) - c.Assert(err, IsNil) + require.NoError(t, err) clt := s.client() re, err := clt.Get(context.Background(), clt.Endpoint("webapi", "u2f", "signuptokens", token.GetName()), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var u2fRegReq u2f.RegisterChallenge - c.Assert(json.Unmarshal(re.Bytes(), &u2fRegReq), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &u2fRegReq)) u2fRegResp, err := s.mockU2F.RegisterResponse(&u2fRegReq) - c.Assert(err, IsNil) + require.NoError(t, err) data, err := json.Marshal(auth.ChangePasswordWithTokenRequest{ TokenID: token.GetName(), Password: []byte("qweQWE"), U2FRegisterResponse: u2fRegResp, }) - c.Assert(err, IsNil) + require.NoError(t, err) req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(data)) - c.Assert(err, IsNil) + require.NoError(t, err) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" addCSRFCookieToReq(req, csrfToken) @@ -1461,14 +1507,15 @@ func (s *WebSuite) TestChangePasswordWithTokenU2F(c *C) { re, err = clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) - c.Assert(err, IsNil) + require.NoError(t, err) var rawSess *CreateSessionResponse - c.Assert(json.Unmarshal(re.Bytes(), &rawSess), IsNil) - c.Assert(rawSess.Token != "", Equals, true) + require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) + require.NotEmpty(t, rawSess.Token) } func TestU2FLogin(t *testing.T) { + t.Parallel() for _, sf := range []constants.SecondFactorType{ constants.SecondFactorU2F, constants.SecondFactorOptional, @@ -1606,84 +1653,91 @@ func testU2FLogin(t *testing.T, secondFactor constants.SecondFactorType) { // TestPing ensures that a response is returned by /webapi/ping // and that that response body contains authentication information. -func (s *WebSuite) TestPing(c *C) { +func TestPing(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + wc := s.client() re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var out *webclient.PingResponse - c.Assert(json.Unmarshal(re.Bytes(), &out), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &out)) preference, err := s.server.Auth().GetAuthPreference(s.ctx) - c.Assert(err, IsNil) + require.NoError(t, err) - c.Assert(out.Auth.Type, Equals, preference.GetType()) - c.Assert(out.Auth.SecondFactor, Equals, preference.GetSecondFactor()) + require.Equal(t, out.Auth.Type, preference.GetType()) + require.Equal(t, out.Auth.SecondFactor, preference.GetSecondFactor()) } // TestEmptyMotD ensures that responses returned by both /webapi/ping and // /webapi/motd work when no MotD is set -func (s *WebSuite) TestEmptyMotD(c *C) { - ctx := context.Background() +func TestEmptyMotD(t *testing.T) { + t.Parallel() + s := newWebSuite(t) wc := s.client() // Given an auth server configured *not* to expose a Message Of The // Day... // When I issue a ping request... - re, err := wc.Get(ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) + re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) + require.NoError(t, err) // Expect that the MotD flag in the ping response is *not* set var pingResponse *webclient.PingResponse - c.Assert(json.Unmarshal(re.Bytes(), &pingResponse), IsNil) - c.Assert(pingResponse.Auth.HasMessageOfTheDay, Equals, false) + require.NoError(t, json.Unmarshal(re.Bytes(), &pingResponse)) + require.False(t, pingResponse.Auth.HasMessageOfTheDay) // When I fetch the MotD... - re, err = wc.Get(ctx, wc.Endpoint("webapi", "motd"), url.Values{}) - c.Assert(err, IsNil) + re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "motd"), url.Values{}) + require.NoError(t, err) // Expect that an empty response returned var motdResponse *webclient.MotD - c.Assert(json.Unmarshal(re.Bytes(), &motdResponse), IsNil) - c.Assert(motdResponse.Text, Equals, "") + require.NoError(t, json.Unmarshal(re.Bytes(), &motdResponse)) + require.Empty(t, motdResponse.Text) } // TestMotD ensures that a response is returned by both /webapi/ping and /webapi/motd // and that that the response bodies contain their MOTD components -func (s *WebSuite) TestMotD(c *C) { +func TestMotD(t *testing.T) { + t.Parallel() const motd = "Hello. I'm a Teleport cluster!" - ctx := context.Background() + s := newWebSuite(t) wc := s.client() // Given an auth server configured to expose a Message Of The Day... prefs := types.DefaultAuthPreference() prefs.SetMessageOfTheDay(motd) - s.server.AuthServer.AuthServer.SetAuthPreference(ctx, prefs) + require.NoError(t, s.server.AuthServer.AuthServer.SetAuthPreference(s.ctx, prefs)) // When I issue a ping request... - re, err := wc.Get(ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) + re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) + require.NoError(t, err) // Expect that the MotD flag in the ping response is set to indicate // a MotD var pingResponse *webclient.PingResponse - c.Assert(json.Unmarshal(re.Bytes(), &pingResponse), IsNil) - c.Assert(pingResponse.Auth.HasMessageOfTheDay, Equals, true) + require.NoError(t, json.Unmarshal(re.Bytes(), &pingResponse)) + require.True(t, pingResponse.Auth.HasMessageOfTheDay) // When I fetch the MotD... - re, err = wc.Get(ctx, wc.Endpoint("webapi", "motd"), url.Values{}) - c.Assert(err, IsNil) + re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "motd"), url.Values{}) + require.NoError(t, err) // Expect that the text returned is the configured value var motdResponse *webclient.MotD - c.Assert(json.Unmarshal(re.Bytes(), &motdResponse), IsNil) - c.Assert(motdResponse.Text, Equals, motd) + require.NoError(t, json.Unmarshal(re.Bytes(), &motdResponse)) + require.Equal(t, motd, motdResponse.Text) } -func (s *WebSuite) TestMultipleConnectors(c *C) { +func TestMultipleConnectors(t *testing.T) { + t.Parallel() + s := newWebSuite(t) wc := s.client() // create two oidc connectors, one named "foo" and another named "bar" @@ -1703,61 +1757,61 @@ func (s *WebSuite) TestMultipleConnectors(c *C) { }, } o, err := types.NewOIDCConnector("foo", oidcConnectorSpec) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertOIDCConnector(s.ctx, o) - c.Assert(err, IsNil) + require.NoError(t, err) o2, err := types.NewOIDCConnector("bar", oidcConnectorSpec) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().UpsertOIDCConnector(s.ctx, o2) - c.Assert(err, IsNil) + require.NoError(t, err) // set the auth preferences to oidc with no connector name authPreference, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: "oidc", }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, authPreference) - c.Assert(err, IsNil) + require.NoError(t, err) // hit the ping endpoint to get the auth type and connector name re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) + require.NoError(t, err) var out *webclient.PingResponse - c.Assert(json.Unmarshal(re.Bytes(), &out), IsNil) + require.NoError(t, json.Unmarshal(re.Bytes(), &out)) // make sure the connector name we got back was the first connector // in the backend, in this case it's "bar" oidcConnectors, err := s.server.Auth().GetOIDCConnectors(s.ctx, false) - c.Assert(err, IsNil) - c.Assert(out.Auth.OIDC.Name, Equals, oidcConnectors[0].GetName()) + require.NoError(t, err) + require.Equal(t, oidcConnectors[0].GetName(), out.Auth.OIDC.Name) // update the auth preferences and this time specify the connector name authPreference, err = types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: "oidc", ConnectorName: "foo", }) - c.Assert(err, IsNil) + require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, authPreference) - c.Assert(err, IsNil) + require.NoError(t, err) // hit the ping endpoing to get the auth type and connector name re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) - c.Assert(err, IsNil) - c.Assert(json.Unmarshal(re.Bytes(), &out), IsNil) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(re.Bytes(), &out)) // make sure the connector we get back is "foo" - c.Assert(out.Auth.OIDC.Name, Equals, "foo") + require.Equal(t, "foo", out.Auth.OIDC.Name) } // TestConstructSSHResponse checks if the secret package uses AES-GCM to // encrypt and decrypt data that passes through the ConstructSSHResponse // function. -func (s *WebSuite) TestConstructSSHResponse(c *C) { +func TestConstructSSHResponse(t *testing.T) { key, err := secret.NewKey() - c.Assert(err, IsNil) + require.NoError(t, err) u, err := url.Parse("http://www.example.com/callback") - c.Assert(err, IsNil) + require.NoError(t, err) query := u.Query() query.Set("secret_key", key.String()) u.RawQuery = query.Encode() @@ -1768,35 +1822,35 @@ func (s *WebSuite) TestConstructSSHResponse(c *C) { TLSCert: []byte{0x01}, ClientRedirectURL: u.String(), }) - c.Assert(err, IsNil) + require.NoError(t, err) - c.Assert(rawresp.Query().Get("secret"), Equals, "") - c.Assert(rawresp.Query().Get("secret_key"), Equals, "") - c.Assert(rawresp.Query().Get("response"), Not(Equals), "") + require.Empty(t, rawresp.Query().Get("secret")) + require.Empty(t, rawresp.Query().Get("secret_key")) + require.NotEmpty(t, rawresp.Query().Get("response")) plaintext, err := key.Open([]byte(rawresp.Query().Get("response"))) - c.Assert(err, IsNil) + require.NoError(t, err) var resp *auth.SSHLoginResponse err = json.Unmarshal(plaintext, &resp) - c.Assert(err, IsNil) - c.Assert(resp.Username, Equals, "foo") - c.Assert(resp.Cert, DeepEquals, []byte{0x00}) - c.Assert(resp.TLSCert, DeepEquals, []byte{0x01}) + require.NoError(t, err) + require.Equal(t, "foo", resp.Username) + require.EqualValues(t, []byte{0x00}, resp.Cert) + require.EqualValues(t, []byte{0x01}, resp.TLSCert) } // TestConstructSSHResponseLegacy checks if the secret package uses NaCl to // encrypt and decrypt data that passes through the ConstructSSHResponse // function. -func (s *WebSuite) TestConstructSSHResponseLegacy(c *C) { +func TestConstructSSHResponseLegacy(t *testing.T) { key, err := lemma_secret.NewKey() - c.Assert(err, IsNil) + require.NoError(t, err) lemma, err := lemma_secret.New(&lemma_secret.Config{KeyBytes: key}) - c.Assert(err, IsNil) + require.NoError(t, err) u, err := url.Parse("http://www.example.com/callback") - c.Assert(err, IsNil) + require.NoError(t, err) query := u.Query() query.Set("secret", lemma_secret.KeyToEncodedString(key)) u.RawQuery = query.Encode() @@ -1807,25 +1861,25 @@ func (s *WebSuite) TestConstructSSHResponseLegacy(c *C) { TLSCert: []byte{0x01}, ClientRedirectURL: u.String(), }) - c.Assert(err, IsNil) + require.NoError(t, err) - c.Assert(rawresp.Query().Get("secret"), Equals, "") - c.Assert(rawresp.Query().Get("secret_key"), Equals, "") - c.Assert(rawresp.Query().Get("response"), Not(Equals), "") + require.Empty(t, rawresp.Query().Get("secret")) + require.Empty(t, rawresp.Query().Get("secret_key")) + require.NotEmpty(t, rawresp.Query().Get("response")) var sealedData *lemma_secret.SealedBytes err = json.Unmarshal([]byte(rawresp.Query().Get("response")), &sealedData) - c.Assert(err, IsNil) + require.NoError(t, err) plaintext, err := lemma.Open(sealedData) - c.Assert(err, IsNil) + require.NoError(t, err) var resp *auth.SSHLoginResponse err = json.Unmarshal(plaintext, &resp) - c.Assert(err, IsNil) - c.Assert(resp.Username, Equals, "foo") - c.Assert(resp.Cert, DeepEquals, []byte{0x00}) - c.Assert(resp.TLSCert, DeepEquals, []byte{0x01}) + require.NoError(t, err) + require.Equal(t, "foo", resp.Username) + require.EqualValues(t, []byte{0x00}, resp.Cert) + require.EqualValues(t, []byte{0x01}, resp.TLSCert) } type byTimeAndIndex []apievents.AuditEvent @@ -1848,11 +1902,13 @@ func (f byTimeAndIndex) Swap(i, j int) { } // TestSearchClusterEvents makes sure web API allows querying events by type. -func (s *WebSuite) TestSearchClusterEvents(c *C) { +func TestSearchClusterEvents(t *testing.T) { + t.Parallel() // We need a clock that uses the current time here to work around // the fact that filelog doesn't support emitting past events. clock := clockwork.NewRealClock() + s := newWebSuite(t) sessionEvents := events.GenerateTestSession(events.SessionParams{ PrintEvents: 3, Clock: clock, @@ -1860,7 +1916,7 @@ func (s *WebSuite) TestSearchClusterEvents(c *C) { }) for _, e := range sessionEvents { - c.Assert(s.proxyClient.EmitAuditEvent(context.TODO(), e), IsNil) + require.NoError(t, s.proxyClient.EmitAuditEvent(s.ctx, e)) } sort.Sort(sort.Reverse(byTimeAndIndex(sessionEvents))) @@ -1945,49 +2001,50 @@ func (s *WebSuite) TestSearchClusterEvents(c *C) { }, } - pack := s.authPack(c, "foo") - // var sessionStartKey string + pack := s.authPack(t, "foo") for _, tc := range testCases { - result := s.searchEvents(c, pack.clt, tc.Query, []string{sessionStart.GetType(), sessionPrint.GetType(), sessionEnd.GetType()}) - c.Assert(result.Events, HasLen, len(tc.Result), Commentf(tc.Comment)) - for i, resultEvent := range result.Events { - c.Assert(resultEvent.GetType(), Equals, tc.Result[i].GetType(), Commentf(tc.Comment)) - c.Assert(resultEvent.GetID(), Equals, tc.Result[i].GetID(), Commentf(tc.Comment)) - } + tc := tc + t.Run(tc.Comment, func(t *testing.T) { + t.Parallel() + response, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "events", "search"), tc.Query) + require.NoError(t, err) + var result eventsListGetResponse + require.NoError(t, json.Unmarshal(response.Bytes(), &result)) + + require.Len(t, result.Events, len(tc.Result)) + for i, resultEvent := range result.Events { + require.Equal(t, tc.Result[i].GetType(), resultEvent.GetType()) + require.Equal(t, tc.Result[i].GetID(), resultEvent.GetID()) + } - // Session prints do not have ID's, only sessionStart and sessionEnd. - // When retrieving events for sessionStart and sessionEnd, sessionStart is returned first. - if tc.TestStartKey { - c.Assert(result.StartKey, Equals, tc.StartKeyValue, Commentf(tc.Comment)) - } + // Session prints do not have ID's, only sessionStart and sessionEnd. + // When retrieving events for sessionStart and sessionEnd, sessionStart is returned first. + if tc.TestStartKey { + require.Equal(t, tc.StartKeyValue, result.StartKey) + } + }) } } -func (s *WebSuite) searchEvents(c *C, clt *client.WebClient, query url.Values, filter []string) eventsListGetResponse { - response, err := clt.Get(context.Background(), clt.Endpoint("webapi", "sites", s.server.ClusterName(), "events", "search"), query) - c.Assert(err, IsNil) - var out eventsListGetResponse - c.Assert(json.Unmarshal(response.Bytes(), &out), IsNil) - return out -} - -func (s *WebSuite) TestGetClusterDetails(c *C) { +func TestGetClusterDetails(t *testing.T) { + t.Parallel() + s := newWebSuite(t) site, err := s.proxyTunnel.GetSite(s.server.ClusterName()) - c.Assert(err, IsNil) - c.Assert(site, NotNil) + require.NoError(t, err) + require.NotNil(t, site) cluster, err := ui.GetClusterDetails(s.ctx, site) - c.Assert(err, IsNil) - c.Assert(cluster.Name, Equals, s.server.ClusterName()) - c.Assert(cluster.ProxyVersion, Equals, teleport.Version) - c.Assert(cluster.PublicURL, Equals, fmt.Sprintf("%v:%v", s.server.ClusterName(), defaults.HTTPListenPort)) - c.Assert(cluster.Status, Equals, teleport.RemoteClusterStatusOnline) - c.Assert(cluster.LastConnected, NotNil) - c.Assert(cluster.AuthVersion, Equals, teleport.Version) + require.NoError(t, err) + require.Equal(t, s.server.ClusterName(), cluster.Name) + require.Equal(t, teleport.Version, cluster.ProxyVersion) + require.Equal(t, fmt.Sprintf("%v:%v", s.server.ClusterName(), defaults.HTTPListenPort), cluster.PublicURL) + require.Equal(t, teleport.RemoteClusterStatusOnline, cluster.Status) + require.NotNil(t, cluster.LastConnected) + require.Equal(t, teleport.Version, cluster.AuthVersion) nodes, err := s.proxyClient.GetNodes(s.ctx, apidefaults.Namespace) - c.Assert(err, IsNil) - c.Assert(nodes, HasLen, cluster.NodeCount) + require.NoError(t, err) + require.Len(t, nodes, cluster.NodeCount) } type testModules struct { @@ -2108,7 +2165,7 @@ func TestApplicationAccessDisabled(t *testing.T) { Version: types.V2, Metadata: types.Metadata{ Namespace: apidefaults.Namespace, - Name: uuid.New(), + Name: uuid.New().String(), }, Spec: types.ServerSpecV2{ Version: teleport.Version, @@ -2135,9 +2192,11 @@ func TestApplicationAccessDisabled(t *testing.T) { } // TestCreateAppSession verifies that an existing session to the Web UI can -// be exchanged for a application specific session. -func (s *WebSuite) TestCreateAppSession(c *C) { - pack := s.authPack(c, "foo@example.com") +// be exchanged for an application specific session. +func TestCreateAppSession(t *testing.T) { + t.Parallel() + s := newWebSuite(t) + pack := s.authPack(t, "foo@example.com") // Register an application called "panel". server := &types.ServerV2{ @@ -2145,7 +2204,7 @@ func (s *WebSuite) TestCreateAppSession(c *C) { Version: types.V2, Metadata: types.Metadata{ Namespace: apidefaults.Namespace, - Name: uuid.New(), + Name: uuid.New().String(), }, Spec: types.ServerSpecV2{ Version: teleport.Version, @@ -2159,126 +2218,130 @@ func (s *WebSuite) TestCreateAppSession(c *C) { }, } _, err := s.server.Auth().UpsertAppServer(context.Background(), server) - c.Assert(err, IsNil) + require.NoError(t, err) // Extract the session ID and bearer token for the current session. rawCookie := *pack.cookies[0] cookieBytes, err := hex.DecodeString(rawCookie.Value) - c.Assert(err, IsNil) + require.NoError(t, err) var sessionCookie SessionCookie err = json.Unmarshal(cookieBytes, &sessionCookie) - c.Assert(err, IsNil) + require.NoError(t, err) - var tests = []struct { - inComment CommentInterface + tests := []struct { + name string inCreateRequest *CreateAppSessionRequest - outError bool + outError require.ErrorAssertionFunc outFQDN string outUsername string }{ { - inComment: Commentf("Valid request: all fields."), + name: "Valid request: all fields", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "panel.example.com", ClusterName: "localhost", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Valid request: without FQDN."), + name: "Valid request: without FQDN", inCreateRequest: &CreateAppSessionRequest{ PublicAddr: "panel.example.com", ClusterName: "localhost", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Valid request: only FQDN."), + name: "Valid request: only FQDN", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Invalid request: only public address."), + name: "Invalid request: only public address", inCreateRequest: &CreateAppSessionRequest{ PublicAddr: "panel.example.com", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Invalid request: only cluster name."), + name: "Invalid request: only cluster name", inCreateRequest: &CreateAppSessionRequest{ ClusterName: "localhost", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Invalid application."), + name: "Invalid application", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "invalid.example.com", ClusterName: "localhost", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Invalid cluster name."), + name: "Invalid cluster name", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "panel.example.com", ClusterName: "example.com", }, - outError: true, + outError: require.Error, }, { - inComment: Commentf("Malicious request: all fields."), + name: "Malicious request: all fields", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com@malicious.com", PublicAddr: "panel.example.com", ClusterName: "localhost", }, - outError: false, + outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { - inComment: Commentf("Malicious request: only FQDN."), + name: "Malicious request: only FQDN", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com@malicious.com", }, - outError: true, + outError: require.Error, }, } for _, tt := range tests { - // Make a request to create an application session for "panel". - endpoint := pack.clt.Endpoint("webapi", "sessions", "app") - resp, err := pack.clt.PostJSON(context.Background(), endpoint, tt.inCreateRequest) - c.Assert(err != nil, Equals, tt.outError, tt.inComment) - if tt.outError { - continue - } - - // Unmarshal the response. - var response *CreateAppSessionResponse - c.Assert(json.Unmarshal(resp.Bytes(), &response), IsNil, tt.inComment) - c.Assert(response.FQDN, Equals, tt.outFQDN, tt.inComment) + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Make a request to create an application session for "panel". + endpoint := pack.clt.Endpoint("webapi", "sessions", "app") + resp, err := pack.clt.PostJSON(s.ctx, endpoint, tt.inCreateRequest) + tt.outError(t, err) + if err != nil { + return + } - // Verify that the application session was created. - session, err := s.server.Auth().GetAppSession(context.Background(), types.GetAppSessionRequest{ - SessionID: response.CookieValue, + // Unmarshal the response. + var response *CreateAppSessionResponse + require.NoError(t, json.Unmarshal(resp.Bytes(), &response)) + require.Equal(t, tt.outFQDN, response.FQDN) + + // Verify that the application session was created. + sess, err := s.server.Auth().GetAppSession(s.ctx, types.GetAppSessionRequest{ + SessionID: response.CookieValue, + }) + require.NoError(t, err) + require.Equal(t, tt.outUsername, sess.GetUser()) + require.Equal(t, response.CookieValue, sess.GetName()) }) - c.Assert(err, IsNil) - c.Assert(session.GetUser(), Equals, tt.outUsername, tt.inComment) - c.Assert(session.GetName(), Equals, response.CookieValue, tt.inComment) } } @@ -2290,7 +2353,7 @@ func TestNewSessionResponseWithRenewSession(t *testing.T) { duration := time.Duration(5) * time.Minute cfg := types.DefaultClusterNetworkingConfig() cfg.SetWebIdleTimeout(duration) - env.server.Auth().SetClusterNetworkingConfig(context.Background(), cfg) + require.NoError(t, env.server.Auth().SetClusterNetworkingConfig(context.Background(), cfg)) proxy := env.proxies[0] pack := proxy.authPack(t, "foo") @@ -2328,7 +2391,7 @@ func TestWebSessionsRenewDoesNotBreakExistingTerminalSession(t *testing.T) { env.clock.Advance(auth.BearerTokenTTL - delta) // Renew the session using the 1st proxy - resp := pack1.renewSession(context.TODO(), t) + resp := pack1.renewSession(context.Background(), t) // Expire the old session and make sure it has been removed. // The bearer token is also removed after this point, so we have to @@ -2337,7 +2400,7 @@ func TestWebSessionsRenewDoesNotBreakExistingTerminalSession(t *testing.T) { pack2 = proxy2.authPackFromResponse(t, resp) // Verify that access via the 2nd proxy also works for the same session - pack2.validateAPI(context.TODO(), t) + pack2.validateAPI(context.Background(), t) // Check whether the terminal session is still active validateTerminalStream(t, ws) @@ -2366,12 +2429,12 @@ func TestWebSessionsRenewAllowsOldBearerTokenToLinger(t *testing.T) { // prevSessionCookie := *pack.cookies[0] prevBearerToken := pack.session.Token - resp := pack.renewSession(context.TODO(), t) + resp := pack.renewSession(context.Background(), t) newPack := proxy.authPackFromResponse(t, resp) // new session is functioning - newPack.validateAPI(context.TODO(), t) + newPack.validateAPI(context.Background(), t) sessionCookie := *newPack.cookies[0] bearerToken := newPack.session.Token @@ -2394,7 +2457,7 @@ func TestWebSessionsRenewAllowsOldBearerTokenToLinger(t *testing.T) { // now expire the old session and make sure it has been removed env.clock.Advance(delta) - _, err = proxy.client.GetWebSession(context.TODO(), types.GetWebSessionRequest{ + _, err = proxy.client.GetWebSession(context.Background(), types.GetWebSessionRequest{ User: "foo", SessionID: prevSessionID, }) @@ -2496,7 +2559,7 @@ func waitForOutput(stream *terminalStream, substr string) error { } func (s *WebSuite) waitForRawEvent(ws *websocket.Conn, timeout time.Duration) error { - timeoutContext, timeoutCancel := context.WithTimeout(context.Background(), timeout) + timeoutContext, timeoutCancel := context.WithTimeout(s.ctx, timeout) defer timeoutCancel() done := make(chan error, 1) @@ -2535,7 +2598,7 @@ func (s *WebSuite) waitForRawEvent(ws *websocket.Conn, timeout time.Duration) er } func (s *WebSuite) waitForResizeEvent(ws *websocket.Conn, timeout time.Duration) error { - timeoutContext, timeoutCancel := context.WithTimeout(context.Background(), timeout) + timeoutContext, timeoutCancel := context.WithTimeout(s.ctx, timeout) defer timeoutCancel() done := make(chan error, 1) @@ -2811,6 +2874,15 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(proxyLockWatcher.Close) + proxyNodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: client, + }, + }) + require.NoError(t, err) + t.Cleanup(proxyNodeWatcher.Close) + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -2824,6 +2896,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula DirectClusters: []reversetunnel.DirectCluster{{Name: authServer.ClusterName(), Client: client}}, DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, + NodeWatcher: proxyNodeWatcher, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunServer.Close()) }) @@ -2845,6 +2918,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(clock), regular.SetLockWatcher(proxyLockWatcher), + regular.SetNodeWatcher(proxyNodeWatcher), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, proxyServer.Close()) }) @@ -2937,7 +3011,7 @@ func (r *proxy) authPack(t *testing.T, user string) *authPack { err = r.auth.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) - r.createUser(context.TODO(), t, user, loginUser, pass, otpSecret) + r.createUser(context.Background(), t, user, loginUser, pass, otpSecret) // create a valid otp token validToken, err := totp.GenerateCode(otpSecret, r.clock.Now()) @@ -3081,7 +3155,7 @@ func (r *proxy) makeTerminal(t *testing.T, pack *authPack, sessionID session.ID) ws, err := websocket.DialConfig(wscfg) require.NoError(t, err) - t.Cleanup(func() { ws.Close() }) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) return ws } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index fbbdb779f632d..3045a59608cc5 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -26,6 +26,9 @@ import ( "sync" "time" + "github.com/gogo/protobuf/proto" + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "golang.org/x/net/websocket" "golang.org/x/text/encoding" @@ -43,11 +46,6 @@ import ( "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" - - "github.com/gravitational/trace" - - "github.com/gogo/protobuf/proto" - "github.com/sirupsen/logrus" ) // TerminalRequest describes a request to create a web-based terminal @@ -141,7 +139,7 @@ type TerminalHandler struct { // params is the initial PTY size. params TerminalRequest - // ctx is a web session context for the currently logged in user. + // ctx is a web session context for the currently logged zin user. ctx *SessionContext // hostName is the hostname of the server. @@ -156,7 +154,7 @@ type TerminalHandler struct { // sshSession holds the "shell" SSH channel to the node. sshSession *ssh.Session - // terminalContext is used to signal when the terminal sesson is closing. + // terminalContext is used to signal when the terminal session is closing. terminalContext context.Context // terminalCancel is used to signal when the terminal session is closing. diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 4338de5c0931d..0617bcfedf9d9 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -19,6 +19,7 @@ package main import ( "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -260,6 +261,14 @@ func (c *CLIConf) Stderr() io.Writer { return os.Stderr } +type exitCodeError struct { + code int +} + +func (e *exitCodeError) Error() string { + return fmt.Sprintf("exit code %d", e.code) +} + func main() { cmdLineOrig := os.Args[1:] var cmdLine []string @@ -275,6 +284,10 @@ func main() { cmdLine = cmdLineOrig } if err := Run(cmdLine); err != nil { + var exitError *exitCodeError + if errors.As(err, &exitError) { + os.Exit(exitError.code) + } utils.FatalError(err) } } @@ -1111,7 +1124,7 @@ func onLogout(cf *CLIConf) error { if err != nil { if trace.IsNotFound(err) { fmt.Printf("User %v already logged out from %v.\n", cf.Username, proxyHost) - os.Exit(1) + return trace.Wrap(&exitCodeError{code: 1}) } return trace.Wrap(err) } @@ -1589,15 +1602,14 @@ func onSSH(cf *CLIConf) error { fmt.Fprintf(os.Stderr, "Hint: try addressing the node by unique id (ex: tsh ssh user@node-id)\n") fmt.Fprintf(os.Stderr, "Hint: use 'tsh ls -v' to list all nodes with their unique ids\n") fmt.Fprintf(os.Stderr, "\n") - os.Exit(1) + return trace.Wrap(&exitCodeError{code: 1}) } // exit with the same exit status as the failed command: if tc.ExitStatus != 0 { fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err)) - os.Exit(tc.ExitStatus) - } else { - return trace.Wrap(err) + return trace.Wrap(&exitCodeError{code: tc.ExitStatus}) } + return trace.Wrap(err) } return nil } @@ -1616,7 +1628,7 @@ func onBenchmark(cf *CLIConf) error { result, err := cnf.Benchmark(cf.Context, tc) if err != nil { fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err)) - os.Exit(255) + return trace.Wrap(&exitCodeError{code: 255}) } fmt.Printf("\n") fmt.Printf("* Requests originated: %v\n", result.RequestsOriginated) @@ -1684,7 +1696,7 @@ func onSCP(cf *CLIConf) error { // exit with the same exit status as the failed command: if tc.ExitStatus != 0 { fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err)) - os.Exit(tc.ExitStatus) + return trace.Wrap(&exitCodeError{code: tc.ExitStatus}) } return trace.Wrap(err) }