diff --git a/config.go b/config.go index a877601..503f27e 100644 --- a/config.go +++ b/config.go @@ -7,7 +7,6 @@ import ( "encoding/json" "fmt" "os" - "os/user" "path/filepath" "strings" "time" @@ -20,6 +19,7 @@ const ( testDirectory = "hcptest" fileName = "hvd_proxy_config.json" directoryPermissions = 0o755 + defaultProxyURL = "https://hcp-proxy.addr:8200" envVarCacheTestMode = "HCP_CACHE_TEST_MODE" ) @@ -31,15 +31,15 @@ type HCPToken struct { } type HCPTokenHelper interface { - GetHCPToken() (*HCPToken, error) + GetHCPToken(string) (*HCPToken, error) } var _ HCPTokenHelper = (*InternalHCPTokenHelper)(nil) type InternalHCPTokenHelper struct{} -func (h InternalHCPTokenHelper) GetHCPToken() (*HCPToken, error) { - configCache, err := readConfig() +func (h InternalHCPTokenHelper) GetHCPToken(path string) (*HCPToken, error) { + configCache, err := readConfig(path) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func (h InternalHCPTokenHelper) GetHCPToken() (*HCPToken, error) { tk, err := hcp.Token() if err != nil { if strings.Contains(err.Error(), "no valid credential source available") { - _ = eraseConfig() + _ = eraseConfig(path) return nil, nil } @@ -84,21 +84,22 @@ type TestingHCPTokenHelper struct { ValidCache bool } -func (h TestingHCPTokenHelper) GetHCPToken() (*HCPToken, error) { - userHome := getHomeFolder() - credentialDir := filepath.Join(userHome, testDirectory) - err := os.RemoveAll(credentialDir) - if err != nil { +func (h TestingHCPTokenHelper) GetHCPToken(path string) (*HCPToken, error) { + if path == "" { + return nil, fmt.Errorf("HCP token path may not be an empty string") + } + + credentialDir := filepath.Join(path, testDirectory) + if err := os.RemoveAll(credentialDir); err != nil { return nil, err } if h.ValidCache { - err = writeConfig("https://hcp-proxy.addr:8200", "", "") - if err != nil { + if err := writeConfig(defaultProxyURL, "", "", path); err != nil { return nil, err } - configCache, err := readConfig() + configCache, err := readConfig(path) if err != nil { return nil, err } @@ -126,8 +127,8 @@ type HCPConfigCache struct { } // Write saves HCP auth data in a common location in the home directory. -func writeConfig(addr string, clientID string, secretID string) error { - credentialPath, credentialDirectory, err := getConfigPaths() +func writeConfig(addr, clientID, secretID, path string) error { + credentialPath, credentialDirectory, err := getConfigPaths(path) if err != nil { return fmt.Errorf("failed to retrieve credential path and directory: %v", err) } @@ -156,8 +157,8 @@ func writeConfig(addr string, clientID string, secretID string) error { } // readConfig opens the saved HCP auth data and returns the token. -func readConfig() (*HCPConfigCache, error) { - configPath, _, err := getConfigPaths() +func readConfig(path string) (*HCPConfigCache, error) { + configPath, _, err := getConfigPaths(path) if err != nil { return nil, fmt.Errorf("failed to retrieve config path and directory: %v", err) } @@ -179,8 +180,8 @@ func readConfig() (*HCPConfigCache, error) { return &cache, nil } -func eraseConfig() error { - _, credentialDirectory, err := getConfigPaths() +func eraseConfig(path string) error { + _, credentialDirectory, err := getConfigPaths(path) if err != nil { return fmt.Errorf("failed to retrieve credential path and directory: %v", err) } @@ -194,10 +195,10 @@ func eraseConfig() error { } // getCredentialPaths returns the complete credential path and directory. -func getConfigPaths() (configPath string, configDirectory string, err error) { - // Get the user's home directory. - userHome := getHomeFolder() - +func getConfigPaths(path string) (configPath string, configDirectory string, err error) { + if path == "" { + return "", "", fmt.Errorf("path may not be empty") + } directoryName := defaultDirectory // If in test mode, use test directory. if testMode, ok := os.LookupEnv(envVarCacheTestMode); ok { @@ -207,21 +208,8 @@ func getConfigPaths() (configPath string, configDirectory string, err error) { } // Determine absolute path to config file and directory. - configDirectory = filepath.Join(userHome, directoryName) - configPath = filepath.Join(userHome, directoryName, fileName) + configDirectory = filepath.Join(path, directoryName) + configPath = filepath.Join(path, directoryName, fileName) return configPath, configDirectory, nil } - -func getHomeFolder() string { - current, e := user.Current() - if e != nil { - // Give up and try to return something sensible - home, err := os.UserHomeDir() - if err != nil { - return "" - } - return home - } - return current.HomeDir -} diff --git a/config_test.go b/config_test.go index 21f6f53..150c49e 100644 --- a/config_test.go +++ b/config_test.go @@ -12,30 +12,42 @@ import ( func Test_GetHCPConfiguration(t *testing.T) { cases := map[string]struct { - Valid bool + Valid bool + Path string + ExpectedErr bool }{ "valid hcp configuration": { - Valid: true, + Valid: true, + Path: os.TempDir(), + ExpectedErr: false, }, "empty hcp configuration": { - Valid: false, + Valid: false, + Path: os.TempDir(), + ExpectedErr: false, + }, + "empty path configuration": { + Valid: false, + Path: "", + ExpectedErr: true, }, } for n, tst := range cases { t.Run(n, func(t *testing.T) { tkHelper := &TestingHCPTokenHelper{ValidCache: tst.Valid} - tk, err := tkHelper.GetHCPToken() - - assert.NoError(t, err) + tk, err := tkHelper.GetHCPToken(tst.Path) - if tst.Valid { - assert.Equal(t, "https://hcp-proxy.addr:8200", tk.ProxyAddr) - assert.Contains(t, tk.AccessToken, "Test.Access.Token") - assert.NotEmpty(t, tk.AccessTokenExpiry) - } else { - assert.Nil(t, tk) - assert.Nil(t, err) + if !tst.ExpectedErr { + assert.NoError(t, err) + if tst.Valid { + assert.Equal(t, "https://hcp-proxy.addr:8200", tk.ProxyAddr) + assert.Contains(t, tk.AccessToken, "Test.Access.Token") + assert.NotEmpty(t, tk.AccessTokenExpiry) + } else { + assert.Nil(t, tk) + assert.Nil(t, err) + } } }) } @@ -45,10 +57,10 @@ func Test_GetHCPConfiguration_EraseConfig(t *testing.T) { err := os.Setenv(envVarCacheTestMode, "true") assert.NoError(t, err) - err = eraseConfig() + err = eraseConfig(os.TempDir()) assert.NoError(t, err) tkHelper := &TestingHCPTokenHelper{} - _, err = tkHelper.GetHCPToken() + _, err = tkHelper.GetHCPToken(os.TempDir()) assert.NoError(t, err) } diff --git a/connect.go b/connect.go index 90ad8a4..adb4ded 100644 --- a/connect.go +++ b/connect.go @@ -19,6 +19,7 @@ import ( hcpvsm "github.com/hashicorp/hcp-sdk-go/clients/cloud-vault-service/stable/2020-11-25/models" "github.com/hashicorp/hcp-sdk-go/config" "github.com/hashicorp/hcp-sdk-go/httpclient" + "github.com/mitchellh/go-homedir" ) var ( @@ -89,7 +90,13 @@ func (c *HCPConnectCommand) Run(args []string) int { return 1 } - err = writeConfig(proxyAddr, c.flagClientID, c.flagSecretID) + path, err := homedir.Dir() + if err != nil { + c.Ui.Error(fmt.Sprintf("\nFailed to find home directory: %s", err)) + return 1 + } + + err = writeConfig(proxyAddr, c.flagClientID, c.flagSecretID, path) if err != nil { c.Ui.Error(fmt.Sprintf("\nFailed to connect to HCP Vault Cluster: %s", err)) return 1 diff --git a/disconnect.go b/disconnect.go index eeea12d..8c583e4 100644 --- a/disconnect.go +++ b/disconnect.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/hashicorp/hcp-sdk-go/config" + "github.com/mitchellh/go-homedir" "github.com/hashicorp/cli" ) @@ -30,8 +31,13 @@ Usage: vault hcp disconnect [options] } func (c *HCPDisconnectCommand) Run(_ []string) int { - err := eraseConfig() + path, err := homedir.Dir() if err != nil { + c.Ui.Error(fmt.Sprintf("\nFailed to find home directory: %s", err)) + return 1 + } + + if err := eraseConfig(path); err != nil { c.Ui.Error(fmt.Sprintf("Failed to disconnect from HCP Vault Cluster: %s", err)) return 1 } diff --git a/go.mod b/go.mod index b832f3e..e0e4d40 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/mitchellh/copystructure v1.0.0 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/reflectwalk v1.0.0 // indirect github.com/oklog/ulid v1.3.1 // indirect diff --git a/go.sum b/go.sum index edc6195..6eaddf3 100644 --- a/go.sum +++ b/go.sum @@ -106,6 +106,8 @@ github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2Em github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=