Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tctl InitFunc to return an authclient.ClientI #51093

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions lib/auth/authclient/clt.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package authclient

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
Expand All @@ -41,8 +42,10 @@ import (
"github.com/gravitational/teleport/api/client/usertask"
apidefaults "github.com/gravitational/teleport/api/defaults"
accessgraphsecretsv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1"
autoupdatev1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/autoupdate/v1"
clusterconfigpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/clusterconfig/v1"
dbobjectimportrulev1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/dbobjectimportrule/v1"
decisionv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/decision/v1alpha1"
devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1"
identitycenterv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/identitycenter/v1"
integrationv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/integration/v1"
Expand All @@ -56,6 +59,7 @@ import (
trustpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/trust/v1"
userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1"
"github.com/gravitational/teleport/api/gen/proto/go/teleport/vnet/v1"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
userpreferencesv1 "github.com/gravitational/teleport/api/gen/proto/go/userpreferences/v1"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -1826,6 +1830,9 @@ type ClientI interface {
// when calling this method, but all RPCs will return "not implemented" errors
// (as per the default gRPC behavior).
WorkloadIdentityServiceClient() machineidv1pb.WorkloadIdentityServiceClient
SPIFFEFederationServiceClient() machineidv1pb.SPIFFEFederationServiceClient
WorkloadIdentityResourceServiceClient() workloadidentityv1pb.WorkloadIdentityResourceServiceClient
WorkloadIdentityIssuanceClient() workloadidentityv1pb.WorkloadIdentityIssuanceServiceClient

// NotificationServiceClient returns a notification service client.
// Clients connecting to older Teleport versions, still get a client
Expand Down Expand Up @@ -1903,4 +1910,29 @@ type ClientI interface {

// GitServerReadOnlyClient returns the read-only client for Git servers.
GitServerReadOnlyClient() gitserver.ReadOnlyClient

DecisionClient() decisionv1.DecisionServiceClient

SetMFAPromptConstructor(pc mfa.PromptConstructor)

CreateAutoUpdateConfig(ctx context.Context, config *autoupdatev1pb.AutoUpdateConfig) (*autoupdatev1pb.AutoUpdateConfig, error)
UpdateAutoUpdateConfig(ctx context.Context, config *autoupdatev1pb.AutoUpdateConfig) (*autoupdatev1pb.AutoUpdateConfig, error)
UpsertAutoUpdateConfig(ctx context.Context, config *autoupdatev1pb.AutoUpdateConfig) (*autoupdatev1pb.AutoUpdateConfig, error)
DeleteAutoUpdateConfig(ctx context.Context) error

CreateAutoUpdateVersion(ctx context.Context, config *autoupdatev1pb.AutoUpdateVersion) (*autoupdatev1pb.AutoUpdateVersion, error)
UpdateAutoUpdateVersion(ctx context.Context, config *autoupdatev1pb.AutoUpdateVersion) (*autoupdatev1pb.AutoUpdateVersion, error)
UpsertAutoUpdateVersion(ctx context.Context, config *autoupdatev1pb.AutoUpdateVersion) (*autoupdatev1pb.AutoUpdateVersion, error)
DeleteAutoUpdateVersion(ctx context.Context) error

CreateAutoUpdateAgentRollout(ctx context.Context, config *autoupdatev1pb.AutoUpdateAgentRollout) (*autoupdatev1pb.AutoUpdateAgentRollout, error)
UpdateAutoUpdateAgentRollout(ctx context.Context, config *autoupdatev1pb.AutoUpdateAgentRollout) (*autoupdatev1pb.AutoUpdateAgentRollout, error)
UpsertAutoUpdateAgentRollout(ctx context.Context, config *autoupdatev1pb.AutoUpdateAgentRollout) (*autoupdatev1pb.AutoUpdateAgentRollout, error)
DeleteAutoUpdateAgentRollout(cxt context.Context) error

GetDesktopBootstrapScript(ctx context.Context) (string, error)

CrownJewelsClient() services.CrownJewels
UserTasksClient() services.UserTasks
Config() *tls.Config
}
22 changes: 11 additions & 11 deletions tool/tctl/common/access_request_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (c *AccessRequestCommand) Initialize(app *kingpin.Application, _ *tctlcfg.G

// TryRun takes the CLI command as an argument (like "access-request list") and executes it.
func (c *AccessRequestCommand) TryRun(ctx context.Context, cmd string, clientFunc commonclient.InitFunc) (match bool, err error) {
var commandFunc func(ctx context.Context, client *authclient.Client) error
var commandFunc func(ctx context.Context, client authclient.ClientI) error
switch cmd {
case c.requestList.FullCommand():
commandFunc = c.List
Expand Down Expand Up @@ -160,7 +160,7 @@ func (c *AccessRequestCommand) TryRun(ctx context.Context, cmd string, clientFun
return true, trace.Wrap(err)
}

func (c *AccessRequestCommand) List(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) List(ctx context.Context, client authclient.ClientI) error {
var index proto.AccessRequestSort
switch c.sortIndex {
case "created":
Expand Down Expand Up @@ -203,7 +203,7 @@ func (c *AccessRequestCommand) List(ctx context.Context, client *authclient.Clie
return nil
}

func (c *AccessRequestCommand) Get(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Get(ctx context.Context, client authclient.ClientI) error {
reqs := []types.AccessRequest{}
for _, reqID := range strings.Split(c.reqIDs, ",") {
req, err := client.GetAccessRequests(ctx, types.AccessRequestFilter{
Expand Down Expand Up @@ -258,7 +258,7 @@ func (c *AccessRequestCommand) splitRoles() []string {
return roles
}

func (c *AccessRequestCommand) Approve(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Approve(ctx context.Context, client authclient.ClientI) error {
if c.delegator != "" {
ctx = authz.WithDelegator(ctx, c.delegator)
}
Expand Down Expand Up @@ -289,7 +289,7 @@ func (c *AccessRequestCommand) Approve(ctx context.Context, client *authclient.C
return nil
}

func (c *AccessRequestCommand) Deny(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Deny(ctx context.Context, client authclient.ClientI) error {
if c.delegator != "" {
ctx = authz.WithDelegator(ctx, c.delegator)
}
Expand All @@ -310,7 +310,7 @@ func (c *AccessRequestCommand) Deny(ctx context.Context, client *authclient.Clie
return nil
}

func (c *AccessRequestCommand) Create(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Create(ctx context.Context, client authclient.ClientI) error {
if len(c.roles) == 0 && len(c.requestedResourceIDs) == 0 {
c.roles = "*"
}
Expand All @@ -326,10 +326,10 @@ func (c *AccessRequestCommand) Create(ctx context.Context, client *authclient.Cl

if c.dryRun {
users := &struct {
*authclient.Client
authclient.ClientI
services.UserLoginStatesGetter
}{
Client: client,
ClientI: client,
UserLoginStatesGetter: client.UserLoginStateClient(),
}
err = services.ValidateAccessRequestForUser(ctx, clockwork.NewRealClock(), users, req, tlsca.Identity{}, services.ExpandVars(true))
Expand All @@ -346,7 +346,7 @@ func (c *AccessRequestCommand) Create(ctx context.Context, client *authclient.Cl
return nil
}

func (c *AccessRequestCommand) Delete(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Delete(ctx context.Context, client authclient.ClientI) error {
var approvedTokens []string
for _, reqID := range strings.Split(c.reqIDs, ",") {
// Fetch the requests first to see if they were approved to provide the
Expand Down Expand Up @@ -386,7 +386,7 @@ func (c *AccessRequestCommand) Delete(ctx context.Context, client *authclient.Cl
return nil
}

func (c *AccessRequestCommand) Caps(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Caps(ctx context.Context, client authclient.ClientI) error {
caps, err := client.GetAccessCapabilities(ctx, types.AccessCapabilitiesRequest{
User: c.user,
RequestableRoles: true,
Expand Down Expand Up @@ -422,7 +422,7 @@ func (c *AccessRequestCommand) Caps(ctx context.Context, client *authclient.Clie
}
}

func (c *AccessRequestCommand) Review(ctx context.Context, client *authclient.Client) error {
func (c *AccessRequestCommand) Review(ctx context.Context, client authclient.ClientI) error {
if c.approve == c.deny {
return trace.BadParameter("must supply exactly one of '--approve' or '--deny'")
}
Expand Down
22 changes: 11 additions & 11 deletions tool/tctl/common/accessmonitoring/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (c *Command) initAuditReportsCommands(auditCmd *kingpin.CmdClause, cfg *ser
})
}

type runFunc func(context.Context, *authclient.Client) error
type runFunc func(context.Context, authclient.ClientI) error

func (c *Command) TryRun(ctx context.Context, cmd string, clientFunc commonclient.InitFunc) (match bool, err error) {
handler, ok := c.innerCmdMap[cmd]
Expand All @@ -136,7 +136,7 @@ func (c *Command) TryRun(ctx context.Context, cmd string, clientFunc commonclien
}
}

func (c *cmdHandler) onAuditQueryExec(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQueryExec(ctx context.Context, authClient authclient.ClientI) error {
if c.auditQuery == "" {
buff, err := io.ReadAll(os.Stdin)
if err != nil {
Expand All @@ -154,7 +154,7 @@ func (c *cmdHandler) onAuditQueryExec(ctx context.Context, authClient *authclien
return nil
}

func (c *cmdHandler) onAuditQueryGet(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQueryGet(ctx context.Context, authClient authclient.ClientI) error {
auditQuery, err := authClient.SecReportsClient().GetSecurityAuditQuery(ctx, c.name)
if err != nil {
return trace.Wrap(err)
Expand All @@ -165,7 +165,7 @@ func (c *cmdHandler) onAuditQueryGet(ctx context.Context, authClient *authclient
return nil
}

func (c *cmdHandler) onAuditQueryLs(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQueryLs(ctx context.Context, authClient authclient.ClientI) error {
auditQueries, err := authClient.SecReportsClient().GetSecurityAuditQueries(ctx)
if err != nil {
return trace.Wrap(err)
Expand All @@ -176,14 +176,14 @@ func (c *cmdHandler) onAuditQueryLs(ctx context.Context, authClient *authclient.
return nil
}

func (c *cmdHandler) onAuditQueryRm(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQueryRm(ctx context.Context, authClient authclient.ClientI) error {
if err := authClient.SecReportsClient().DeleteSecurityAuditQuery(ctx, c.name); err != nil {
return trace.Wrap(err)
}
return nil
}

func (c *cmdHandler) onAuditQuerySchema(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQuerySchema(ctx context.Context, authClient authclient.ClientI) error {
resp, err := authClient.SecReportsClient().GetSchema(ctx)
if err != nil {
return trace.Wrap(err)
Expand All @@ -201,7 +201,7 @@ func (c *cmdHandler) onAuditQuerySchema(ctx context.Context, authClient *authcli
return nil
}

func (c *cmdHandler) onAuditQueryCreate(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditQueryCreate(ctx context.Context, authClient authclient.ClientI) error {
if c.auditQuery == "" {
return trace.BadParameter("audit query required")
}
Expand All @@ -221,7 +221,7 @@ func (c *cmdHandler) onAuditQueryCreate(ctx context.Context, authClient *authcli
return nil
}

func (c *cmdHandler) onAuditReportLs(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditReportLs(ctx context.Context, authClient authclient.ClientI) error {
reports, err := authClient.SecReportsClient().GetSecurityReports(ctx)
if err != nil {
return trace.Wrap(err)
Expand All @@ -232,7 +232,7 @@ func (c *cmdHandler) onAuditReportLs(ctx context.Context, authClient *authclient
return trace.Wrap(err)
}

func (c *cmdHandler) onAuditReportGet(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditReportGet(ctx context.Context, authClient authclient.ClientI) error {
details, err := authClient.SecReportsClient().GetSecurityReportResult(ctx, c.name, c.days)
if err != nil {
return trace.Wrap(err)
Expand All @@ -243,15 +243,15 @@ func (c *cmdHandler) onAuditReportGet(ctx context.Context, authClient *authclien
return nil
}

func (c *cmdHandler) onAuditReportRun(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditReportRun(ctx context.Context, authClient authclient.ClientI) error {
err := authClient.SecReportsClient().RunSecurityReport(ctx, c.name, c.days)
if err != nil {
return trace.Wrap(err)
}
return nil
}

func (c *cmdHandler) onAuditReportState(ctx context.Context, authClient *authclient.Client) error {
func (c *cmdHandler) onAuditReportState(ctx context.Context, authClient authclient.ClientI) error {
state, err := authClient.SecReportsClient().GetSecurityReportExecutionState(ctx, c.name, int32(c.days))
if err != nil {
return trace.Wrap(err)
Expand Down
12 changes: 6 additions & 6 deletions tool/tctl/common/acl_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (c *ACLCommand) Initialize(app *kingpin.Application, _ *tctlcfg.GlobalCLIFl

// TryRun takes the CLI command as an argument (like "acl ls") and executes it.
func (c *ACLCommand) TryRun(ctx context.Context, cmd string, clientFunc commonclient.InitFunc) (match bool, err error) {
var commandFunc func(ctx context.Context, client *authclient.Client) error
var commandFunc func(ctx context.Context, client authclient.ClientI) error
switch cmd {
case c.ls.FullCommand():
commandFunc = c.List
Expand All @@ -122,7 +122,7 @@ func (c *ACLCommand) TryRun(ctx context.Context, cmd string, clientFunc commoncl
}

// List will list access lists visible to the user.
func (c *ACLCommand) List(ctx context.Context, client *authclient.Client) error {
func (c *ACLCommand) List(ctx context.Context, client authclient.ClientI) error {
var accessLists []*accesslist.AccessList
var nextKey string
for {
Expand All @@ -149,7 +149,7 @@ func (c *ACLCommand) List(ctx context.Context, client *authclient.Client) error
}

// Get will display information about an access list visible to the user.
func (c *ACLCommand) Get(ctx context.Context, client *authclient.Client) error {
func (c *ACLCommand) Get(ctx context.Context, client authclient.ClientI) error {
accessList, err := client.AccessListClient().GetAccessList(ctx, c.accessListName)
if err != nil {
return trace.Wrap(err)
Expand All @@ -159,7 +159,7 @@ func (c *ACLCommand) Get(ctx context.Context, client *authclient.Client) error {
}

// UsersAdd will add a user to an access list.
func (c *ACLCommand) UsersAdd(ctx context.Context, client *authclient.Client) error {
func (c *ACLCommand) UsersAdd(ctx context.Context, client authclient.ClientI) error {
var expires time.Time
if c.expires != "" {
var err error
Expand Down Expand Up @@ -205,7 +205,7 @@ func (c *ACLCommand) UsersAdd(ctx context.Context, client *authclient.Client) er
}

// UsersRemove will remove a user to an access list.
func (c *ACLCommand) UsersRemove(ctx context.Context, client *authclient.Client) error {
func (c *ACLCommand) UsersRemove(ctx context.Context, client authclient.ClientI) error {
err := client.AccessListClient().DeleteAccessListMember(ctx, c.accessListName, c.userName)
if err != nil {
return trace.Wrap(err)
Expand All @@ -217,7 +217,7 @@ func (c *ACLCommand) UsersRemove(ctx context.Context, client *authclient.Client)
}

// UsersList will list the users in an access list.
func (c *ACLCommand) UsersList(ctx context.Context, client *authclient.Client) error {
func (c *ACLCommand) UsersList(ctx context.Context, client authclient.ClientI) error {
var (
allMembers []*accesslist.AccessListMember
nextToken string
Expand Down
2 changes: 1 addition & 1 deletion tool/tctl/common/admin_action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ func runTestCase(t *testing.T, ctx context.Context, client *authclient.Client, t
commandName, err := app.Parse(args)
require.NoError(t, err)

match, err := tc.cliCommand.TryRun(ctx, commandName, func(context.Context) (*authclient.Client, func(context.Context), error) {
match, err := tc.cliCommand.TryRun(ctx, commandName, func(context.Context) (authclient.ClientI, func(context.Context), error) {
return client, func(context.Context) {}, nil
})
require.True(t, match)
Expand Down
12 changes: 6 additions & 6 deletions tool/tctl/common/alert_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (c *AlertCommand) Initialize(app *kingpin.Application, _ *tctlcfg.GlobalCLI

// TryRun takes the CLI command as an argument (like "alerts ls") and executes it.
func (c *AlertCommand) TryRun(ctx context.Context, cmd string, clientFunc commonclient.InitFunc) (match bool, err error) {
var commandFunc func(ctx context.Context, client *authclient.Client) error
var commandFunc func(ctx context.Context, client authclient.ClientI) error
switch cmd {
case c.alertList.FullCommand():
commandFunc = c.List
Expand All @@ -117,7 +117,7 @@ func (c *AlertCommand) TryRun(ctx context.Context, cmd string, clientFunc common
return true, trace.Wrap(err)
}

func (c *AlertCommand) ListAck(ctx context.Context, client *authclient.Client) error {
func (c *AlertCommand) ListAck(ctx context.Context, client authclient.ClientI) error {
acks, err := client.GetAlertAcks(ctx)
if err != nil {
return trace.Wrap(err)
Expand All @@ -135,7 +135,7 @@ func (c *AlertCommand) ListAck(ctx context.Context, client *authclient.Client) e
return nil
}

func (c *AlertCommand) Ack(ctx context.Context, client *authclient.Client) error {
func (c *AlertCommand) Ack(ctx context.Context, client authclient.ClientI) error {
if c.clear {
return c.ClearAck(ctx, client)
}
Expand Down Expand Up @@ -164,7 +164,7 @@ func (c *AlertCommand) Ack(ctx context.Context, client *authclient.Client) error
return nil
}

func (c *AlertCommand) ClearAck(ctx context.Context, client *authclient.Client) error {
func (c *AlertCommand) ClearAck(ctx context.Context, client authclient.ClientI) error {
req := proto.ClearAlertAcksRequest{
AlertID: c.alertID,
}
Expand All @@ -178,7 +178,7 @@ func (c *AlertCommand) ClearAck(ctx context.Context, client *authclient.Client)
return nil
}

func (c *AlertCommand) List(ctx context.Context, client *authclient.Client) error {
func (c *AlertCommand) List(ctx context.Context, client authclient.ClientI) error {
labels, err := libclient.ParseLabelSpec(c.labels)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -269,7 +269,7 @@ func displayAlertsJSON(alerts []types.ClusterAlert) error {
return nil
}

func (c *AlertCommand) Create(ctx context.Context, client *authclient.Client) error {
func (c *AlertCommand) Create(ctx context.Context, client authclient.ClientI) error {
labels, err := libclient.ParseLabelSpec(c.labels)
if err != nil {
return trace.Wrap(err)
Expand Down
4 changes: 2 additions & 2 deletions tool/tctl/common/app_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (c *AppsCommand) Initialize(app *kingpin.Application, _ *tctlcfg.GlobalCLIF

// TryRun attempts to run subcommands like "apps ls".
func (c *AppsCommand) TryRun(ctx context.Context, cmd string, clientFunc commonclient.InitFunc) (match bool, err error) {
var commandFunc func(ctx context.Context, client *authclient.Client) error
var commandFunc func(ctx context.Context, client authclient.ClientI) error
switch cmd {
case c.appsList.FullCommand():
commandFunc = c.ListApps
Expand All @@ -90,7 +90,7 @@ func (c *AppsCommand) TryRun(ctx context.Context, cmd string, clientFunc commonc

// ListApps prints the list of applications that have recently sent heartbeats
// to the cluster.
func (c *AppsCommand) ListApps(ctx context.Context, clt *authclient.Client) error {
func (c *AppsCommand) ListApps(ctx context.Context, clt authclient.ClientI) error {
labels, err := libclient.ParseLabelSpec(c.labels)
if err != nil {
return trace.Wrap(err)
Expand Down
Loading
Loading