diff --git a/cmd/setup-gh.go b/cmd/setup-gh.go index 3c5efeea..7dfbf2d9 100644 --- a/cmd/setup-gh.go +++ b/cmd/setup-gh.go @@ -1,19 +1,19 @@ package cmd import ( + "context" "errors" "fmt" - "strings" - "github.com/manifoldco/promptui" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "strings" "github.com/Azure/draft/pkg/providers" "github.com/Azure/draft/pkg/spinner" ) -func newSetUpCmd() *cobra.Command { +func newSetUpCmd(ctx context.Context) *cobra.Command { sc := &providers.SetUpCmd{} // setup-ghCmd represents the setup-gh command @@ -27,7 +27,7 @@ application and service principle, and will configure that application to trust s := spinner.CreateSpinner("--> Setting up Github OIDC...") s.Start() - err := runProviderSetUp(sc, s) + err := runProviderSetUp(ctx, sc, s) s.Stop() if err != nil { return err @@ -72,11 +72,11 @@ func fillSetUpConfig(sc *providers.SetUpCmd) { } } -func runProviderSetUp(sc *providers.SetUpCmd, s spinner.Spinner) error { +func runProviderSetUp(ctx context.Context, sc *providers.SetUpCmd, s spinner.Spinner) error { provider := strings.ToLower(sc.Provider) if provider == "azure" { // call azure provider logic - return providers.InitiateAzureOIDCFlow(sc, s) + return providers.InitiateAzureOIDCFlow(ctx, sc, s) } else { // call logic for user-submitted provider @@ -203,6 +203,6 @@ func GetAzSubscriptionId(subIds []string) string { } func init() { - rootCmd.AddCommand(newSetUpCmd()) - + ctx := context.Background() + rootCmd.AddCommand(newSetUpCmd(ctx)) } diff --git a/pkg/providers/azure.go b/pkg/providers/azure.go index 5b5c8a75..5c294173 100644 --- a/pkg/providers/azure.go +++ b/pkg/providers/azure.go @@ -26,10 +26,9 @@ type SetUpCmd struct { tenantId string appObjectId string spObjectId string - ctx context.Context } -func InitiateAzureOIDCFlow(sc *SetUpCmd, s spinner.Spinner) error { +func InitiateAzureOIDCFlow(ctx context.Context, sc *SetUpCmd, s spinner.Spinner) error { log.Debug("Commencing github connection with azure...") if !HasGhCli() || !IsLoggedInToGh() { @@ -54,7 +53,7 @@ func InitiateAzureOIDCFlow(sc *SetUpCmd, s spinner.Spinner) error { return err } - if err := sc.getTenantId(); err != nil { + if err := sc.getTenantId(ctx); err != nil { return err } @@ -180,10 +179,10 @@ func (sc *SetUpCmd) assignSpRole() error { return nil } -func (sc *SetUpCmd) getTenantId() error { +func (sc *SetUpCmd) getTenantId(ctx context.Context) error { log.Debug("getting Azure tenant ID") - tenants, err := ListTenants(sc.ctx) + tenants, err := ListTenants(ctx) if err != nil { return fmt.Errorf("listing tenants: %w", err) } @@ -224,7 +223,6 @@ func ListTenants(ctx context.Context) ([]armsubscription.TenantIDDescription, er if t == nil { return nil, errors.New("nil tenant") // this should never happen but it's good to check just in case } - tenants = append(tenants, *t) } }