From 98eb90327b772d98d01588ab144104dd79da085e Mon Sep 17 00:00:00 2001 From: Chris Capurso <1036769+ccapurso@users.noreply.github.com> Date: Tue, 23 Jan 2024 15:27:17 -0500 Subject: [PATCH] force re-auth if cached token is invalid --- connect.go | 122 ++++++++++++++++------------- connect_test.go | 199 ++++++++++++++++++++++++++++++++++-------------- 2 files changed, 209 insertions(+), 112 deletions(-) diff --git a/connect.go b/connect.go index b7b29e5..edee475 100644 --- a/connect.go +++ b/connect.go @@ -7,15 +7,16 @@ import ( "errors" "flag" "fmt" + "net/http" "strings" - hcprmm "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/models" - hcpvsm "github.com/hashicorp/hcp-sdk-go/clients/cloud-vault-service/stable/2020-11-25/models" "github.com/hashicorp/cli" + hcpis "github.com/hashicorp/hcp-sdk-go/clients/cloud-iam/stable/2019-12-10/client/iam_service" hcprmo "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/client/organization_service" hcprmp "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/client/project_service" + hcprmm "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/models" hcpvs "github.com/hashicorp/hcp-sdk-go/clients/cloud-vault-service/stable/2020-11-25/client/vault_service" - hcpis "github.com/hashicorp/hcp-sdk-go/clients/cloud-iam/stable/2019-12-10/client/iam_service" + hcpvsm "github.com/hashicorp/hcp-sdk-go/clients/cloud-vault-service/stable/2020-11-25/models" "github.com/hashicorp/hcp-sdk-go/config" "github.com/hashicorp/hcp-sdk-go/httpclient" ) @@ -72,36 +73,13 @@ func (c *HCPConnectCommand) Run(args []string) int { return 1 } - hcpCfg, err := c.setupClients(nil) + err := c.setupClients() if err != nil { c.Ui.Error(err.Error()) return 1 } - - resp, err := c.iamClient.IamServiceGetCallerIdentity(hcpis.NewIamServiceGetCallerIdentityParams().WithDefaults(), nil) - - // force re-auth in case where cached token is invalid - if err != nil || resp == nil || resp.Payload == nil || resp.Payload.Principal == nil { - err = hcpCfg.Logout() - if err != nil { - c.Ui.Error(fmt.Sprintf("Failed to erase HCP credentials cache: %s", err)) - return 1 - } - - // attempt to re-establish credentials now that cache has been cleared - c.rmOrgClient = nil - c.vsClient = nil - c.rmProjClient = nil - c.iamClient = nil - - if _, err := c.setupClients(hcpCfg); err != nil { - c.Ui.Error(err.Error()) - return 1 - } - } - - proxyAddr, err := c.getProxyAddr(c.rmOrgClient, c.rmProjClient, c.vsClient) + proxyAddr, err := c.getProxyAddr() if err != nil { if errors.Is(err, ErrorProxyDisabled) { c.Ui.Error("\nFailed to connect to HCP Vault Cluster: HTTP proxy feature not enabled.") @@ -121,11 +99,7 @@ func (c *HCPConnectCommand) Run(args []string) int { return 0 } -func (c *HCPConnectCommand) setupClients(cfg config.HCPConfig) (config.HCPConfig, error) { - if c.rmOrgClient != nil && c.rmProjClient != nil && c.vsClient != nil { - return cfg, nil - } - +func (c *HCPConnectCommand) setupClients() error { opts := []config.HCPConfigOption{config.FromEnv()} if c.flagClientID != "" && c.flagSecretID != "" { opts = append(opts, config.WithClientCredentials(c.flagClientID, c.flagSecretID)) @@ -134,30 +108,72 @@ func (c *HCPConnectCommand) setupClients(cfg config.HCPConfig) (config.HCPConfig cfg, err := config.NewHCPConfig(opts...) if err != nil { - return cfg, errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err)) + return errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err)) } hcpHttpClient, err := httpclient.New(httpclient.Config{HCPConfig: cfg}) if err != nil { - return cfg, errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err)) + return errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err)) + } + + // client should only be pre-populated for testing + if c.iamClient == nil { + c.iamClient = hcpis.New(hcpHttpClient, nil) } - c.rmOrgClient = hcprmo.New(hcpHttpClient, nil) - c.rmProjClient = hcprmp.New(hcpHttpClient, nil) - c.vsClient = hcpvs.New(hcpHttpClient, nil) - c.iamClient = hcpis.New(hcpHttpClient, nil) + // verify token is valid + resp, err := c.iamClient.IamServiceGetCallerIdentity(hcpis.NewIamServiceGetCallerIdentityParams().WithDefaults(), nil) + if err != nil { + if identErr, ok := err.(*hcpis.IamServiceGetCallerIdentityDefault); ok && !identErr.IsCode(http.StatusUnauthorized) { + return errors.New(fmt.Sprintf("Failed to get HCP caller identity: %s", err)) + } + } + + // force re-auth in case where cached token is invalid + if resp == nil || resp.Payload == nil || resp.Payload.Principal == nil { + err = cfg.Logout() + if err != nil { + return errors.New(fmt.Sprintf("Failed to erase HCP credentials cache: %s", err)) + } + cfg, err = config.NewHCPConfig(opts...) + if err != nil { + return errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err)) + } + + hcpHttpClient, err = httpclient.New(httpclient.Config{HCPConfig: cfg}) + if err != nil { + return errors.New(fmt.Sprintf("Failed to connect to HCP: %s", err)) + } + } + + // clients should only be pre-populated for testing + if c.iamClient == nil { + c.iamClient = hcpis.New(hcpHttpClient, nil) + } + + if c.rmOrgClient == nil { + c.rmOrgClient = hcprmo.New(hcpHttpClient, nil) + } + + if c.vsClient == nil { + c.vsClient = hcpvs.New(hcpHttpClient, nil) + } + + if c.rmProjClient == nil { + c.rmProjClient = hcprmp.New(hcpHttpClient, nil) + } - return cfg, nil + return nil } -func (c *HCPConnectCommand) getProxyAddr(organizationClient hcprmo.ClientService, projectClient hcprmp.ClientService, clusterClient hcpvs.ClientService) (string, error) { +func (c *HCPConnectCommand) getProxyAddr() (string, error) { var err error var organizationID string if c.flagOrganizationID != "" { organizationID = c.flagOrganizationID } else { - organizationID, err = c.getOrganization(organizationClient) + organizationID, err = c.getOrganization() if err != nil { return "", errors.New(fmt.Sprintf("Failed to get HCP organization information: %s", err)) } @@ -167,13 +183,13 @@ func (c *HCPConnectCommand) getProxyAddr(organizationClient hcprmo.ClientService if c.flagProjectID != "" { projectID = c.flagProjectID } else { - projectID, err = c.getProject(organizationID, projectClient) + projectID, err = c.getProject(organizationID) if err != nil { return "", errors.New(fmt.Sprintf("Failed to get HCP project information: %s", err)) } } - proxyAddr, err := c.getCluster(organizationID, projectID, c.flagClusterID, clusterClient) + proxyAddr, err := c.getCluster(organizationID, projectID, c.flagClusterID) if err != nil { return "", err } @@ -196,8 +212,8 @@ func (c *HCPConnectCommand) Flags() *flag.FlagSet { return mainSet } -func (c *HCPConnectCommand) getOrganization(rmOrgClient hcprmo.ClientService) (organizationID string, err error) { - organizationsResp, err := rmOrgClient.OrganizationServiceList(hcprmo.NewOrganizationServiceListParams().WithDefaults(), nil) +func (c *HCPConnectCommand) getOrganization() (organizationID string, err error) { + organizationsResp, err := c.rmOrgClient.OrganizationServiceList(hcprmo.NewOrganizationServiceListParams().WithDefaults(), nil) switch { case err != nil: return "", err @@ -236,14 +252,14 @@ func (c *HCPConnectCommand) getOrganization(rmOrgClient hcprmo.ClientService) (o } } -func (c *HCPConnectCommand) getProject(organizationID string, rmProjClient hcprmp.ClientService) (projectID string, err error) { +func (c *HCPConnectCommand) getProject(organizationID string) (projectID string, err error) { scopeType := "ORGANIZATION" projectListReq := hcprmp. NewProjectServiceListParams(). WithDefaults(). WithScopeType(&scopeType). WithScopeID(&organizationID) - projectResp, err := rmProjClient.ProjectServiceList(projectListReq, nil) + projectResp, err := c.rmProjClient.ProjectServiceList(projectListReq, nil) switch { case err != nil: return "", err @@ -282,9 +298,9 @@ func (c *HCPConnectCommand) getProject(organizationID string, rmProjClient hcprm } } -func (c *HCPConnectCommand) getCluster(organizationID string, projectID string, clusterID string, vsClient hcpvs.ClientService) (proxyAddr string, err error) { +func (c *HCPConnectCommand) getCluster(organizationID string, projectID string, clusterID string) (proxyAddr string, err error) { if clusterID == "" { - return c.listClusters(organizationID, projectID, vsClient) + return c.listClusters(organizationID, projectID) } clusterGetReq := hcpvs.NewGetParams(). @@ -292,7 +308,7 @@ func (c *HCPConnectCommand) getCluster(organizationID string, projectID string, WithLocationOrganizationID(organizationID). WithLocationProjectID(projectID). WithClusterID(clusterID) - clusterResp, err := vsClient.Get(clusterGetReq, nil) + clusterResp, err := c.vsClient.Get(clusterGetReq, nil) switch { case err != nil: return "", err @@ -314,14 +330,14 @@ func (c *HCPConnectCommand) getCluster(organizationID string, projectID string, } } -func (c *HCPConnectCommand) listClusters(organizationID string, projectID string, vsClient hcpvs.ClientService) (proxyAddr string, err error) { +func (c *HCPConnectCommand) listClusters(organizationID string, projectID string) (proxyAddr string, err error) { clusterListReq := hcpvs.NewListParams(). WithDefaults(). WithLocationOrganizationID(organizationID). WithLocationProjectID(projectID) // Purposely calling List instead of ListAll because we are only interested in HVD clusters. - clustersResp, err := vsClient.List(clusterListReq, nil) + clustersResp, err := c.vsClient.List(clusterListReq, nil) switch { case err != nil: return "", err diff --git a/connect_test.go b/connect_test.go index 41501b8..3c44088 100644 --- a/connect_test.go +++ b/connect_test.go @@ -6,15 +6,19 @@ package vaulthcplib import ( "errors" "io" + "net/http" "testing" "github.com/hashicorp/cli" clustermocks "github.com/hashicorp/vault-hcp-lib/mocks/cluster" + iammocks "github.com/hashicorp/vault-hcp-lib/mocks/iam" orgmocks "github.com/hashicorp/vault-hcp-lib/mocks/organization" projmocks "github.com/hashicorp/vault-hcp-lib/mocks/project" "github.com/google/uuid" + hcpis "github.com/hashicorp/hcp-sdk-go/clients/cloud-iam/stable/2019-12-10/client/iam_service" + iam_models "github.com/hashicorp/hcp-sdk-go/clients/cloud-iam/stable/2019-12-10/models" hcprmo "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/client/organization_service" hcprmp "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/client/project_service" "github.com/hashicorp/hcp-sdk-go/clients/cloud-resource-manager/stable/2019-12-10/models" @@ -30,68 +34,136 @@ func testHCPConnectCommand() (*cli.MockUi, *HCPConnectCommand) { } func Test_HCPConnectCommand(t *testing.T) { - _, cmd := testHCPConnectCommand() - - mockRmOrgClient := orgmocks.NewClientService(t) - mockRmOrgClient. - On("OrganizationServiceList", mock.Anything, nil). - Return(&hcprmo.OrganizationServiceListOK{ - Payload: &models.HashicorpCloudResourcemanagerOrganizationListResponse{ - Organizations: []*models.HashicorpCloudResourcemanagerOrganization{ - { - ID: uuid.New().String(), - Name: "mock-organization-1", - State: models.NewHashicorpCloudResourcemanagerOrganizationOrganizationState( - models.HashicorpCloudResourcemanagerOrganizationOrganizationStateACTIVE, - ), + tests := map[string]struct{ + getCallerIdentityResp *hcpis.IamServiceGetCallerIdentityOK + getCallerIdentityErr error + expectedResult int + }{ + "OK resp": { + getCallerIdentityResp: &hcpis.IamServiceGetCallerIdentityOK{ + Payload: &iam_models.HashicorpCloudIamGetCallerIdentityResponse{ + Principal: &iam_models.HashicorpCloudIamPrincipal{ + User: &iam_models.HashicorpCloudIamUserPrincipal{ + Email: "test@test.com", + FullName: "HCP Test", + ID: "test", + Subject: "test", + }, }, }, }, - }, nil) - - mockRmProjClient := projmocks.NewClientService(t) - mockRmProjClient. - On("ProjectServiceList", mock.Anything, nil). - Return(&hcprmp.ProjectServiceListOK{ - Payload: &models.HashicorpCloudResourcemanagerProjectListResponse{ - Projects: []*models.HashicorpCloudResourcemanagerProject{ - { - ID: uuid.New().String(), - Name: "mock-project-1", - State: models.NewHashicorpCloudResourcemanagerProjectProjectState( - models.HashicorpCloudResourcemanagerProjectProjectStateACTIVE, - ), - }, - }, + getCallerIdentityErr: nil, + expectedResult: 0, + }, + "no resp or error": { + getCallerIdentityResp: nil, + getCallerIdentityErr: nil, + expectedResult: 0, + }, + "error - unauthorized": { + getCallerIdentityResp: nil, + getCallerIdentityErr: hcpis.NewIamServiceGetCallerIdentityDefault(http.StatusUnauthorized), + expectedResult: 0, + }, + "error - server error": { + getCallerIdentityResp: nil, + getCallerIdentityErr: hcpis.NewIamServiceGetCallerIdentityDefault(http.StatusInternalServerError), + expectedResult: 1, + }, + "nil payload": { + getCallerIdentityResp: &hcpis.IamServiceGetCallerIdentityOK{ + Payload: nil, }, - }, nil) - - mockVsClient := clustermocks.NewClientService(t) - mockVsClient. - On("Get", mock.Anything, nil). - Return(&hcpvs.GetOK{ - Payload: &hcpvsm.HashicorpCloudVault20201125GetResponse{ - Cluster: &hcpvsm.HashicorpCloudVault20201125Cluster{ - ID: "cluster-1", - DNSNames: &hcpvsm.HashicorpCloudVault20201125ClusterDNSNames{Proxy: "hcp-proxy-cluster-1.addr:8200"}, - State: hcpvsm.NewHashicorpCloudVault20201125ClusterState( - hcpvsm.HashicorpCloudVault20201125ClusterStateRUNNING, - ), - Config: &hcpvsm.HashicorpCloudVault20201125ClusterConfig{ - NetworkConfig: &hcpvsm.HashicorpCloudVault20201125NetworkConfig{ - HTTPProxyOption: hcpvsm.NewHashicorpCloudVault20201125HTTPProxyOption(hcpvsm.HashicorpCloudVault20201125HTTPProxyOptionENABLED), - }, - }, + getCallerIdentityErr: nil, + expectedResult: 0, + }, + "nil principal": { + getCallerIdentityResp: &hcpis.IamServiceGetCallerIdentityOK{ + Payload: &iam_models.HashicorpCloudIamGetCallerIdentityResponse{ + Principal: nil, }, }, - }, nil) + getCallerIdentityErr: nil, + expectedResult: 0, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + _, cmd := testHCPConnectCommand() + + mockIamClient := iammocks.NewClientService(t) + mockIamClient. + On("IamServiceGetCallerIdentity", mock.Anything, nil). + Return(test.getCallerIdentityResp, test.getCallerIdentityErr) + + cmd.iamClient = mockIamClient - cmd.rmOrgClient = mockRmOrgClient - cmd.rmProjClient = mockRmProjClient - cmd.vsClient = mockVsClient + // we will only call these if the caller identity call succeeds + if test.expectedResult == 0 { + mockRmOrgClient := orgmocks.NewClientService(t) + mockRmOrgClient. + On("OrganizationServiceList", mock.Anything, nil). + Return(&hcprmo.OrganizationServiceListOK{ + Payload: &models.HashicorpCloudResourcemanagerOrganizationListResponse{ + Organizations: []*models.HashicorpCloudResourcemanagerOrganization{ + { + ID: uuid.New().String(), + Name: "mock-organization-1", + State: models.NewHashicorpCloudResourcemanagerOrganizationOrganizationState( + models.HashicorpCloudResourcemanagerOrganizationOrganizationStateACTIVE, + ), + }, + }, + }, + }, nil) - result := cmd.Run([]string{"-cluster-id", "cluster-1"}) - assert.Equal(t, 0, result) + mockRmProjClient := projmocks.NewClientService(t) + mockRmProjClient. + On("ProjectServiceList", mock.Anything, nil). + Return(&hcprmp.ProjectServiceListOK{ + Payload: &models.HashicorpCloudResourcemanagerProjectListResponse{ + Projects: []*models.HashicorpCloudResourcemanagerProject{ + { + ID: uuid.New().String(), + Name: "mock-project-1", + State: models.NewHashicorpCloudResourcemanagerProjectProjectState( + models.HashicorpCloudResourcemanagerProjectProjectStateACTIVE, + ), + }, + }, + }, + }, nil) + + mockVsClient := clustermocks.NewClientService(t) + mockVsClient. + On("Get", mock.Anything, nil). + Return(&hcpvs.GetOK{ + Payload: &hcpvsm.HashicorpCloudVault20201125GetResponse{ + Cluster: &hcpvsm.HashicorpCloudVault20201125Cluster{ + ID: "cluster-1", + DNSNames: &hcpvsm.HashicorpCloudVault20201125ClusterDNSNames{Proxy: "hcp-proxy-cluster-1.addr:8200"}, + State: hcpvsm.NewHashicorpCloudVault20201125ClusterState( + hcpvsm.HashicorpCloudVault20201125ClusterStateRUNNING, + ), + Config: &hcpvsm.HashicorpCloudVault20201125ClusterConfig{ + NetworkConfig: &hcpvsm.HashicorpCloudVault20201125NetworkConfig{ + HTTPProxyOption: hcpvsm.NewHashicorpCloudVault20201125HTTPProxyOption(hcpvsm.HashicorpCloudVault20201125HTTPProxyOptionENABLED), + }, + }, + }, + }, + }, nil) + + cmd.rmOrgClient = mockRmOrgClient + cmd.rmProjClient = mockRmProjClient + cmd.vsClient = mockVsClient + } + + result := cmd.Run([]string{"-cluster-id", "cluster-1"}) + assert.Equal(t, test.expectedResult, result) + }) + } } func Test_getOrganization(t *testing.T) { @@ -202,7 +274,9 @@ func Test_getOrganization(t *testing.T) { On("OrganizationServiceList", mock.Anything, nil). Return(tst.organizationServiceListResponse, tst.expectedError) - orgID, err := cmd.getOrganization(mockRmOrgClient) + cmd.rmOrgClient = mockRmOrgClient + + orgID, err := cmd.getOrganization() if tst.expectedError != nil { assert.Error(t, err) assert.EqualError(t, err, tst.expectedError.Error()) @@ -212,7 +286,6 @@ func Test_getOrganization(t *testing.T) { } }) } - } func Test_getProject(t *testing.T) { @@ -324,7 +397,9 @@ func Test_getProject(t *testing.T) { On("ProjectServiceList", mock.Anything, nil). Return(tst.projectServiceListResponse, tst.expectedError) - projID, err := cmd.getProject("", mockRmProjClient) + cmd.rmProjClient = mockRmProjClient + + projID, err := cmd.getProject("") if tst.expectedError != nil { assert.Error(t, tst.expectedError) } else { @@ -506,7 +581,9 @@ func Test_getCluster(t *testing.T) { Return(tst.listClustersServiceListResponse, tst.expectedError) } - proxyAddr, err := cmd.getCluster("", "", tst.userParamCluster, mockVsClient) + cmd.vsClient = mockVsClient + + proxyAddr, err := cmd.getCluster("", "", tst.userParamCluster) if tst.expectedError != nil { assert.Error(t, tst.expectedError) } else { @@ -698,7 +775,11 @@ func Test_getProxyAddr(t *testing.T) { Return(tst.projectServiceListResponse, nil) } - proxyAddr, err := cmd.getProxyAddr(mockRmOrgClient, mockRmProjClient, mockVsClient) + cmd.rmOrgClient = mockRmOrgClient + cmd.rmProjClient = mockRmProjClient + cmd.vsClient = mockVsClient + + proxyAddr, err := cmd.getProxyAddr() if tst.expectedError != nil { assert.Error(t, tst.expectedError) } else {