diff --git a/integration/integration_test.go b/integration/integration_test.go index 45227bdcd8c80..56dbbe8595a9e 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -1905,6 +1905,8 @@ func testMapRoles(t *testing.T, suite *integrationTestSuite) { // tryCreateTrustedCluster performs several attempts to create a trusted cluster, // retries on connection problems and access denied errors to let caches // propagate and services to start +// +// Duplicated in tool/tsh/tsh_test.go func tryCreateTrustedCluster(t *testing.T, authServer *auth.Server, trustedCluster types.TrustedCluster) { ctx := context.TODO() for i := 0; i < 10; i++ { diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index e6fc731c8af5c..ef5fbb66f9856 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -1236,13 +1236,36 @@ func executeAccessRequest(cf *CLIConf, tc *client.TeleportClient) error { req.SetSuggestedReviewers(reviewers) } - // Watch for resolution events on the given request. Start watcher before - // creating the request to avoid a potential race. + // Watch for resolution events on the given request. Start watcher and wait + // for it to be ready before creating the request to avoid a potential race. errChan := make(chan error) if !cf.NoWait { + log.Debug("Waiting for the access-request watcher to ready up...") + ready := make(chan struct{}) go func() { - errChan <- waitForRequestResolution(cf, tc, req) + var resolvedReq types.AccessRequest + err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { + var err error + resolvedReq, err = waitForRequestResolution(cf, clt, req, ready) + return trace.Wrap(err) + }) + + if err != nil { + errChan <- trace.Wrap(err) + } else { + errChan <- trace.Wrap(onRequestResolution(cf, tc, resolvedReq)) + } }() + + select { + case err = <-errChan: + if err == nil { + return trace.Errorf("event watcher exited cleanly without readying up?") + } + return trace.Wrap(err) + case <-ready: + log.Debug("Access-request watcher is ready") + } } // Create request if it doesn't already exist @@ -2161,26 +2184,24 @@ func host(in string) string { return out } -// waitForRequestResolution waits for an access request to be resolved. -func waitForRequestResolution(cf *CLIConf, tc *client.TeleportClient, req types.AccessRequest) error { +// waitForRequestResolution waits for an access request to be resolved. On +// approval, returns the updated request. `clt` must be a client to the root +// cluster, such as the one returned by +// `(*TeleportClient).WithRootClusterClient`. `ready` will be closed when the +// event watcher used to wait for the request updates is ready. +func waitForRequestResolution(cf *CLIConf, clt auth.ClientI, req types.AccessRequest, ready chan<- struct{}) (types.AccessRequest, error) { filter := types.AccessRequestFilter{ User: req.GetUser(), } - var err error - var watcher types.Watcher - err = tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { - watcher, err = tc.NewWatcher(cf.Context, types.Watch{ - Name: "await-request-approval", - Kinds: []types.WatchKind{{ - Kind: types.KindAccessRequest, - Filter: filter.IntoMap(), - }}, - }) - return trace.Wrap(err) + watcher, err := clt.NewWatcher(cf.Context, types.Watch{ + Name: "await-request-approval", + Kinds: []types.WatchKind{{ + Kind: types.KindAccessRequest, + Filter: filter.IntoMap(), + }}, }) - if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } defer watcher.Close() Loop: @@ -2190,28 +2211,29 @@ Loop: switch event.Type { case types.OpInit: log.Infof("Access-request watcher initialized...") + close(ready) continue Loop case types.OpPut: r, ok := event.Resource.(*types.AccessRequestV3) if !ok { - return trace.BadParameter("unexpected resource type %T", event.Resource) + return nil, trace.BadParameter("unexpected resource type %T", event.Resource) } if r.GetName() != req.GetName() || r.GetState().IsPending() { log.Debugf("Skipping put event id=%s,state=%s.", r.GetName(), r.GetState()) continue Loop } - return onRequestResolution(cf, tc, r) + return r, nil case types.OpDelete: if event.Resource.GetName() != req.GetName() { log.Debugf("Skipping delete event id=%s", event.Resource.GetName()) continue Loop } - return trace.Errorf("request %s has expired or been deleted...", event.Resource.GetName()) + return nil, trace.Errorf("request %s has expired or been deleted...", event.Resource.GetName()) default: log.Warnf("Skipping unknown event type %s", event.Type) } case <-watcher.Done(): - return trace.Wrap(watcher.Error()) + return nil, trace.Wrap(watcher.Error()) } } } diff --git a/tool/tsh/tsh_test.go b/tool/tsh/tsh_test.go index b2bd390ebce2e..98316be0b6b84 100644 --- a/tool/tsh/tsh_test.go +++ b/tool/tsh/tsh_test.go @@ -35,6 +35,7 @@ import ( "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/client" @@ -42,6 +43,7 @@ import ( "github.com/gravitational/teleport/lib/kube/kubeconfig" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/service" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -432,6 +434,169 @@ func TestMakeClient(t *testing.T) { require.Greater(t, len(agentKeys), 0) } +func TestAccessRequestOnLeaf(t *testing.T) { + tmpHomePath := t.TempDir() + + isInsecure := lib.IsInsecureDevMode() + lib.SetInsecureDevMode(true) + t.Cleanup(func() { + lib.SetInsecureDevMode(isInsecure) + }) + + requester, err := types.NewRole("requester", types.RoleSpecV4{ + Allow: types.RoleConditions{ + Request: &types.AccessRequestConditions{ + Roles: []string{"access"}, + }, + }, + }) + require.NoError(t, err) + + connector := mockConnector(t) + + alice, err := types.NewUser("alice@example.com") + require.NoError(t, err) + alice.SetRoles([]string{"requester"}) + + rootAuth, rootProxy := makeTestServers(t, + withBootstrap(requester, connector, alice), + ) + + rootAuthServer := rootAuth.GetAuthServer() + require.NotNil(t, rootAuthServer) + rootProxyAddr, err := rootProxy.ProxyWebAddr() + require.NoError(t, err) + rootTunnelAddr, err := rootProxy.ProxyTunnelAddr() + require.NoError(t, err) + + trustedCluster, err := types.NewTrustedCluster("localhost", types.TrustedClusterSpecV2{ + Enabled: true, + Roles: []string{}, + Token: staticToken, + ProxyAddress: rootProxyAddr.String(), + ReverseTunnelAddress: rootTunnelAddr.String(), + RoleMap: []types.RoleMapping{ + { + Remote: "access", + Local: []string{"access"}, + }, + }, + }) + require.NoError(t, err) + + leafAuth, _ := makeTestServers(t, withClusterName(t, "leafcluster")) + tryCreateTrustedCluster(t, leafAuth.GetAuthServer(), trustedCluster) + + err = Run([]string{ + "login", + "--insecure", + "--debug", + "--auth", connector.GetName(), + "--proxy", rootProxyAddr.String(), + }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { + cf.mockSSOLogin = mockSSOLogin(t, rootAuthServer, alice) + return nil + })) + require.NoError(t, err) + + err = Run([]string{ + "login", + "--insecure", + "--debug", + "--proxy", rootProxyAddr.String(), + "leafcluster", + }, setHomePath(tmpHomePath)) + require.NoError(t, err) + + err = Run([]string{ + "login", + "--insecure", + "--debug", + "--proxy", rootProxyAddr.String(), + "localhost", + }, setHomePath(tmpHomePath)) + require.NoError(t, err) + + err = Run([]string{ + "login", + "--insecure", + "--debug", + "--proxy", rootProxyAddr.String(), + "leafcluster", + }, setHomePath(tmpHomePath)) + require.NoError(t, err) + + errChan := make(chan error) + go func() { + errChan <- Run([]string{ + "request", + "new", + "--insecure", + "--debug", + "--proxy", rootProxyAddr.String(), + "--roles=access", + }, setHomePath(tmpHomePath)) + }() + + var request types.AccessRequest + for i := 0; i < 5; i++ { + log.Debugf("Waiting for access request %d", i) + requests, err := rootAuth.GetAuthServer().GetAccessRequests(rootAuth.ExitContext(), types.AccessRequestFilter{}) + require.NoError(t, err) + require.LessOrEqual(t, len(requests), 1) + if len(requests) == 1 { + request = requests[0] + break + } + time.Sleep(1 * time.Second) + } + require.NotNil(t, request) + + err = rootAuth.GetAuthServer().SetAccessRequestState( + rootAuth.ExitContext(), + types.AccessRequestUpdate{ + RequestID: request.GetName(), + State: types.RequestState_APPROVED, + }, + ) + require.NoError(t, err) + + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(2 * time.Minute): + t.Fatal("access request wasn't resolved after 2 minutes") + } +} + +// tryCreateTrustedCluster performs several attempts to create a trusted cluster, +// retries on connection problems and access denied errors to let caches +// propagate and services to start +// +// Duplicated in integration/integration_test.go +func tryCreateTrustedCluster(t *testing.T, authServer *auth.Server, trustedCluster types.TrustedCluster) { + ctx := context.TODO() + for i := 0; i < 10; i++ { + log.Debugf("Will create trusted cluster %v, attempt %v.", trustedCluster, i) + _, err := authServer.UpsertTrustedCluster(ctx, trustedCluster) + if err == nil { + return + } + if trace.IsConnectionProblem(err) { + log.Debugf("Retrying on connection problem: %v.", err) + time.Sleep(500 * time.Millisecond) + continue + } + if trace.IsAccessDenied(err) { + log.Debugf("Retrying on access denied: %v.", err) + time.Sleep(500 * time.Millisecond) + continue + } + require.FailNow(t, "Terminating on unexpected problem", "%v.", err) + } + require.FailNow(t, "Timeout creating trusted cluster") +} + func TestIdentityRead(t *testing.T) { // 3 different types of identities ids := []string{ @@ -942,6 +1107,17 @@ func withAuthConfig(fn func(cfg *service.AuthConfig)) testServerOptFunc { } } +func withClusterName(t *testing.T, n string) testServerOptFunc { + return withAuthConfig(func(cfg *service.AuthConfig) { + clusterName, err := services.NewClusterNameWithRandomID( + types.ClusterNameSpecV2{ + ClusterName: n, + }) + require.NoError(t, err) + cfg.ClusterName = clusterName + }) +} + func makeTestServers(t *testing.T, opts ...testServerOptFunc) (auth *service.TeleportProcess, proxy *service.TeleportProcess) { var options testServersOpts for _, opt := range opts { @@ -962,7 +1138,7 @@ func makeTestServers(t *testing.T, opts ...testServerOptFunc) (auth *service.Tel cfg.Auth.StorageConfig.Params = backend.Params{defaults.BackendPath: filepath.Join(cfg.DataDir, defaults.BackendDir)} cfg.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{ StaticTokens: []types.ProvisionTokenV1{{ - Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase}, + Roles: []types.SystemRole{types.RoleProxy, types.RoleDatabase, types.RoleTrustedCluster}, Expires: time.Now().Add(time.Minute), Token: staticToken, }},