Skip to content

Commit

Permalink
Force re-auth if cached credentials are invalid (#34)
Browse files Browse the repository at this point in the history
* force re-auth if cached creds are invalid

* add IAM service mocks

* force re-auth if cached token is invalid

* improve error message if cluster ID does not exist

* change errors.New to fmt.Errorf

* check cmd output in connect test
  • Loading branch information
ccapurso authored Jan 26, 2024
1 parent cfd994b commit 473e9a4
Show file tree
Hide file tree
Showing 4 changed files with 693 additions and 123 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ mocks:
go install github.com/vektra/mockery/[email protected]
mockery --srcpkg github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/client/organization_service --name=ClientService
mockery --srcpkg github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/client/project_service --name=ClientService
mockery --srcpkg github.com/hashicorp/hcp-sdk-go/clients/cloud-iam/stable/2019-12-10/client/iam_service --name=ClientService
150 changes: 96 additions & 54 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,24 @@ import (
"errors"
"flag"
"fmt"
"net/http"
"strings"

hcprmm "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/models"
hcpvsm "github.com/hashicorp/hcp-sdk-go/clients/cloud-vault-service/stable/2020-11-25/models"

"github.com/hashicorp/cli"
hcpis "github.com/hashicorp/hcp-sdk-go/clients/cloud-iam/stable/2019-12-10/client/iam_service"
hcprmo "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/client/organization_service"
hcprmp "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/client/project_service"
hcprmm "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/models"
hcpvs "github.com/hashicorp/hcp-sdk-go/clients/cloud-vault-service/stable/2020-11-25/client/vault_service"
hcpvsm "github.com/hashicorp/hcp-sdk-go/clients/cloud-vault-service/stable/2020-11-25/models"
"github.com/hashicorp/hcp-sdk-go/config"
"github.com/hashicorp/hcp-sdk-go/httpclient"
)

var (
_ cli.Command = (*HCPConnectCommand)(nil)

ErrorProxyDisabled = errors.New("proxy is disabled")
ErrorProxyDisabled = fmt.Errorf("proxy is disabled")
)

type HCPConnectCommand struct {
Expand All @@ -39,6 +40,7 @@ type HCPConnectCommand struct {
rmOrgClient hcprmo.ClientService
vsClient hcpvs.ClientService
rmProjClient hcprmp.ClientService
iamClient hcpis.ClientService
}

func (c *HCPConnectCommand) Help() string {
Expand Down Expand Up @@ -71,12 +73,13 @@ func (c *HCPConnectCommand) Run(args []string) int {
return 1
}

if err := c.setupClients(); err != nil {
err := c.setupClients()
if err != nil {
c.Ui.Error(err.Error())
return 1
}

proxyAddr, err := c.getProxyAddr(c.rmOrgClient, c.rmProjClient, c.vsClient)
proxyAddr, err := c.getProxyAddr()
if err != nil {
if errors.Is(err, ErrorProxyDisabled) {
c.Ui.Error("\nFailed to connect to HCP Vault Cluster: HTTP proxy feature not enabled.")
Expand All @@ -97,62 +100,101 @@ func (c *HCPConnectCommand) Run(args []string) int {
}

func (c *HCPConnectCommand) setupClients() error {
var opts []config.HCPConfigOption

if c.rmOrgClient == nil && c.rmProjClient == nil && c.vsClient == nil {
opts = []config.HCPConfigOption{config.FromEnv()}

if c.flagClientID != "" && c.flagSecretID == "" {
return errors.New("secret-id is required when client-id is provided")
} else if c.flagSecretID != "" && c.flagClientID == "" {
return errors.New("client-id is required when secret-id is provided")
} else if c.flagClientID != "" && c.flagSecretID != "" {
opts = append(opts, config.WithClientCredentials(c.flagClientID, c.flagSecretID))
opts = append(opts, config.WithoutBrowserLogin())
opts := []config.HCPConfigOption{config.FromEnv()}

if c.flagClientID != "" && c.flagSecretID == "" {
return fmt.Errorf("secret-id is required when client-id is provided")
} else if c.flagSecretID != "" && c.flagClientID == "" {
return fmt.Errorf("client-id is required when secret-id is provided")
} else if c.flagClientID != "" && c.flagSecretID != "" {
opts = append(opts, config.WithClientCredentials(c.flagClientID, c.flagSecretID))
opts = append(opts, config.WithoutBrowserLogin())
}

cfg, err := config.NewHCPConfig(opts...)
if err != nil {
return fmt.Errorf("failed to connect to HCP: %w", err)
}

hcpHttpClient, err := httpclient.New(httpclient.Config{HCPConfig: cfg})
if err != nil {
return fmt.Errorf("failed to connect to HCP: %w", err)
}

// client should only be pre-populated for testing
if c.iamClient == nil {
c.iamClient = hcpis.New(hcpHttpClient, nil)
}

// verify token is valid
resp, err := c.iamClient.IamServiceGetCallerIdentity(hcpis.NewIamServiceGetCallerIdentityParams().WithDefaults(), nil)
if err != nil {
if identErr, ok := err.(*hcpis.IamServiceGetCallerIdentityDefault); ok && !identErr.IsCode(http.StatusUnauthorized) {
return fmt.Errorf("failed to get HCP caller identity: %w", err)
}
}

cfg, err := config.NewHCPConfig(opts...)
// force re-auth in case where cached token is invalid
if resp == nil || resp.Payload == nil || resp.Payload.Principal == nil {
err = cfg.Logout()
if err != nil {
return fmt.Errorf("failed to erase HCP credentials cache: %w", err)
}
cfg, err = config.NewHCPConfig(opts...)
if err != nil {
return errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err))
return fmt.Errorf("failed to connect to HCP: %w", err)
}

hcpHttpClient, err := httpclient.New(httpclient.Config{HCPConfig: cfg})
hcpHttpClient, err = httpclient.New(httpclient.Config{HCPConfig: cfg})
if err != nil {
return errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err))
return fmt.Errorf("failed to connect to HCP: %w", err)
}
}

// clients should only be pre-populated for testing
if c.iamClient == nil {
c.iamClient = hcpis.New(hcpHttpClient, nil)
}

if c.rmOrgClient == nil {
c.rmOrgClient = hcprmo.New(hcpHttpClient, nil)
c.rmProjClient = hcprmp.New(hcpHttpClient, nil)
}

if c.vsClient == nil {
c.vsClient = hcpvs.New(hcpHttpClient, nil)
}

if c.rmProjClient == nil {
c.rmProjClient = hcprmp.New(hcpHttpClient, nil)
}

return nil
}

func (c *HCPConnectCommand) getProxyAddr(organizationClient hcprmo.ClientService, projectClient hcprmp.ClientService, clusterClient hcpvs.ClientService) (string, error) {
func (c *HCPConnectCommand) getProxyAddr() (string, error) {
var err error

var organizationID string
if c.flagOrganizationID != "" {
organizationID = c.flagOrganizationID
} else {
organizationID, err = c.getOrganization(organizationClient)
organizationID, err = c.getOrganization()
if err != nil {
return "", errors.New(fmt.Sprintf("Failed to get HCP organization information: %s", err))
return "", fmt.Errorf("failed to get HCP organization information: %w", err)
}
}

var projectID string
if c.flagProjectID != "" {
projectID = c.flagProjectID
} else {
projectID, err = c.getProject(organizationID, projectClient)
projectID, err = c.getProject(organizationID)
if err != nil {
return "", errors.New(fmt.Sprintf("Failed to get HCP project information: %s", err))
return "", fmt.Errorf("failed to get HCP project information: %w", err)
}
}

proxyAddr, err := c.getCluster(organizationID, projectID, c.flagClusterID, clusterClient)
proxyAddr, err := c.getCluster(organizationID, projectID, c.flagClusterID)
if err != nil {
return "", err
}
Expand All @@ -175,15 +217,15 @@ func (c *HCPConnectCommand) Flags() *flag.FlagSet {
return mainSet
}

func (c *HCPConnectCommand) getOrganization(rmOrgClient hcprmo.ClientService) (organizationID string, err error) {
organizationsResp, err := rmOrgClient.OrganizationServiceList(hcprmo.NewOrganizationServiceListParams().WithDefaults(), nil)
func (c *HCPConnectCommand) getOrganization() (organizationID string, err error) {
organizationsResp, err := c.rmOrgClient.OrganizationServiceList(hcprmo.NewOrganizationServiceListParams().WithDefaults(), nil)
switch {
case err != nil:
return "", err
case organizationsResp.GetPayload() == nil:
return "", errors.New("payload is nil")
return "", fmt.Errorf("payload is nil")
case len(organizationsResp.GetPayload().Organizations) < 1:
return "", errors.New("no organizations available")
return "", fmt.Errorf("no organizations available")
case len(organizationsResp.GetPayload().Organizations) > 1:
title := "Available organizations:"
u := strings.Repeat("-", len(title))
Expand All @@ -203,33 +245,33 @@ func (c *HCPConnectCommand) getOrganization(rmOrgClient hcprmo.ClientService) (o
}
chosenOrg, ok := orgs[userInput]
if !ok {
return "", errors.New(fmt.Sprintf("invalid HCP organization: %s", userInput))
return "", fmt.Errorf("invalid HCP organization: %s", userInput)
}
return chosenOrg.ID, nil
default:
organization := organizationsResp.GetPayload().Organizations[0]
if *organization.State != hcprmm.HashicorpCloudResourcemanagerOrganizationOrganizationStateACTIVE {
return "", errors.New("organization is not active")
return "", fmt.Errorf("organization is not active")
}
return organization.ID, nil
}
}

func (c *HCPConnectCommand) getProject(organizationID string, rmProjClient hcprmp.ClientService) (projectID string, err error) {
func (c *HCPConnectCommand) getProject(organizationID string) (projectID string, err error) {
scopeType := "ORGANIZATION"
projectListReq := hcprmp.
NewProjectServiceListParams().
WithDefaults().
WithScopeType(&scopeType).
WithScopeID(&organizationID)
projectResp, err := rmProjClient.ProjectServiceList(projectListReq, nil)
projectResp, err := c.rmProjClient.ProjectServiceList(projectListReq, nil)
switch {
case err != nil:
return "", err
case projectResp.GetPayload() == nil:
return "", errors.New("payload is nil")
return "", fmt.Errorf("payload is nil")
case len(projectResp.GetPayload().Projects) < 1:
return "", errors.New("no projects available")
return "", fmt.Errorf("no projects available")
case len(projectResp.GetPayload().Projects) > 1:
title := "Available projects:"
u := strings.Repeat("-", len(title))
Expand All @@ -249,34 +291,34 @@ func (c *HCPConnectCommand) getProject(organizationID string, rmProjClient hcprm
}
chosenProj, ok := projs[userInput]
if !ok {
return "", errors.New(fmt.Sprintf("invalid HCP project: %s", userInput))
return "", fmt.Errorf("invalid HCP project: %s", userInput)
}
return chosenProj.ID, nil
default:
project := projectResp.GetPayload().Projects[0]
if *project.State != hcprmm.HashicorpCloudResourcemanagerProjectProjectStateACTIVE {
return "", errors.New("project is not active")
return "", fmt.Errorf("project is not active")
}
return project.ID, nil
}
}

func (c *HCPConnectCommand) getCluster(organizationID string, projectID string, clusterID string, vsClient hcpvs.ClientService) (proxyAddr string, err error) {
func (c *HCPConnectCommand) getCluster(organizationID string, projectID string, clusterID string) (proxyAddr string, err error) {
if clusterID == "" {
return c.listClusters(organizationID, projectID, vsClient)
return c.listClusters(organizationID, projectID)
}

clusterGetReq := hcpvs.NewGetParams().
WithDefaults().
WithLocationOrganizationID(organizationID).
WithLocationProjectID(projectID).
WithClusterID(clusterID)
clusterResp, err := vsClient.Get(clusterGetReq, nil)
clusterResp, err := c.vsClient.Get(clusterGetReq, nil)
switch {
case err != nil:
return "", err
return "", fmt.Errorf("failed to get cluster %s: %s", clusterID, err)
case clusterResp.GetPayload() == nil:
return "", errors.New("payload is nil")
return "", fmt.Errorf("payload is nil")
default:
cluster := clusterResp.GetPayload().Cluster

Expand All @@ -293,21 +335,21 @@ func (c *HCPConnectCommand) getCluster(organizationID string, projectID string,
}
}

func (c *HCPConnectCommand) listClusters(organizationID string, projectID string, vsClient hcpvs.ClientService) (proxyAddr string, err error) {
func (c *HCPConnectCommand) listClusters(organizationID string, projectID string) (proxyAddr string, err error) {
clusterListReq := hcpvs.NewListParams().
WithDefaults().
WithLocationOrganizationID(organizationID).
WithLocationProjectID(projectID)

// Purposely calling List instead of ListAll because we are only interested in HVD clusters.
clustersResp, err := vsClient.List(clusterListReq, nil)
clustersResp, err := c.vsClient.List(clusterListReq, nil)
switch {
case err != nil:
return "", err
case clustersResp.GetPayload() == nil:
return "", errors.New("payload is nil")
return "", fmt.Errorf("payload is nil")
case len(clustersResp.GetPayload().Clusters) < 1:
return "", errors.New("no clusters available")
return "", fmt.Errorf("no clusters available")
case len(clustersResp.GetPayload().Clusters) > 1:
title := "Available clusters:"
u := strings.Repeat("-", len(title))
Expand All @@ -330,7 +372,7 @@ func (c *HCPConnectCommand) listClusters(organizationID string, projectID string
// set the cluster
cluster, ok := clusters[userInput]
if !ok {
return "", errors.New(fmt.Sprintf("invalid cluster: %s", userInput))
return "", fmt.Errorf("invalid cluster: %s", userInput)
}
if *cluster.Config.NetworkConfig.HTTPProxyOption == hcpvsm.HashicorpCloudVault20201125HTTPProxyOptionDISABLED {
return "", ErrorProxyDisabled
Expand All @@ -345,11 +387,11 @@ func (c *HCPConnectCommand) listClusters(organizationID string, projectID string
clusterState := *cluster.State

if clusterState == hcpvsm.HashicorpCloudVault20201125ClusterStateLOCKED || clusterState == hcpvsm.HashicorpCloudVault20201125ClusterStateLOCKING {
return "", errors.New("cluster is locked")
return "", fmt.Errorf("cluster is locked")
} else if clusterState == hcpvsm.HashicorpCloudVault20201125ClusterStateCREATING {
return "", errors.New("cluster is still being created")
return "", fmt.Errorf("cluster is still being created")
} else if clusterState != hcpvsm.HashicorpCloudVault20201125ClusterStateRUNNING {
return "", errors.New("cluster is not running")
return "", fmt.Errorf("cluster is not running")
}

if *cluster.Config.NetworkConfig.HTTPProxyOption == hcpvsm.HashicorpCloudVault20201125HTTPProxyOptionDISABLED {
Expand Down
Loading

0 comments on commit 473e9a4

Please sign in to comment.