Skip to content

Commit

Permalink
force re-auth if cached token is invalid
Browse files Browse the repository at this point in the history
  • Loading branch information
ccapurso committed Jan 23, 2024
1 parent cd63071 commit 98eb903
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 112 deletions.
122 changes: 69 additions & 53 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ import (
"errors"
"flag"
"fmt"
"net/http"
"strings"

hcprmm "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/models"
hcpvsm "github.com/hashicorp/hcp-sdk-go/clients/cloud-vault-service/stable/2020-11-25/models"
"github.com/hashicorp/cli"
hcpis "github.com/hashicorp/hcp-sdk-go/clients/cloud-iam/stable/2019-12-10/client/iam_service"
hcprmo "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/client/organization_service"
hcprmp "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/client/project_service"
hcprmm "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/models"
hcpvs "github.com/hashicorp/hcp-sdk-go/clients/cloud-vault-service/stable/2020-11-25/client/vault_service"
hcpis "github.com/hashicorp/hcp-sdk-go/clients/cloud-iam/stable/2019-12-10/client/iam_service"
hcpvsm "github.com/hashicorp/hcp-sdk-go/clients/cloud-vault-service/stable/2020-11-25/models"
"github.com/hashicorp/hcp-sdk-go/config"
"github.com/hashicorp/hcp-sdk-go/httpclient"
)
Expand Down Expand Up @@ -72,36 +73,13 @@ func (c *HCPConnectCommand) Run(args []string) int {
return 1
}

hcpCfg, err := c.setupClients(nil)
err := c.setupClients()
if err != nil {
c.Ui.Error(err.Error())
return 1
}


resp, err := c.iamClient.IamServiceGetCallerIdentity(hcpis.NewIamServiceGetCallerIdentityParams().WithDefaults(), nil)

// force re-auth in case where cached token is invalid
if err != nil || resp == nil || resp.Payload == nil || resp.Payload.Principal == nil {
err = hcpCfg.Logout()
if err != nil {
c.Ui.Error(fmt.Sprintf("Failed to erase HCP credentials cache: %s", err))
return 1
}

// attempt to re-establish credentials now that cache has been cleared
c.rmOrgClient = nil
c.vsClient = nil
c.rmProjClient = nil
c.iamClient = nil

if _, err := c.setupClients(hcpCfg); err != nil {
c.Ui.Error(err.Error())
return 1
}
}

proxyAddr, err := c.getProxyAddr(c.rmOrgClient, c.rmProjClient, c.vsClient)
proxyAddr, err := c.getProxyAddr()
if err != nil {
if errors.Is(err, ErrorProxyDisabled) {
c.Ui.Error("\nFailed to connect to HCP Vault Cluster: HTTP proxy feature not enabled.")
Expand All @@ -121,11 +99,7 @@ func (c *HCPConnectCommand) Run(args []string) int {
return 0
}

func (c *HCPConnectCommand) setupClients(cfg config.HCPConfig) (config.HCPConfig, error) {
if c.rmOrgClient != nil && c.rmProjClient != nil && c.vsClient != nil {
return cfg, nil
}

func (c *HCPConnectCommand) setupClients() error {
opts := []config.HCPConfigOption{config.FromEnv()}
if c.flagClientID != "" && c.flagSecretID != "" {
opts = append(opts, config.WithClientCredentials(c.flagClientID, c.flagSecretID))
Expand All @@ -134,30 +108,72 @@ func (c *HCPConnectCommand) setupClients(cfg config.HCPConfig) (config.HCPConfig

cfg, err := config.NewHCPConfig(opts...)
if err != nil {
return cfg, errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err))
return errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err))
}

hcpHttpClient, err := httpclient.New(httpclient.Config{HCPConfig: cfg})
if err != nil {
return cfg, errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err))
return errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err))
}

// client should only be pre-populated for testing
if c.iamClient == nil {
c.iamClient = hcpis.New(hcpHttpClient, nil)
}

c.rmOrgClient = hcprmo.New(hcpHttpClient, nil)
c.rmProjClient = hcprmp.New(hcpHttpClient, nil)
c.vsClient = hcpvs.New(hcpHttpClient, nil)
c.iamClient = hcpis.New(hcpHttpClient, nil)
// verify token is valid
resp, err := c.iamClient.IamServiceGetCallerIdentity(hcpis.NewIamServiceGetCallerIdentityParams().WithDefaults(), nil)
if err != nil {
if identErr, ok := err.(*hcpis.IamServiceGetCallerIdentityDefault); ok && !identErr.IsCode(http.StatusUnauthorized) {
return errors.New(fmt.Sprintf("Failed to get HCP caller identity: %s", err))
}
}

// force re-auth in case where cached token is invalid
if resp == nil || resp.Payload == nil || resp.Payload.Principal == nil {
err = cfg.Logout()
if err != nil {
return errors.New(fmt.Sprintf("Failed to erase HCP credentials cache: %s", err))
}
cfg, err = config.NewHCPConfig(opts...)
if err != nil {
return errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err))
}

hcpHttpClient, err = httpclient.New(httpclient.Config{HCPConfig: cfg})
if err != nil {
return errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err))
}
}

