From 9f97fe1fc219ced56cc0bdb632e250932841f1f6 Mon Sep 17 00:00:00 2001 From: John Eikenberry Date: Mon, 12 Sep 2022 14:42:15 -0700 Subject: [PATCH] vault token management moved into separate watcher Refactor the vault token management (renewing, tokenfile watching, unwrapping, etc) into a dedicated watcher that is only responsible for that. Done to encapsulate the vault management token into one place where it can be more easily understood and tested. It was scattered about and inconsistent (eg. it only tried to unwrap the first token). Strips the vault token code out of current watcher and client_set. Most of this commit is the new test suite for this code as it wasn't really tested before. --- cli.go | 19 ++- dependency/client_set.go | 23 --- dependency/client_set_test.go | 31 ---- dependency/vault_agent_token.go | 10 +- dependency/vault_agent_token_test.go | 13 +- dependency/vault_read_test.go | 4 +- manager/manager_test.go | 14 +- manager/runner.go | 30 +--- manager/runner_test.go | 26 +-- watch/vault_token.go | 156 +++++++++++++++++ watch/vault_token_test.go | 242 +++++++++++++++++++++++++++ watch/watch_test.go | 198 ++++++++++++++++++++++ watch/watcher.go | 26 +-- watch/watcher_test.go | 45 +---- 14 files changed, 664 insertions(+), 173 deletions(-) create mode 100644 watch/vault_token.go create mode 100644 watch/vault_token_test.go create mode 100644 watch/watch_test.go diff --git a/cli.go b/cli.go index de5addc98..a7ce58f17 100644 --- a/cli.go +++ b/cli.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/consul-template/service_os" "github.com/hashicorp/consul-template/signals" "github.com/hashicorp/consul-template/version" + "github.com/hashicorp/consul-template/watch" ) // Exit codes are int values that represent an exit code for a particular error. @@ -103,8 +104,20 @@ func (cli *CLI) Run(args []string) int { return ExitCodeOK } + // Create the clientset + clients, err := manager.NewClientSet(config) + if err != nil { + return logError(err, ExitCodeConfigError) + } + + // vault token watcher + vtwatchErrCh := watch.VaultTokenWatcher(clients, config.Vault) + if err != nil { + return logError(err, ExitCodeRunnerError) + } + // Initial runner - runner, err := manager.NewRunner(config, dry) + runner, err := manager.NewRunner(clients, config, dry) if err != nil { return logError(err, ExitCodeRunnerError) } @@ -115,6 +128,8 @@ func (cli *CLI) Run(args []string) int { for { select { + case err := <-vtwatchErrCh: + return logError(err, ExitCodeRunnerError) case err := <-runner.ErrCh: // Check if the runner's error returned a specific exit status, and return // that value. If no value was given, return a generic exit status. @@ -150,7 +165,7 @@ func (cli *CLI) Run(args []string) int { return logError(err, ExitCodeConfigError) } - runner, err = manager.NewRunner(config, dry) + runner, err = manager.NewRunner(clients, config, dry) if err != nil { return logError(err, ExitCodeRunnerError) } diff --git a/dependency/client_set.go b/dependency/client_set.go index 93ea12fe9..45c3d6713 100644 --- a/dependency/client_set.go +++ b/dependency/client_set.go @@ -347,33 +347,10 @@ func (c *ClientSet) CreateVaultClient(i *CreateVaultClientInput) error { } } - // Set the token if given if i.Token != "" { client.SetToken(i.Token) } - // Check if we are unwrapping - if i.UnwrapToken { - secret, err := client.Logical().Unwrap(i.Token) - if err != nil { - return fmt.Errorf("client set: vault unwrap: %s", err) - } - - if secret == nil { - return fmt.Errorf("client set: vault unwrap: no secret") - } - - if secret.Auth == nil { - return fmt.Errorf("client set: vault unwrap: no secret auth") - } - - if secret.Auth.ClientToken == "" { - return fmt.Errorf("client set: vault unwrap: no token returned") - } - - client.SetToken(secret.Auth.ClientToken) - } - // Save the data on ourselves c.Lock() c.vault = &vaultClient{ diff --git a/dependency/client_set_test.go b/dependency/client_set_test.go index a085867fa..9a68f9d9e 100644 --- a/dependency/client_set_test.go +++ b/dependency/client_set_test.go @@ -14,37 +14,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestClientSet_unwrapVaultToken(t *testing.T) { - // Don't use t.Parallel() here as the SetWrappingLookupFunc is a global - // setting and breaks other tests if run in parallel - - vault := testClients.Vault() - - // Create a wrapped token - vault.SetWrappingLookupFunc(func(operation, path string) string { - return "30s" - }) - defer vault.SetWrappingLookupFunc(nil) - - wrappedToken, err := vault.Auth().Token().Create(&api.TokenCreateRequest{ - Lease: "1h", - }) - if err != nil { - t.Fatal(err) - } - - token := vault.Token() - - if token == wrappedToken.WrapInfo.Token { - t.Errorf("expected %q to not be %q", token, - wrappedToken.WrapInfo.Token) - } - - if _, err := vault.Auth().Token().LookupSelf(); err != nil { - t.Fatal(err) - } -} - func TestClientSet_K8SServiceTokenAuth(t *testing.T) { t.Parallel() diff --git a/dependency/vault_agent_token.go b/dependency/vault_agent_token.go index f3103ecd0..f06704678 100644 --- a/dependency/vault_agent_token.go +++ b/dependency/vault_agent_token.go @@ -1,10 +1,8 @@ package dependency import ( - "io/ioutil" "log" "os" - "strings" "time" "github.com/pkg/errors" @@ -39,6 +37,7 @@ func NewVaultAgentTokenQuery(path string) (*VaultAgentTokenQuery, error) { func (d *VaultAgentTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions) (interface{}, *ResponseMetadata, error) { log.Printf("[TRACE] %s: READ %s", d, d.path) + var token string select { case <-d.stopCh: log.Printf("[TRACE] %s: stopped", d) @@ -50,16 +49,15 @@ func (d *VaultAgentTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions) (in log.Printf("[TRACE] %s: reported change", d) - token, err := ioutil.ReadFile(d.path) + raw_token, err := os.ReadFile(d.path) if err != nil { return "", nil, errors.Wrap(err, d.String()) } - d.stat = r.stat - clients.Vault().SetToken(strings.TrimSpace(string(token))) + token = string(raw_token) } - return respWithMetadata("") + return respWithMetadata(token) } // CanShare returns if this dependency is sharable. diff --git a/dependency/vault_agent_token_test.go b/dependency/vault_agent_token_test.go index d06cc14e0..9a2099173 100644 --- a/dependency/vault_agent_token_test.go +++ b/dependency/vault_agent_token_test.go @@ -14,11 +14,6 @@ func TestVaultAgentTokenQuery_Fetch(t *testing.T) { // Don't use t.Parallel() here as the SetToken() calls are global and break // other tests if run in parallel - // reset token back to original - vc := testClients.Vault() - token := vc.Token() - defer vc.SetToken(token) - // Set up the Vault token file. tokenFile, err := ioutil.TempFile("", "token1") if err != nil { @@ -33,22 +28,22 @@ func TestVaultAgentTokenQuery_Fetch(t *testing.T) { } clientSet := testClients - _, _, err = d.Fetch(clientSet, nil) + token, _, err := d.Fetch(clientSet, nil) if err != nil { t.Fatal(err) } - assert.Equal(t, "token", clientSet.Vault().Token()) + assert.Equal(t, "token", token) // Update the contents. renderer.AtomicWrite( tokenFile.Name(), false, []byte("another_token"), 0o644, false) - _, _, err = d.Fetch(clientSet, nil) + token, _, err = d.Fetch(clientSet, nil) if err != nil { t.Fatal(err) } - assert.Equal(t, "another_token", clientSet.Vault().Token()) + assert.Equal(t, "another_token", token) } func TestVaultAgentTokenQuery_Fetch_missingFile(t *testing.T) { diff --git a/dependency/vault_read_test.go b/dependency/vault_read_test.go index 086085b37..fe1209c1e 100644 --- a/dependency/vault_read_test.go +++ b/dependency/vault_read_test.go @@ -571,8 +571,8 @@ func TestVaultReadQuery_Fetch_NonSecrets(t *testing.T) { vc := clients.Vault() err = vc.Sys().EnableAuth("approle", "approle", "") - if err != nil { - t.Fatal(err) + if err != nil && !strings.Contains(err.Error(), "path is already in use") { + t.Fatalf("(%T) %s\n", err, err) } _, err = vc.Logical().Write("auth/approle/role/my-approle", nil) diff --git a/manager/manager_test.go b/manager/manager_test.go index f2cfb81f9..259a2683d 100644 --- a/manager/manager_test.go +++ b/manager/manager_test.go @@ -33,12 +33,16 @@ func TestMain(m *testing.M) { } testConsul = consul - clients := dep.NewClientSet() - if err := clients.CreateConsulClient(&dep.CreateConsulClientInput{ - Address: testConsul.HTTPAddr, - }); err != nil { + consulConfig := config.DefaultConsulConfig() + consulConfig.Address = &testConsul.HTTPAddr + clients, err := NewClientSet(&config.Config{ + Consul: consulConfig, + Vault: config.DefaultVaultConfig(), + Nomad: config.DefaultNomadConfig(), + }) + if err != nil { testConsul.Stop() - log.Fatal(err) + log.Fatal(fmt.Errorf("failed to start clients: %v", err)) } testClients = clients diff --git a/manager/runner.go b/manager/runner.go index 7f2f8bd47..0dda335d6 100644 --- a/manager/runner.go +++ b/manager/runner.go @@ -172,7 +172,7 @@ type RenderEvent struct { // NewRunner accepts a slice of TemplateConfigs and returns a pointer to the new // Runner and any error that occurred during creation. -func NewRunner(config *config.Config, dry bool) (*Runner, error) { +func NewRunner(clients *dep.ClientSet, config *config.Config, dry bool) (*Runner, error) { log.Printf("[INFO] (runner) creating new runner (dry: %v, once: %v)", dry, config.Once) @@ -181,7 +181,7 @@ func NewRunner(config *config.Config, dry bool) (*Runner, error) { dry: dry, } - if err := runner.init(); err != nil { + if err := runner.init(clients); err != nil { return nil, err } @@ -885,7 +885,7 @@ func (r *Runner) runTemplate(tmpl *template.Template, runCtx *templateRunCtx) (* // init() creates the Runner's underlying data structures and returns an error // if any problems occur. -func (r *Runner) init() error { +func (r *Runner) init(clients *dep.ClientSet) error { // Ensure default configuration values r.config = config.DefaultConfig().Merge(r.config) r.config.Finalize() @@ -900,18 +900,8 @@ func (r *Runner) init() error { dep.SetVaultDefaultLeaseDuration(config.TimeDurationVal(r.config.Vault.DefaultLeaseDuration)) dep.SetVaultLeaseRenewalThreshold(*r.config.Vault.LeaseRenewalThreshold) - // Create the clientset - clients, err := newClientSet(r.config) - if err != nil { - return fmt.Errorf("runner: %s", err) - } - // Create the watcher - watcher, err := newWatcher(r.config, clients, r.config.Once) - if err != nil { - return fmt.Errorf("runner: %s", err) - } - r.watcher = watcher + r.watcher = newWatcher(r.config, clients, r.config.Once) numTemplates := len(*r.config.Templates) templates := make([]*template.Template, 0, numTemplates) @@ -1291,8 +1281,8 @@ func findCommand(c *config.TemplateConfig, templates []*config.TemplateConfig) * return nil } -// newClientSet creates a new client set from the given config. -func newClientSet(c *config.Config) (*dep.ClientSet, error) { +// NewClientSet creates a new client set from the given config. +func NewClientSet(c *config.Config) (*dep.ClientSet, error) { clients := dep.NewClientSet() if err := clients.CreateConsulClient(&dep.CreateConsulClientInput{ @@ -1378,10 +1368,10 @@ func newClientSet(c *config.Config) (*dep.ClientSet, error) { } // newWatcher creates a new watcher. -func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) (*watch.Watcher, error) { +func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) *watch.Watcher { log.Printf("[INFO] (runner) creating watcher") - w, err := watch.NewWatcher(&watch.NewWatcherInput{ + return watch.NewWatcher(&watch.NewWatcherInput{ Clients: clients, MaxStale: config.TimeDurationVal(c.MaxStale), Once: c.Once, @@ -1396,8 +1386,4 @@ func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) (*watch.Wat VaultToken: clients.Vault().Token(), RetryFuncNomad: watch.RetryFunc(c.Nomad.Retry.RetryFunc()), }) - if err != nil { - return nil, errors.Wrap(err, "runner") - } - return w, nil } diff --git a/manager/runner_test.go b/manager/runner_test.go index 0fb162d29..ef7572862 100644 --- a/manager/runner_test.go +++ b/manager/runner_test.go @@ -31,7 +31,7 @@ func TestRunner_initTemplates(t *testing.T) { }, }) - r, err := NewRunner(c, true) + r, err := NewRunner(testClients, c, true) if err != nil { t.Fatal(err) } @@ -51,7 +51,7 @@ func TestRunner_initTemplates(t *testing.T) { func TestRunner_Receive(t *testing.T) { c := config.TestConfig(&config.Config{Once: true}) - r, err := NewRunner(c, true) + r, err := NewRunner(testClients, c, true) if err != nil { t.Fatal(err) } @@ -476,7 +476,7 @@ func TestRunner_Run(t *testing.T) { c.Once = true c.Finalize() - r, err := NewRunner(c, true) + r, err := NewRunner(testClients, c, true) if err != nil { t.Fatal(err) } @@ -523,7 +523,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(c, false) + r, err := NewRunner(testClients, c, false) if err != nil { t.Fatal(err) } @@ -564,7 +564,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(c, false) + r, err := NewRunner(testClients, c, false) if err != nil { t.Fatal(err) } @@ -611,7 +611,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(c, false) + r, err := NewRunner(testClients, c, false) if err != nil { t.Fatal(err) } @@ -659,7 +659,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(c, false) + r, err := NewRunner(testClients, c, false) if err != nil { t.Fatal(err) } @@ -705,7 +705,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(c, false) + r, err := NewRunner(testClients, c, false) if err != nil { t.Fatal(err) } @@ -760,7 +760,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(c, false) + r, err := NewRunner(testClients, c, false) if err != nil { t.Fatal(err) } @@ -826,7 +826,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(c, false) + r, err := NewRunner(testClients, c, false) if err != nil { t.Fatal(err) } @@ -902,7 +902,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(c, false) + r, err := NewRunner(testClients, c, false) if err != nil { t.Fatal(err) } @@ -952,7 +952,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(c, true) + r, err := NewRunner(testClients, c, true) if err != nil { t.Fatal(err) } @@ -1010,7 +1010,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(c, false) + r, err := NewRunner(testClients, c, false) if err != nil { t.Fatal(err) } diff --git a/watch/vault_token.go b/watch/vault_token.go new file mode 100644 index 000000000..c53573aba --- /dev/null +++ b/watch/vault_token.go @@ -0,0 +1,156 @@ +package watch + +import ( + "encoding/json" + "fmt" + "log" + "strings" + "sync" + + "github.com/hashicorp/consul-template/config" + dep "github.com/hashicorp/consul-template/dependency" + "github.com/hashicorp/vault/api" +) + +// VaultTokenWatcher monitors the vault token for updates +func VaultTokenWatcher(clients *dep.ClientSet, c *config.VaultConfig) chan error { + // c.Vault.Token is populated by the config code from all places + // vault tokens are supported. So if there is no token set here, + // tokens are not being used. + raw_token := strings.TrimSpace(config.StringVal(c.Token)) + if raw_token == "" { + return nil + } + + unwrap := config.BoolVal(c.UnwrapToken) + vault := clients.Vault() + // buffer 1 error to allow for sequential errors to send and return + errChan := make(chan error, 1) + + // get/set token once when kicked off, async after that.. + token, err := getToken(vault, raw_token, unwrap) + if err != nil { + errChan <- err + return errChan + } + vault.SetToken(token) + + var once sync.Once + var watcher *Watcher + getWatcher := func() *Watcher { + once.Do(func() { + watcher = NewWatcher(&NewWatcherInput{ + Clients: clients, + RetryFuncVault: RetryFunc(c.Retry.RetryFunc()), + }) + }) + return watcher + } + + // Vault Agent Token File process // + tokenFile := strings.TrimSpace(config.StringVal(c.VaultAgentTokenFile)) + if tokenFile != "" { + atf, err := dep.NewVaultAgentTokenQuery(tokenFile) + if err != nil { + errChan <- fmt.Errorf("vaultwatcher: %w", err) + return errChan + } + w := getWatcher() + if _, err := w.Add(atf); err != nil { + errChan <- fmt.Errorf("vaultwatcher: %w", err) + return errChan + } + go func() { + for { + raw_token, err = waitforToken(w, raw_token, unwrap) + if err != nil { + errChan <- err + return + } + } + }() + } + + // Vault Token Renewal process // + renewVault := vault.Token() != "" && config.BoolVal(c.RenewToken) + if renewVault { + go func() { + vt, err := dep.NewVaultTokenQuery(token) + if err != nil { + errChan <- fmt.Errorf("vaultwatcher: %w", err) + } + w := getWatcher() + if _, err := w.Add(vt); err != nil { + errChan <- fmt.Errorf("vaultwatcher: %w", err) + } + + // VaultTokenQuery loops internally and never returns data, + // so we only care about if it errors out. + errChan <- <-w.ErrCh() + }() + } + + return errChan +} + +// waitforToken blocks until the tokenfile is updated, and it given the new +// data on the watcher's DataCh(annel) +// (as a variable to swap out in tests) +var waitforToken = func(w *Watcher, old_raw_token string, unwrap bool) (string, error) { + vault := w.clients.Vault() + var new_raw_token string + select { + case v := <-w.DataCh(): + new_raw_token = strings.TrimSpace(v.Data().(string)) + if new_raw_token == old_raw_token { + break + } + switch token, err := getToken(vault, new_raw_token, unwrap); err { + case nil: + vault.SetToken(token) + default: + log.Printf("[INFO] %s", err) + } + case err := <-w.ErrCh(): + return "", err + } + return new_raw_token, nil +} + +type vaultClient interface { + SetToken(string) + Logical() *api.Logical +} + +// getToken grabs the real token from raw_token (unwrap, etc) +func getToken(client vaultClient, token string, unwrap bool) (string, error) { + // If vault agent specifies wrap_ttl for the token it is returned as + // a SecretWrapInfo struct marshalled into JSON instead of the normal raw + // token. This checks for that and pulls out the token if it is the case. + var wrapinfo api.SecretWrapInfo + if err := json.Unmarshal([]byte(token), &wrapinfo); err == nil { + token = wrapinfo.Token + } + token = strings.TrimSpace(token) + if token == "" { + return "", fmt.Errorf("empty token") + } + + if unwrap { + client.SetToken(token) // needs to be set to unwrap + secret, err := client.Logical().Unwrap(token) + switch { + case err != nil: + return token, fmt.Errorf("vault unwrap: %s", err) + case secret == nil: + return token, fmt.Errorf("vault unwrap: no secret") + case secret.Auth == nil: + return token, fmt.Errorf("vault unwrap: no secret auth") + case secret.Auth.ClientToken == "": + return token, fmt.Errorf("vault unwrap: no token returned") + default: + token = secret.Auth.ClientToken + } + } + return token, nil +} diff --git a/watch/vault_token_test.go b/watch/vault_token_test.go new file mode 100644 index 000000000..b3f99d19c --- /dev/null +++ b/watch/vault_token_test.go @@ -0,0 +1,242 @@ +package watch + +import ( + "encoding/json" + "fmt" + "os" + "sync" + "testing" + "time" + + "github.com/hashicorp/consul-template/config" + dep "github.com/hashicorp/consul-template/dependency" + "github.com/hashicorp/vault/api" +) + +// approle auto-auth setup in watch_test.go, TestMain() +func TestVaultTokenWatcher(t *testing.T) { + // Don't set the below to run in parallel. They mess with the single + // running vault's permissions. + t.Run("noop", func(t *testing.T) { + conf := config.DefaultVaultConfig() + errCh := VaultTokenWatcher(testClients, conf) + + select { + case err := <-errCh: + if err != nil { + t.Error(err) + } + case <-time.After(time.Second): + return + } + }) + + t.Run("fixed_token", func(t *testing.T) { + testClients.Vault().SetToken(vaultToken) + conf := config.DefaultVaultConfig() + token := vaultToken + conf.Token = &token + _ = VaultTokenWatcher(testClients, conf) + if testClients.Vault().Token() != vaultToken { + t.Error("Token should be " + vaultToken) + } + }) + + t.Run("secretwrapped_token", func(t *testing.T) { + testClients.Vault().SetToken(vaultToken) + conf := config.DefaultVaultConfig() + data, err := json.Marshal(&api.SecretWrapInfo{Token: vaultToken}) + if err != nil { + t.Error(err) + } + jsonToken := string(data) + conf.Token = &jsonToken + _ = VaultTokenWatcher(testClients, conf) + if testClients.Vault().Token() != vaultToken { + t.Error("Token should be " + vaultToken) + } + }) + + t.Run("tokenfile", func(t *testing.T) { + // setup + testClients.Vault().SetToken(vaultToken) + tokenfile := runVaultAgent(testClients, tokenRoleId) + defer func() { os.Remove(tokenfile) }() + conf := config.DefaultVaultConfig() + token := vaultToken + conf.Token = &token + conf.VaultAgentTokenFile = &tokenfile + + // test data + d, _ := dep.NewVaultAgentTokenQuery("") + waitforCalled := fmt.Errorf("refresh called successfully") + _waitforToken := waitforToken + defer func() { + waitforToken = _waitforToken + }() + waitforToken = func(w *Watcher, raw_token string, unwrap bool) (string, error) { + _, ok := w.depViewMap[d.String()] + if !ok { + t.Error("missing tokenfile dependency") + } + return "", waitforCalled + } + + errCh := VaultTokenWatcher(testClients, conf) + // tests + err := <-errCh + switch err { + case waitforCalled, nil: + default: + t.Error(err) + } + + if testClients.Vault().Token() != vaultToken { + t.Error("Token should be " + vaultToken) + } + }) + + t.Run("renew", func(t *testing.T) { + // exercise the renewer: the action is all inside the vault api + // calls and vault so there's little to check.. so we just try + // to call it and make sure it doesn't error + testClients.Vault().SetToken(vaultToken) + renew := true + _, err := testClients.Vault().Auth().Token().Create( + &api.TokenCreateRequest{ + ID: "b_token", + TTL: "1m", + Renewable: &renew, + }) + if err != nil { + t.Error(err) + } + conf := config.DefaultVaultConfig() + token := "b_token" + conf.Token = &token // + conf.RenewToken = &renew + errCh := VaultTokenWatcher(testClients, conf) + + select { + case err := <-errCh: + if err != nil { + t.Error(err) + } + case <-time.After(time.Millisecond * 100): + // give it a chance to throw an error + } + }) +} + +func TestVaultTokenRefreshToken(t *testing.T) { + watcher := NewWatcher(&NewWatcherInput{ + Clients: testClients, + }) + wrapinfo := api.SecretWrapInfo{ + Token: "btoken", + } + b, _ := json.Marshal(wrapinfo) + type testcase struct { + name, raw_token, exp_token string + } + vault := testClients.Vault() + testcases := []testcase{ + {name: "noop", raw_token: "foo", exp_token: "foo"}, + {name: "spaces", raw_token: " foo ", exp_token: "foo"}, + {name: "secretwrap", raw_token: string(b), exp_token: "btoken"}, + } + for i, tc := range testcases { + tc := tc // avoid for-loop pointer wart + name := fmt.Sprintf("%d_%s", i, tc.name) + t.Run(name, func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(1) + go func(t *testing.T) { + defer wg.Done() + token, err := waitforToken(watcher, "", false) + switch { + case err != nil: + t.Error(err) + case vault.Token() != tc.exp_token: + t.Errorf("bad token, expected: '%s', received '%s'", + tc.exp_token, token) + } + }(t) + fd := fakeDep{name: name} + watcher.dataCh <- &View{dependency: fd, data: tc.raw_token} + wg.Wait() + }) + } + watcher.Stop() +} + +// When vault-agent uses the wrap_ttl option it writes a json blob instead of +// a raw token. This verifies it will extract the token from that when needed. +// It doesn't test unwrap. The integration test covers that for now. +func TestVaultTokenGetToken(t *testing.T) { + t.Run("table_test", func(t *testing.T) { + wrapinfo := api.SecretWrapInfo{ + Token: "btoken", + } + b, _ := json.Marshal(wrapinfo) + testcases := []struct{ in, out string }{ + {in: "", out: ""}, + {in: "atoken", out: "atoken"}, + {in: string(b), out: "btoken"}, + } + for _, tc := range testcases { + dummy := &setTokenFaker{} + token, _ := getToken(dummy, tc.in, false) + if token != tc.out { + t.Errorf("getToken, wanted: '%v', got: '%v'", tc.out, token) + } + } + }) + t.Run("unwrap_test", func(t *testing.T) { + vault := testClients.Vault() + vault.SetToken(vaultToken) + vault.SetWrappingLookupFunc(func(operation, path string) string { + if path == "auth/token/create" { + return "30s" + } + return "" + }) + defer vault.SetWrappingLookupFunc(nil) + + secret, err := vault.Auth().Token().Create(&api.TokenCreateRequest{ + Lease: "1h", + }) + if err != nil { + t.Fatal(err) + } + + unwrap := true + wrappedToken := secret.WrapInfo.Token + token, err := getToken(vault, wrappedToken, unwrap) + if err != nil { + t.Fatal(err) + } + if token == wrappedToken { + t.Errorf("tokens should not match") + } + }) +} + +type setTokenFaker struct { + Token string +} + +func (t *setTokenFaker) SetToken(token string) {} +func (t *setTokenFaker) Logical() *api.Logical { return nil } + +var _ dep.Dependency = (*fakeDep)(nil) + +type fakeDep struct{ name string } + +func (d fakeDep) String() string { return d.name } +func (d fakeDep) CanShare() bool { return false } +func (d fakeDep) Stop() {} +func (d fakeDep) Type() dep.Type { return dep.TypeConsul } +func (d fakeDep) Fetch(*dep.ClientSet, *dep.QueryOptions) (interface{}, *dep.ResponseMetadata, error) { + return d.name, nil, nil +} diff --git a/watch/watch_test.go b/watch/watch_test.go new file mode 100644 index 000000000..ae5c4a304 --- /dev/null +++ b/watch/watch_test.go @@ -0,0 +1,198 @@ +package watch + +import ( + "encoding/json" + "io/ioutil" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + dep "github.com/hashicorp/consul-template/dependency" + "github.com/hashicorp/vault/api" +) + +const ( + vaultAddr = "http://127.0.0.1:8200" + vaultToken = "a_token" +) + +var ( + testVault *vaultServer + testClients *dep.ClientSet + tokenRoleId string +) + +func TestMain(m *testing.M) { + os.Exit(main(m)) +} + +// sub-main so I can use defer +func main(m *testing.M) int { + log.SetOutput(ioutil.Discard) + testVault = newTestVault() + defer func() { testVault.Stop() }() + + clients := dep.NewClientSet() + if err := clients.CreateVaultClient(&dep.CreateVaultClientInput{ + Address: vaultAddr, + Token: vaultToken, + }); err != nil { + panic(err) + } + + testClients = clients + tokenRoleId = vaultTokenSetup(clients) + + return m.Run() +} + +type vaultServer struct { + cmd *exec.Cmd +} + +func (v vaultServer) Stop() error { + if v.cmd != nil && v.cmd.Process != nil { + return v.cmd.Process.Signal(os.Interrupt) + } + return nil +} + +func newTestVault() *vaultServer { + path, err := exec.LookPath("vault") + if err != nil || path == "" { + panic("vault not found on $PATH") + } + args := []string{ + "server", "-dev", "-dev-root-token-id", vaultToken, + "-dev-no-store-token", + } + cmd := exec.Command("vault", args...) + cmd.Stdout = ioutil.Discard + cmd.Stderr = ioutil.Discard + + if err := cmd.Start(); err != nil { + panic("vault failed to start: " + err.Error()) + } + return &vaultServer{ + cmd: cmd, + } +} + +// Sets up approle auto-auth for token generation/testing +func vaultTokenSetup(clients *dep.ClientSet) string { + vc := clients.Vault() + + // vault auth enable approle + err := vc.Sys().EnableAuthWithOptions("approle", + &api.MountInput{ + Type: "approle", + }) + if err != nil && !strings.Contains(err.Error(), "path is already in use") { + panic(err) + } + + // vault policy write foo 'path ...' + err = vc.Sys().PutPolicy("foo", + `path "secret/data/foo" { capabilities = ["read"] }`) + if err != nil { + panic(err) + } + + // vault write auth/approle/role/foo ... + _, err = vc.Logical().Write("auth/approle/role/foo", + map[string]interface{}{ + "token_policies": "foo", + "secret_id_num_uses": 100, + "secret_id_ttl": "5m", + "token_num_users": 10, + "token_ttl": "7m", + "token_max_ttl": "10m", + }) + if err != nil { + panic(err) + } + + var sec *api.Secret + // vault read -field=role_id auth/approle/role/foo/role-id + sec, err = vc.Logical().Read("auth/approle/role/foo/role-id") + if err != nil { + panic(err) + } + role_id := sec.Data["role_id"] + return role_id.(string) +} + +// returns path to token file (which is created by the agent run) +// token file isn't cleaned, so use returned path to remove it when done +func runVaultAgent(clients *dep.ClientSet, role_id string) string { + dir, err := os.MkdirTemp("", "consul-template-test") + if err != nil { + panic(err) + } + defer os.RemoveAll(dir) + + tokenFile := filepath.Join("", "vatoken") + + role_idPath := filepath.Join(dir, "roleid") + secret_idPath := filepath.Join(dir, "secretid") + vaconf := filepath.Join(dir, "vault-agent-config.json") + + // Generate secret_id, need new one for each agent run + // vault write -f -field secret_id auth/approle/role/foo/secret-id + vc := clients.Vault() + sec, err := vc.Logical().Write("auth/approle/role/foo/secret-id", nil) + if err != nil { + panic(err) + } + secret_id := sec.Data["secret_id"].(string) + err = os.WriteFile(secret_idPath, []byte(secret_id), 0o444) + if err != nil { + panic(err) + } + err = os.WriteFile(role_idPath, []byte(role_id), 0o444) + if err != nil { + panic(err) + } + + type obj map[string]interface{} + type list []obj + va := obj{ + "vault": obj{"address": vaultAddr}, + "auto_auth": obj{ + "method": obj{ + "type": "approle", + "config": obj{ + "role_id_file_path": role_idPath, + "secret_id_file_path": secret_idPath, + }, + "wrap_ttl": "5m", + }, + "sinks": list{ + {"sink": obj{"type": "file", "config": obj{"path": tokenFile}}}, + }, + }, + } + txt, err := json.Marshal(va) + if err != nil { + panic(err) + } + err = os.WriteFile(vaconf, txt, 0o644) + if err != nil { + panic(err) + } + + args := []string{ + "agent", "-exit-after-auth", "-config=" + vaconf, + } + cmd := exec.Command("vault", args...) + cmd.Stdout = ioutil.Discard + cmd.Stderr = ioutil.Discard + + if err := cmd.Run(); err != nil { + panic("vault agent failed to run: " + err.Error()) + } + return tokenFile +} diff --git a/watch/watcher.go b/watch/watcher.go index e4c3d6f07..2bb8a3e99 100644 --- a/watch/watcher.go +++ b/watch/watcher.go @@ -78,7 +78,7 @@ type NewWatcherInput struct { } // NewWatcher creates a new watcher using the given API client. -func NewWatcher(i *NewWatcherInput) (*Watcher, error) { +func NewWatcher(i *NewWatcherInput) *Watcher { w := &Watcher{ clients: i.Clients, depViewMap: make(map[string]*View), @@ -92,29 +92,7 @@ func NewWatcher(i *NewWatcherInput) (*Watcher, error) { retryFuncVault: i.RetryFuncVault, retryFuncNomad: i.RetryFuncNomad, } - - // Start a watcher for the Vault renew if that config was specified - if i.RenewVault { - vt, err := dep.NewVaultTokenQuery(i.VaultToken) - if err != nil { - return nil, errors.Wrap(err, "watcher") - } - if _, err := w.Add(vt); err != nil { - return nil, errors.Wrap(err, "watcher") - } - } - - if len(i.VaultAgentTokenFile) > 0 { - vag, err := dep.NewVaultAgentTokenQuery(i.VaultAgentTokenFile) - if err != nil { - return nil, errors.Wrap(err, "watcher") - } - if _, err := w.Add(vag); err != nil { - return nil, errors.Wrap(err, "watcher") - } - } - - return w, nil + return w } // DataCh returns a read-only channel of Views which is populated when a view diff --git a/watch/watcher_test.go b/watch/watcher_test.go index 67b7c5acf..0f47273ed 100644 --- a/watch/watcher_test.go +++ b/watch/watcher_test.go @@ -8,13 +8,10 @@ import ( ) func TestAdd_updatesMap(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } d := &TestDep{} if _, err := w.Add(d); err != nil { @@ -28,13 +25,10 @@ func TestAdd_updatesMap(t *testing.T) { } func TestAdd_exists(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } d := &TestDep{} w.depViewMap[d.String()] = &View{} @@ -50,13 +44,10 @@ func TestAdd_exists(t *testing.T) { } func TestAdd_startsViewPoll(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } added, err := w.Add(&TestDep{}) if err != nil { @@ -76,13 +67,10 @@ func TestAdd_startsViewPoll(t *testing.T) { } func TestWatching_notExists(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } d := &TestDep{} if w.Watching(d) == true { @@ -91,13 +79,10 @@ func TestWatching_notExists(t *testing.T) { } func TestWatching_exists(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } d := &TestDep{} if _, err := w.Add(d); err != nil { @@ -110,13 +95,10 @@ func TestWatching_exists(t *testing.T) { } func TestRemove_exists(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } d := &TestDep{} if _, err := w.Add(d); err != nil { @@ -134,13 +116,10 @@ func TestRemove_exists(t *testing.T) { } func TestRemove_doesNotExist(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } removed := w.Remove(&TestDep{}) if removed != false { @@ -149,13 +128,10 @@ func TestRemove_doesNotExist(t *testing.T) { } func TestSize_empty(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } if w.Size() != 0 { t.Errorf("expected %d to be %d", w.Size(), 0) @@ -163,13 +139,10 @@ func TestSize_empty(t *testing.T) { } func TestSize_returnsNumViews(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } for i := 0; i < 10; i++ { d := &TestDep{name: fmt.Sprintf("%d", i)}