diff --git a/cli.go b/cli.go index a7ce58f17..de5addc98 100644 --- a/cli.go +++ b/cli.go @@ -17,7 +17,6 @@ import ( "github.com/hashicorp/consul-template/service_os" "github.com/hashicorp/consul-template/signals" "github.com/hashicorp/consul-template/version" - "github.com/hashicorp/consul-template/watch" ) // Exit codes are int values that represent an exit code for a particular error. @@ -104,20 +103,8 @@ func (cli *CLI) Run(args []string) int { return ExitCodeOK } - // Create the clientset - clients, err := manager.NewClientSet(config) - if err != nil { - return logError(err, ExitCodeConfigError) - } - - // vault token watcher - vtwatchErrCh := watch.VaultTokenWatcher(clients, config.Vault) - if err != nil { - return logError(err, ExitCodeRunnerError) - } - // Initial runner - runner, err := manager.NewRunner(clients, config, dry) + runner, err := manager.NewRunner(config, dry) if err != nil { return logError(err, ExitCodeRunnerError) } @@ -128,8 +115,6 @@ func (cli *CLI) Run(args []string) int { for { select { - case err := <-vtwatchErrCh: - return logError(err, ExitCodeRunnerError) case err := <-runner.ErrCh: // Check if the runner's error returned a specific exit status, and return // that value. If no value was given, return a generic exit status. @@ -165,7 +150,7 @@ func (cli *CLI) Run(args []string) int { return logError(err, ExitCodeConfigError) } - runner, err = manager.NewRunner(clients, config, dry) + runner, err = manager.NewRunner(config, dry) if err != nil { return logError(err, ExitCodeRunnerError) } diff --git a/manager/runner.go b/manager/runner.go index 0dda335d6..fac062c02 100644 --- a/manager/runner.go +++ b/manager/runner.go @@ -74,6 +74,8 @@ type Runner struct { // dependenciesLock is a lock around touching the dependencies map. dependenciesLock sync.Mutex + // token watcher + vaultTokenWatcher *watch.Watcher // watcher is the watcher this runner is using. watcher *watch.Watcher @@ -172,7 +174,7 @@ type RenderEvent struct { // NewRunner accepts a slice of TemplateConfigs and returns a pointer to the new // Runner and any error that occurred during creation. -func NewRunner(clients *dep.ClientSet, config *config.Config, dry bool) (*Runner, error) { +func NewRunner(config *config.Config, dry bool) (*Runner, error) { log.Printf("[INFO] (runner) creating new runner (dry: %v, once: %v)", dry, config.Once) @@ -181,10 +183,21 @@ func NewRunner(clients *dep.ClientSet, config *config.Config, dry bool) (*Runner dry: dry, } - if err := runner.init(clients); err != nil { + // Create the clientset + clients, err := NewClientSet(config) + if err != nil { + return nil, fmt.Errorf("runner: %w", err) + } + // needs to be run early to do initial token handling + runner.vaultTokenWatcher, err = watch.VaultTokenWatcher( + clients, config.Vault, runner.DoneCh) + if err != nil { return nil, err } + if err := runner.init(clients); err != nil { + return nil, err + } return runner, nil } @@ -226,7 +239,7 @@ func (r *Runner) Start() { if r.child != nil { r.stopDedup() - r.stopWatcher() + r.stopWatchers() log.Printf("[INFO] (runner) waiting for child process to exit") select { @@ -330,7 +343,7 @@ func (r *Runner) Start() { if r.child != nil { r.stopDedup() - r.stopWatcher() + r.stopWatchers() log.Printf("[INFO] (runner) waiting for child process to exit") select { @@ -384,6 +397,12 @@ func (r *Runner) Start() { r.ErrCh <- err return + case err := <-r.vaultTokenWatcher.ErrCh(): + // Push the error back up the stack + log.Printf("[ERR] (runner): %s", err) + r.ErrCh <- err + return + case tmpl := <-r.quiescenceCh: // Remove the quiescence for this template from the map. This will force // the upcoming Run call to actually evaluate and render the template. @@ -455,7 +474,7 @@ func (r *Runner) internalStop(immediately bool) { log.Printf("[INFO] (runner) stopping") r.stopDedup() - r.stopWatcher() + r.stopWatchers() r.stopChild(immediately) if err := r.deletePid(); err != nil { @@ -475,11 +494,15 @@ func (r *Runner) stopDedup() { } } -func (r *Runner) stopWatcher() { +func (r *Runner) stopWatchers() { if r.watcher != nil { log.Printf("[DEBUG] (runner) stopping watcher") r.watcher.Stop() } + if r.vaultTokenWatcher != nil { + log.Printf("[DEBUG] (runner) stopping vault token watcher") + r.vaultTokenWatcher.Stop() + } } func (r *Runner) stopChild(immediately bool) { diff --git a/manager/runner_test.go b/manager/runner_test.go index ef7572862..0fb162d29 100644 --- a/manager/runner_test.go +++ b/manager/runner_test.go @@ -31,7 +31,7 @@ func TestRunner_initTemplates(t *testing.T) { }, }) - r, err := NewRunner(testClients, c, true) + r, err := NewRunner(c, true) if err != nil { t.Fatal(err) } @@ -51,7 +51,7 @@ func TestRunner_initTemplates(t *testing.T) { func TestRunner_Receive(t *testing.T) { c := config.TestConfig(&config.Config{Once: true}) - r, err := NewRunner(testClients, c, true) + r, err := NewRunner(c, true) if err != nil { t.Fatal(err) } @@ -476,7 +476,7 @@ func TestRunner_Run(t *testing.T) { c.Once = true c.Finalize() - r, err := NewRunner(testClients, c, true) + r, err := NewRunner(c, true) if err != nil { t.Fatal(err) } @@ -523,7 +523,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(testClients, c, false) + r, err := NewRunner(c, false) if err != nil { t.Fatal(err) } @@ -564,7 +564,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(testClients, c, false) + r, err := NewRunner(c, false) if err != nil { t.Fatal(err) } @@ -611,7 +611,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(testClients, c, false) + r, err := NewRunner(c, false) if err != nil { t.Fatal(err) } @@ -659,7 +659,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(testClients, c, false) + r, err := NewRunner(c, false) if err != nil { t.Fatal(err) } @@ -705,7 +705,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(testClients, c, false) + r, err := NewRunner(c, false) if err != nil { t.Fatal(err) } @@ -760,7 +760,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(testClients, c, false) + r, err := NewRunner(c, false) if err != nil { t.Fatal(err) } @@ -826,7 +826,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(testClients, c, false) + r, err := NewRunner(c, false) if err != nil { t.Fatal(err) } @@ -902,7 +902,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(testClients, c, false) + r, err := NewRunner(c, false) if err != nil { t.Fatal(err) } @@ -952,7 +952,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(testClients, c, true) + r, err := NewRunner(c, true) if err != nil { t.Fatal(err) } @@ -1010,7 +1010,7 @@ func TestRunner_Start(t *testing.T) { }) c.Finalize() - r, err := NewRunner(testClients, c, false) + r, err := NewRunner(c, false) if err != nil { t.Fatal(err) } diff --git a/watch/vault_token.go b/watch/vault_token.go index c53573aba..6e22c2179 100644 --- a/watch/vault_token.go +++ b/watch/vault_token.go @@ -13,25 +13,23 @@ import ( ) // VaultTokenWatcher monitors the vault token for updates -func VaultTokenWatcher(clients *dep.ClientSet, c *config.VaultConfig) chan error { +func VaultTokenWatcher( + clients *dep.ClientSet, c *config.VaultConfig, doneCh chan struct{}, +) (*Watcher, error) { // c.Vault.Token is populated by the config code from all places // vault tokens are supported. So if there is no token set here, // tokens are not being used. raw_token := strings.TrimSpace(config.StringVal(c.Token)) if raw_token == "" { - return nil + return nil, nil } unwrap := config.BoolVal(c.UnwrapToken) vault := clients.Vault() - // buffer 1 error to allow for sequential errors to send and return - errChan := make(chan error, 1) - // get/set token once when kicked off, async after that.. - token, err := getToken(vault, raw_token, unwrap) + token, err := unpackToken(vault, raw_token, unwrap) if err != nil { - errChan <- err - return errChan + return nil, fmt.Errorf("vaultwatcher: %w", err) } vault.SetToken(token) @@ -50,71 +48,66 @@ func VaultTokenWatcher(clients *dep.ClientSet, c *config.VaultConfig) chan error // Vault Agent Token File process // tokenFile := strings.TrimSpace(config.StringVal(c.VaultAgentTokenFile)) if tokenFile != "" { - atf, err := dep.NewVaultAgentTokenQuery(tokenFile) - if err != nil { - errChan <- fmt.Errorf("vaultwatcher: %w", err) - return errChan - } w := getWatcher() - if _, err := w.Add(atf); err != nil { - errChan <- fmt.Errorf("vaultwatcher: %w", err) - return errChan + watchLoop, err := watchTokenFile(w, tokenFile, raw_token, unwrap, doneCh) + if err != nil { + return nil, fmt.Errorf("vaultwatcher: %w", err) } - go func() { - for { - raw_token, err = waitforToken(w, raw_token, unwrap) - if err != nil { - errChan <- err - return - } - } - }() + go watchLoop() } // Vault Token Renewal process // renewVault := vault.Token() != "" && config.BoolVal(c.RenewToken) if renewVault { - go func() { - vt, err := dep.NewVaultTokenQuery(token) - if err != nil { - errChan <- fmt.Errorf("vaultwatcher: %w", err) - } - w := getWatcher() - if _, err := w.Add(vt); err != nil { - errChan <- fmt.Errorf("vaultwatcher: %w", err) - } - - // VaultTokenQuery loops internally and never returns data, - // so we only care about if it errors out. - errChan <- <-w.ErrCh() - }() + w := getWatcher() + vt, err := dep.NewVaultTokenQuery(token) + if err != nil { + w.Stop() + return nil, fmt.Errorf("vaultwatcher: %w", err) + } + if _, err := w.Add(vt); err != nil { + w.Stop() + return nil, fmt.Errorf("vaultwatcher: %w", err) + } } - return errChan + return watcher, nil } -// waitforToken blocks until the tokenfile is updated, and it given the new -// data on the watcher's DataCh(annel) -// (as a variable to swap out in tests) -var waitforToken = func(w *Watcher, old_raw_token string, unwrap bool) (string, error) { +func watchTokenFile( + w *Watcher, tokenFile, raw_token string, unwrap bool, doneCh chan struct{}, +) (func(), error) { + // watcher, tokenFile, raw_token, unwrap, doneCh + atf, err := dep.NewVaultAgentTokenQuery(tokenFile) + if err != nil { + return nil, fmt.Errorf("vaultwatcher: %w", err) + } + if _, err := w.Add(atf); err != nil { + w.Stop() + return nil, fmt.Errorf("vaultwatcher: %w", err) + } vault := w.clients.Vault() - var new_raw_token string - select { - case v := <-w.DataCh(): - new_raw_token = strings.TrimSpace(v.Data().(string)) - if new_raw_token == old_raw_token { - break - } - switch token, err := getToken(vault, new_raw_token, unwrap); err { - case nil: - vault.SetToken(token) - default: - log.Printf("[INFO] %s", err) + return func() { + for { + select { + case v := <-w.DataCh(): + new_raw_token := strings.TrimSpace(v.Data().(string)) + if new_raw_token == raw_token { + break + } + token, err := unpackToken(vault, new_raw_token, unwrap) + switch err { + case nil: + raw_token = new_raw_token + vault.SetToken(token) + default: + log.Printf("[INFO] %s", err) + } + case <-doneCh: + return + } } - case err := <-w.ErrCh(): - return "", err - } - return new_raw_token, nil + }, nil } type vaultClient interface { @@ -122,8 +115,8 @@ type vaultClient interface { Logical() *api.Logical } -// getToken grabs the real token from raw_token (unwrap, etc) -func getToken(client vaultClient, token string, unwrap bool) (string, error) { +// unpackToken grabs the real token from raw_token (unwrap, etc) +func unpackToken(client vaultClient, token string, unwrap bool) (string, error) { // If vault agent specifies wrap_ttl for the token it is returned as // a SecretWrapInfo struct marshalled into JSON instead of the normal raw // token. This checks for that and pulls out the token if it is the case. diff --git a/watch/vault_token_test.go b/watch/vault_token_test.go index b3f99d19c..b1c0d4b61 100644 --- a/watch/vault_token_test.go +++ b/watch/vault_token_test.go @@ -19,10 +19,13 @@ func TestVaultTokenWatcher(t *testing.T) { // running vault's permissions. t.Run("noop", func(t *testing.T) { conf := config.DefaultVaultConfig() - errCh := VaultTokenWatcher(testClients, conf) - + watcher, err := VaultTokenWatcher(testClients, conf, nil) + if err != nil { + t.Error(err) + } + defer watcher.Stop() select { - case err := <-errCh: + case err := <-watcher.ErrCh(): if err != nil { t.Error(err) } @@ -36,7 +39,14 @@ func TestVaultTokenWatcher(t *testing.T) { conf := config.DefaultVaultConfig() token := vaultToken conf.Token = &token - _ = VaultTokenWatcher(testClients, conf) + watcher, err := VaultTokenWatcher(testClients, conf, nil) + if err != nil { + t.Error(err) + } + defer watcher.Stop() + if watcher != nil { + t.Error("watcher should be nil") + } if testClients.Vault().Token() != vaultToken { t.Error("Token should be " + vaultToken) } @@ -51,7 +61,11 @@ func TestVaultTokenWatcher(t *testing.T) { } jsonToken := string(data) conf.Token = &jsonToken - _ = VaultTokenWatcher(testClients, conf) + watcher, err := VaultTokenWatcher(testClients, conf, nil) + if err != nil { + t.Error(err) + } + defer watcher.Stop() if testClients.Vault().Token() != vaultToken { t.Error("Token should be " + vaultToken) } @@ -61,38 +75,32 @@ func TestVaultTokenWatcher(t *testing.T) { // setup testClients.Vault().SetToken(vaultToken) tokenfile := runVaultAgent(testClients, tokenRoleId) - defer func() { os.Remove(tokenfile) }() + defer func() { + testClients.Vault().SetToken(vaultToken) + os.Remove(tokenfile) + }() conf := config.DefaultVaultConfig() token := vaultToken conf.Token = &token conf.VaultAgentTokenFile = &tokenfile - // test data - d, _ := dep.NewVaultAgentTokenQuery("") - waitforCalled := fmt.Errorf("refresh called successfully") - _waitforToken := waitforToken - defer func() { - waitforToken = _waitforToken - }() - waitforToken = func(w *Watcher, raw_token string, unwrap bool) (string, error) { - _, ok := w.depViewMap[d.String()] - if !ok { - t.Error("missing tokenfile dependency") - } - return "", waitforCalled + doneCh := make(chan struct{}) + watcher, err := VaultTokenWatcher(testClients, conf, doneCh) + if err != nil { + t.Error(err) } - - errCh := VaultTokenWatcher(testClients, conf) + defer watcher.Stop() // tests - err := <-errCh - switch err { - case waitforCalled, nil: - default: + select { + case <-time.After(time.Millisecond): + // XXX remove this timer in hashicat port + doneCh <- struct{}{} + case err := <-watcher.ErrCh(): t.Error(err) } - if testClients.Vault().Token() != vaultToken { - t.Error("Token should be " + vaultToken) + if testClients.Vault().Token() == vaultToken { + t.Error("Token should not be " + vaultToken) } }) @@ -115,10 +123,14 @@ func TestVaultTokenWatcher(t *testing.T) { token := "b_token" conf.Token = &token // conf.RenewToken = &renew - errCh := VaultTokenWatcher(testClients, conf) + watcher, err := VaultTokenWatcher(testClients, conf, nil) + if err != nil { + t.Error(err) + } + defer watcher.Stop() select { - case err := <-errCh: + case err := <-watcher.ErrCh(): if err != nil { t.Error(err) } @@ -132,6 +144,8 @@ func TestVaultTokenRefreshToken(t *testing.T) { watcher := NewWatcher(&NewWatcherInput{ Clients: testClients, }) + // force watcher to be synchronous so we can control test flow + watcher.dataCh = make(chan *View) // no buffer wrapinfo := api.SecretWrapInfo{ Token: "btoken", } @@ -150,21 +164,24 @@ func TestVaultTokenRefreshToken(t *testing.T) { name := fmt.Sprintf("%d_%s", i, tc.name) t.Run(name, func(t *testing.T) { var wg sync.WaitGroup + dCh := make(chan struct{}) + watchLoop, err := watchTokenFile(watcher, "", "XXX", false, dCh) + if err != nil { + t.Error(err) + } wg.Add(1) - go func(t *testing.T) { - defer wg.Done() - token, err := waitforToken(watcher, "", false) - switch { - case err != nil: - t.Error(err) - case vault.Token() != tc.exp_token: - t.Errorf("bad token, expected: '%s', received '%s'", - tc.exp_token, token) - } - }(t) + go func() { + watchLoop() + wg.Done() + }() fd := fakeDep{name: name} watcher.dataCh <- &View{dependency: fd, data: tc.raw_token} + close(dCh) // close doneCh to stop watchLoop wg.Wait() + if vault.Token() != tc.exp_token { + t.Errorf("bad token, expected: '%s', received '%s'", + tc.exp_token, tc.raw_token) + } }) } watcher.Stop() @@ -186,9 +203,9 @@ func TestVaultTokenGetToken(t *testing.T) { } for _, tc := range testcases { dummy := &setTokenFaker{} - token, _ := getToken(dummy, tc.in, false) + token, _ := unpackToken(dummy, tc.in, false) if token != tc.out { - t.Errorf("getToken, wanted: '%v', got: '%v'", tc.out, token) + t.Errorf("unpackToken, wanted: '%v', got: '%v'", tc.out, token) } } }) @@ -212,7 +229,7 @@ func TestVaultTokenGetToken(t *testing.T) { unwrap := true wrappedToken := secret.WrapInfo.Token - token, err := getToken(vault, wrappedToken, unwrap) + token, err := unpackToken(vault, wrappedToken, unwrap) if err != nil { t.Fatal(err) } diff --git a/watch/view.go b/watch/view.go index f75ae191b..cb378e410 100644 --- a/watch/view.go +++ b/watch/view.go @@ -189,6 +189,7 @@ func (v *View) fetch(doneCh, successCh chan<- struct{}, errCh chan<- error) { allowStale = true } + firstLoop := true // to disable rate limiting on first pass for { // If the view was stopped, short-circuit this loop. This prevents a bug // where a view can get "lost" in the event Consul Template is reloaded. @@ -239,9 +240,10 @@ func (v *View) fetch(doneCh, successCh chan<- struct{}, errCh chan<- error) { allowStale = true } - if dur := rateLimiter(start); dur > 1 { + if dur := rateLimiter(start); dur > 1 && !firstLoop { time.Sleep(dur) } + firstLoop = false // blocking queries that return due to block timeout // will have the same index diff --git a/watch/watch_test.go b/watch/watch_test.go index ae5c4a304..1bee8cb53 100644 --- a/watch/watch_test.go +++ b/watch/watch_test.go @@ -134,7 +134,7 @@ func runVaultAgent(clients *dep.ClientSet, role_id string) string { } defer os.RemoveAll(dir) - tokenFile := filepath.Join("", "vatoken") + tokenFile := filepath.Join("", "vatoken.txt") role_idPath := filepath.Join(dir, "roleid") secret_idPath := filepath.Join(dir, "secretid") diff --git a/watch/watcher.go b/watch/watcher.go index 2bb8a3e99..af9cfb2b9 100644 --- a/watch/watcher.go +++ b/watch/watcher.go @@ -103,6 +103,9 @@ func (w *Watcher) DataCh() <-chan *View { // ErrCh returns a read-only channel of errors returned by the upstream. func (w *Watcher) ErrCh() <-chan error { + if w == nil { + return nil + } return w.errCh } @@ -211,6 +214,9 @@ func (w *Watcher) Size() int { // Stop halts this watcher and any currently polling views immediately. If a // view was in the middle of a poll, no data will be returned. func (w *Watcher) Stop() { + if w == nil { + return + } w.Lock() defer w.Unlock()