// clients should only be pre-populated for testing
if c.iamClient == nil {
c.iamClient = hcpis.New(hcpHttpClient, nil)
}

if c.rmOrgClient == nil {
c.rmOrgClient = hcprmo.New(hcpHttpClient, nil)
}

if c.vsClient == nil {
c.vsClient = hcpvs.New(hcpHttpClient, nil)
}

if c.rmProjClient == nil {
c.rmProjClient = hcprmp.New(hcpHttpClient, nil)
}

return cfg, nil
return nil
}

func (c *HCPConnectCommand) getProxyAddr(organizationClient hcprmo.ClientService, projectClient hcprmp.ClientService, clusterClient hcpvs.ClientService) (string, error) {
func (c *HCPConnectCommand) getProxyAddr() (string, error) {
var err error

var organizationID string
if c.flagOrganizationID != "" {
organizationID = c.flagOrganizationID
} else {
organizationID, err = c.getOrganization(organizationClient)
organizationID, err = c.getOrganization()
if err != nil {
return "", errors.New(fmt.Sprintf("Failed to get HCP organization information: %s", err))
}
Expand All @@ -167,13 +183,13 @@ func (c *HCPConnectCommand) getProxyAddr(organizationClient hcprmo.ClientService
if c.flagProjectID != "" {
projectID = c.flagProjectID
} else {
projectID, err = c.getProject(organizationID, projectClient)
projectID, err = c.getProject(organizationID)
if err != nil {
return "", errors.New(fmt.Sprintf("Failed to get HCP project information: %s", err))
}
}

proxyAddr, err := c.getCluster(organizationID, projectID, c.flagClusterID, clusterClient)
proxyAddr, err := c.getCluster(organizationID, projectID, c.flagClusterID)
if err != nil {
return "", err
}
Expand All @@ -196,8 +212,8 @@ func (c *HCPConnectCommand) Flags() *flag.FlagSet {
return mainSet
}

func (c *HCPConnectCommand) getOrganization(rmOrgClient hcprmo.ClientService) (organizationID string, err error) {
organizationsResp, err := rmOrgClient.OrganizationServiceList(hcprmo.NewOrganizationServiceListParams().WithDefaults(), nil)
func (c *HCPConnectCommand) getOrganization() (organizationID string, err error) {
organizationsResp, err := c.rmOrgClient.OrganizationServiceList(hcprmo.NewOrganizationServiceListParams().WithDefaults(), nil)
switch {
case err != nil:
return "", err
Expand Down Expand Up @@ -236,14 +252,14 @@ func (c *HCPConnectCommand) getOrganization(rmOrgClient hcprmo.ClientService) (o
}
}

func (c *HCPConnectCommand) getProject(organizationID string, rmProjClient hcprmp.ClientService) (projectID string, err error) {
func (c *HCPConnectCommand) getProject(organizationID string) (projectID string, err error) {
scopeType := "ORGANIZATION"
projectListReq := hcprmp.
NewProjectServiceListParams().
WithDefaults().
WithScopeType(&scopeType).
WithScopeID(&organizationID)
projectResp, err := rmProjClient.ProjectServiceList(projectListReq, nil)
projectResp, err := c.rmProjClient.ProjectServiceList(projectListReq, nil)
switch {
case err != nil:
return "", err
Expand Down Expand Up @@ -282,17 +298,17 @@ func (c *HCPConnectCommand) getProject(organizationID string, rmProjClient hcprm
}
}

func (c *HCPConnectCommand) getCluster(organizationID string, projectID string, clusterID string, vsClient hcpvs.ClientService) (proxyAddr string, err error) {
func (c *HCPConnectCommand) getCluster(organizationID string, projectID string, clusterID string) (proxyAddr string, err error) {
if clusterID == "" {
return c.listClusters(organizationID, projectID, vsClient)
return c.listClusters(organizationID, projectID)
}

clusterGetReq := hcpvs.NewGetParams().
WithDefaults().
WithLocationOrganizationID(organizationID).
WithLocationProjectID(projectID).
WithClusterID(clusterID)
clusterResp, err := vsClient.Get(clusterGetReq, nil)
clusterResp, err := c.vsClient.Get(clusterGetReq, nil)
switch {
case err != nil:
return "", err
Expand All @@ -314,14 +330,14 @@ func (c *HCPConnectCommand) getCluster(organizationID string, projectID string,
}
}

func (c *HCPConnectCommand) listClusters(organizationID string, projectID string, vsClient hcpvs.ClientService) (proxyAddr string, err error) {
func (c *HCPConnectCommand) listClusters(organizationID string, projectID string) (proxyAddr string, err error) {
clusterListReq := hcpvs.NewListParams().
WithDefaults().
WithLocationOrganizationID(organizationID).
WithLocationProjectID(projectID)

// Purposely calling List instead of ListAll because we are only interested in HVD clusters.
clustersResp, err := vsClient.List(clusterListReq, nil)
clustersResp, err := c.vsClient.List(clusterListReq, nil)
switch {
case err != nil:
return "", err
Expand Down
Loading

0 comments on commit 98eb903

Please sign in to comment.