diff --git a/dependency/client_set.go b/dependency/client_set.go index 93ea12fe9..45c3d6713 100644 --- a/dependency/client_set.go +++ b/dependency/client_set.go @@ -347,33 +347,10 @@ func (c *ClientSet) CreateVaultClient(i *CreateVaultClientInput) error { } } - // Set the token if given if i.Token != "" { client.SetToken(i.Token) } - // Check if we are unwrapping - if i.UnwrapToken { - secret, err := client.Logical().Unwrap(i.Token) - if err != nil { - return fmt.Errorf("client set: vault unwrap: %s", err) - } - - if secret == nil { - return fmt.Errorf("client set: vault unwrap: no secret") - } - - if secret.Auth == nil { - return fmt.Errorf("client set: vault unwrap: no secret auth") - } - - if secret.Auth.ClientToken == "" { - return fmt.Errorf("client set: vault unwrap: no token returned") - } - - client.SetToken(secret.Auth.ClientToken) - } - // Save the data on ourselves c.Lock() c.vault = &vaultClient{ diff --git a/dependency/client_set_test.go b/dependency/client_set_test.go index a085867fa..9a68f9d9e 100644 --- a/dependency/client_set_test.go +++ b/dependency/client_set_test.go @@ -14,37 +14,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestClientSet_unwrapVaultToken(t *testing.T) { - // Don't use t.Parallel() here as the SetWrappingLookupFunc is a global - // setting and breaks other tests if run in parallel - - vault := testClients.Vault() - - // Create a wrapped token - vault.SetWrappingLookupFunc(func(operation, path string) string { - return "30s" - }) - defer vault.SetWrappingLookupFunc(nil) - - wrappedToken, err := vault.Auth().Token().Create(&api.TokenCreateRequest{ - Lease: "1h", - }) - if err != nil { - t.Fatal(err) - } - - token := vault.Token() - - if token == wrappedToken.WrapInfo.Token { - t.Errorf("expected %q to not be %q", token, - wrappedToken.WrapInfo.Token) - } - - if _, err := vault.Auth().Token().LookupSelf(); err != nil { - t.Fatal(err) - } -} - func TestClientSet_K8SServiceTokenAuth(t *testing.T) { t.Parallel() diff --git a/dependency/vault_agent_token.go b/dependency/vault_agent_token.go index f3103ecd0..f06704678 100644 --- a/dependency/vault_agent_token.go +++ b/dependency/vault_agent_token.go @@ -1,10 +1,8 @@ package dependency import ( - "io/ioutil" "log" "os" - "strings" "time" "github.com/pkg/errors" @@ -39,6 +37,7 @@ func NewVaultAgentTokenQuery(path string) (*VaultAgentTokenQuery, error) { func (d *VaultAgentTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions) (interface{}, *ResponseMetadata, error) { log.Printf("[TRACE] %s: READ %s", d, d.path) + var token string select { case <-d.stopCh: log.Printf("[TRACE] %s: stopped", d) @@ -50,16 +49,15 @@ func (d *VaultAgentTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions) (in log.Printf("[TRACE] %s: reported change", d) - token, err := ioutil.ReadFile(d.path) + raw_token, err := os.ReadFile(d.path) if err != nil { return "", nil, errors.Wrap(err, d.String()) } - d.stat = r.stat - clients.Vault().SetToken(strings.TrimSpace(string(token))) + token = string(raw_token) } - return respWithMetadata("") + return respWithMetadata(token) } // CanShare returns if this dependency is sharable. diff --git a/dependency/vault_agent_token_test.go b/dependency/vault_agent_token_test.go index d06cc14e0..9a2099173 100644 --- a/dependency/vault_agent_token_test.go +++ b/dependency/vault_agent_token_test.go @@ -14,11 +14,6 @@ func TestVaultAgentTokenQuery_Fetch(t *testing.T) { // Don't use t.Parallel() here as the SetToken() calls are global and break // other tests if run in parallel - // reset token back to original - vc := testClients.Vault() - token := vc.Token() - defer vc.SetToken(token) - // Set up the Vault token file. tokenFile, err := ioutil.TempFile("", "token1") if err != nil { @@ -33,22 +28,22 @@ func TestVaultAgentTokenQuery_Fetch(t *testing.T) { } clientSet := testClients - _, _, err = d.Fetch(clientSet, nil) + token, _, err := d.Fetch(clientSet, nil) if err != nil { t.Fatal(err) } - assert.Equal(t, "token", clientSet.Vault().Token()) + assert.Equal(t, "token", token) // Update the contents. renderer.AtomicWrite( tokenFile.Name(), false, []byte("another_token"), 0o644, false) - _, _, err = d.Fetch(clientSet, nil) + token, _, err = d.Fetch(clientSet, nil) if err != nil { t.Fatal(err) } - assert.Equal(t, "another_token", clientSet.Vault().Token()) + assert.Equal(t, "another_token", token) } func TestVaultAgentTokenQuery_Fetch_missingFile(t *testing.T) { diff --git a/dependency/vault_read_test.go b/dependency/vault_read_test.go index 086085b37..fe1209c1e 100644 --- a/dependency/vault_read_test.go +++ b/dependency/vault_read_test.go @@ -571,8 +571,8 @@ func TestVaultReadQuery_Fetch_NonSecrets(t *testing.T) { vc := clients.Vault() err = vc.Sys().EnableAuth("approle", "approle", "") - if err != nil { - t.Fatal(err) + if err != nil && !strings.Contains(err.Error(), "path is already in use") { + t.Fatalf("(%T) %s\n", err, err) } _, err = vc.Logical().Write("auth/approle/role/my-approle", nil) diff --git a/manager/manager_test.go b/manager/manager_test.go index f2cfb81f9..259a2683d 100644 --- a/manager/manager_test.go +++ b/manager/manager_test.go @@ -33,12 +33,16 @@ func TestMain(m *testing.M) { } testConsul = consul - clients := dep.NewClientSet() - if err := clients.CreateConsulClient(&dep.CreateConsulClientInput{ - Address: testConsul.HTTPAddr, - }); err != nil { + consulConfig := config.DefaultConsulConfig() + consulConfig.Address = &testConsul.HTTPAddr + clients, err := NewClientSet(&config.Config{ + Consul: consulConfig, + Vault: config.DefaultVaultConfig(), + Nomad: config.DefaultNomadConfig(), + }) + if err != nil { testConsul.Stop() - log.Fatal(err) + log.Fatal(fmt.Errorf("failed to start clients: %v", err)) } testClients = clients diff --git a/manager/runner.go b/manager/runner.go index 7f2f8bd47..fac062c02 100644 --- a/manager/runner.go +++ b/manager/runner.go @@ -74,6 +74,8 @@ type Runner struct { // dependenciesLock is a lock around touching the dependencies map. dependenciesLock sync.Mutex + // token watcher + vaultTokenWatcher *watch.Watcher // watcher is the watcher this runner is using. watcher *watch.Watcher @@ -181,10 +183,21 @@ func NewRunner(config *config.Config, dry bool) (*Runner, error) { dry: dry, } - if err := runner.init(); err != nil { + // Create the clientset + clients, err := NewClientSet(config) + if err != nil { + return nil, fmt.Errorf("runner: %w", err) + } + // needs to be run early to do initial token handling + runner.vaultTokenWatcher, err = watch.VaultTokenWatcher( + clients, config.Vault, runner.DoneCh) + if err != nil { return nil, err } + if err := runner.init(clients); err != nil { + return nil, err + } return runner, nil } @@ -226,7 +239,7 @@ func (r *Runner) Start() { if r.child != nil { r.stopDedup() - r.stopWatcher() + r.stopWatchers() log.Printf("[INFO] (runner) waiting for child process to exit") select { @@ -330,7 +343,7 @@ func (r *Runner) Start() { if r.child != nil { r.stopDedup() - r.stopWatcher() + r.stopWatchers() log.Printf("[INFO] (runner) waiting for child process to exit") select { @@ -384,6 +397,12 @@ func (r *Runner) Start() { r.ErrCh <- err return + case err := <-r.vaultTokenWatcher.ErrCh(): + // Push the error back up the stack + log.Printf("[ERR] (runner): %s", err) + r.ErrCh <- err + return + case tmpl := <-r.quiescenceCh: // Remove the quiescence for this template from the map. This will force // the upcoming Run call to actually evaluate and render the template. @@ -455,7 +474,7 @@ func (r *Runner) internalStop(immediately bool) { log.Printf("[INFO] (runner) stopping") r.stopDedup() - r.stopWatcher() + r.stopWatchers() r.stopChild(immediately) if err := r.deletePid(); err != nil { @@ -475,11 +494,15 @@ func (r *Runner) stopDedup() { } } -func (r *Runner) stopWatcher() { +func (r *Runner) stopWatchers() { if r.watcher != nil { log.Printf("[DEBUG] (runner) stopping watcher") r.watcher.Stop() } + if r.vaultTokenWatcher != nil { + log.Printf("[DEBUG] (runner) stopping vault token watcher") + r.vaultTokenWatcher.Stop() + } } func (r *Runner) stopChild(immediately bool) { @@ -885,7 +908,7 @@ func (r *Runner) runTemplate(tmpl *template.Template, runCtx *templateRunCtx) (* // init() creates the Runner's underlying data structures and returns an error // if any problems occur. -func (r *Runner) init() error { +func (r *Runner) init(clients *dep.ClientSet) error { // Ensure default configuration values r.config = config.DefaultConfig().Merge(r.config) r.config.Finalize() @@ -900,18 +923,8 @@ func (r *Runner) init() error { dep.SetVaultDefaultLeaseDuration(config.TimeDurationVal(r.config.Vault.DefaultLeaseDuration)) dep.SetVaultLeaseRenewalThreshold(*r.config.Vault.LeaseRenewalThreshold) - // Create the clientset - clients, err := newClientSet(r.config) - if err != nil { - return fmt.Errorf("runner: %s", err) - } - // Create the watcher - watcher, err := newWatcher(r.config, clients, r.config.Once) - if err != nil { - return fmt.Errorf("runner: %s", err) - } - r.watcher = watcher + r.watcher = newWatcher(r.config, clients, r.config.Once) numTemplates := len(*r.config.Templates) templates := make([]*template.Template, 0, numTemplates) @@ -1291,8 +1304,8 @@ func findCommand(c *config.TemplateConfig, templates []*config.TemplateConfig) * return nil } -// newClientSet creates a new client set from the given config. -func newClientSet(c *config.Config) (*dep.ClientSet, error) { +// NewClientSet creates a new client set from the given config. +func NewClientSet(c *config.Config) (*dep.ClientSet, error) { clients := dep.NewClientSet() if err := clients.CreateConsulClient(&dep.CreateConsulClientInput{ @@ -1378,10 +1391,10 @@ func newClientSet(c *config.Config) (*dep.ClientSet, error) { } // newWatcher creates a new watcher. -func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) (*watch.Watcher, error) { +func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) *watch.Watcher { log.Printf("[INFO] (runner) creating watcher") - w, err := watch.NewWatcher(&watch.NewWatcherInput{ + return watch.NewWatcher(&watch.NewWatcherInput{ Clients: clients, MaxStale: config.TimeDurationVal(c.MaxStale), Once: c.Once, @@ -1396,8 +1409,4 @@ func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) (*watch.Wat VaultToken: clients.Vault().Token(), RetryFuncNomad: watch.RetryFunc(c.Nomad.Retry.RetryFunc()), }) - if err != nil { - return nil, errors.Wrap(err, "runner") - } - return w, nil } diff --git a/watch/vault_token.go b/watch/vault_token.go new file mode 100644 index 000000000..6e22c2179 --- /dev/null +++ b/watch/vault_token.go @@ -0,0 +1,149 @@ +package watch + +import ( + "encoding/json" + "fmt" + "log" + "strings" + "sync" + + "github.com/hashicorp/consul-template/config" + dep "github.com/hashicorp/consul-template/dependency" + "github.com/hashicorp/vault/api" +) + +// VaultTokenWatcher monitors the vault token for updates +func VaultTokenWatcher( + clients *dep.ClientSet, c *config.VaultConfig, doneCh chan struct{}, +) (*Watcher, error) { + // c.Vault.Token is populated by the config code from all places + // vault tokens are supported. So if there is no token set here, + // tokens are not being used. + raw_token := strings.TrimSpace(config.StringVal(c.Token)) + if raw_token == "" { + return nil, nil + } + + unwrap := config.BoolVal(c.UnwrapToken) + vault := clients.Vault() + // get/set token once when kicked off, async after that.. + token, err := unpackToken(vault, raw_token, unwrap) + if err != nil { + return nil, fmt.Errorf("vaultwatcher: %w", err) + } + vault.SetToken(token) + + var once sync.Once + var watcher *Watcher + getWatcher := func() *Watcher { + once.Do(func() { + watcher = NewWatcher(&NewWatcherInput{ + Clients: clients, + RetryFuncVault: RetryFunc(c.Retry.RetryFunc()), + }) + }) + return watcher + } + + // Vault Agent Token File process // + tokenFile := strings.TrimSpace(config.StringVal(c.VaultAgentTokenFile)) + if tokenFile != "" { + w := getWatcher() + watchLoop, err := watchTokenFile(w, tokenFile, raw_token, unwrap, doneCh) + if err != nil { + return nil, fmt.Errorf("vaultwatcher: %w", err) + } + go watchLoop() + } + + // Vault Token Renewal process // + renewVault := vault.Token() != "" && config.BoolVal(c.RenewToken) + if renewVault { + w := getWatcher() + vt, err := dep.NewVaultTokenQuery(token) + if err != nil { + w.Stop() + return nil, fmt.Errorf("vaultwatcher: %w", err) + } + if _, err := w.Add(vt); err != nil { + w.Stop() + return nil, fmt.Errorf("vaultwatcher: %w", err) + } + } + + return watcher, nil +} + +func watchTokenFile( + w *Watcher, tokenFile, raw_token string, unwrap bool, doneCh chan struct{}, +) (func(), error) { + // watcher, tokenFile, raw_token, unwrap, doneCh + atf, err := dep.NewVaultAgentTokenQuery(tokenFile) + if err != nil { + return nil, fmt.Errorf("vaultwatcher: %w", err) + } + if _, err := w.Add(atf); err != nil { + w.Stop() + return nil, fmt.Errorf("vaultwatcher: %w", err) + } + vault := w.clients.Vault() + return func() { + for { + select { + case v := <-w.DataCh(): + new_raw_token := strings.TrimSpace(v.Data().(string)) + if new_raw_token == raw_token { + break + } + token, err := unpackToken(vault, new_raw_token, unwrap) + switch err { + case nil: + raw_token = new_raw_token + vault.SetToken(token) + default: + log.Printf("[INFO] %s", err) + } + case <-doneCh: + return + } + } + }, nil +} + +type vaultClient interface { + SetToken(string) + Logical() *api.Logical +} + +// unpackToken grabs the real token from raw_token (unwrap, etc) +func unpackToken(client vaultClient, token string, unwrap bool) (string, error) { + // If vault agent specifies wrap_ttl for the token it is returned as + // a SecretWrapInfo struct marshalled into JSON instead of the normal raw + // token. This checks for that and pulls out the token if it is the case. + var wrapinfo api.SecretWrapInfo + if err := json.Unmarshal([]byte(token), &wrapinfo); err == nil { + token = wrapinfo.Token + } + token = strings.TrimSpace(token) + if token == "" { + return "", fmt.Errorf("empty token") + } + + if unwrap { + client.SetToken(token) // needs to be set to unwrap + secret, err := client.Logical().Unwrap(token) + switch { + case err != nil: + return token, fmt.Errorf("vault unwrap: %s", err) + case secret == nil: + return token, fmt.Errorf("vault unwrap: no secret") + case secret.Auth == nil: + return token, fmt.Errorf("vault unwrap: no secret auth") + case secret.Auth.ClientToken == "": + return token, fmt.Errorf("vault unwrap: no token returned") + default: + token = secret.Auth.ClientToken + } + } + return token, nil +} diff --git a/watch/vault_token_test.go b/watch/vault_token_test.go new file mode 100644 index 000000000..b1c0d4b61 --- /dev/null +++ b/watch/vault_token_test.go @@ -0,0 +1,259 @@ +package watch + +import ( + "encoding/json" + "fmt" + "os" + "sync" + "testing" + "time" + + "github.com/hashicorp/consul-template/config" + dep "github.com/hashicorp/consul-template/dependency" + "github.com/hashicorp/vault/api" +) + +// approle auto-auth setup in watch_test.go, TestMain() +func TestVaultTokenWatcher(t *testing.T) { + // Don't set the below to run in parallel. They mess with the single + // running vault's permissions. + t.Run("noop", func(t *testing.T) { + conf := config.DefaultVaultConfig() + watcher, err := VaultTokenWatcher(testClients, conf, nil) + if err != nil { + t.Error(err) + } + defer watcher.Stop() + select { + case err := <-watcher.ErrCh(): + if err != nil { + t.Error(err) + } + case <-time.After(time.Second): + return + } + }) + + t.Run("fixed_token", func(t *testing.T) { + testClients.Vault().SetToken(vaultToken) + conf := config.DefaultVaultConfig() + token := vaultToken + conf.Token = &token + watcher, err := VaultTokenWatcher(testClients, conf, nil) + if err != nil { + t.Error(err) + } + defer watcher.Stop() + if watcher != nil { + t.Error("watcher should be nil") + } + if testClients.Vault().Token() != vaultToken { + t.Error("Token should be " + vaultToken) + } + }) + + t.Run("secretwrapped_token", func(t *testing.T) { + testClients.Vault().SetToken(vaultToken) + conf := config.DefaultVaultConfig() + data, err := json.Marshal(&api.SecretWrapInfo{Token: vaultToken}) + if err != nil { + t.Error(err) + } + jsonToken := string(data) + conf.Token = &jsonToken + watcher, err := VaultTokenWatcher(testClients, conf, nil) + if err != nil { + t.Error(err) + } + defer watcher.Stop() + if testClients.Vault().Token() != vaultToken { + t.Error("Token should be " + vaultToken) + } + }) + + t.Run("tokenfile", func(t *testing.T) { + // setup + testClients.Vault().SetToken(vaultToken) + tokenfile := runVaultAgent(testClients, tokenRoleId) + defer func() { + testClients.Vault().SetToken(vaultToken) + os.Remove(tokenfile) + }() + conf := config.DefaultVaultConfig() + token := vaultToken + conf.Token = &token + conf.VaultAgentTokenFile = &tokenfile + // test data + doneCh := make(chan struct{}) + watcher, err := VaultTokenWatcher(testClients, conf, doneCh) + if err != nil { + t.Error(err) + } + defer watcher.Stop() + // tests + select { + case <-time.After(time.Millisecond): + // XXX remove this timer in hashicat port + doneCh <- struct{}{} + case err := <-watcher.ErrCh(): + t.Error(err) + } + + if testClients.Vault().Token() == vaultToken { + t.Error("Token should not be " + vaultToken) + } + }) + + t.Run("renew", func(t *testing.T) { + // exercise the renewer: the action is all inside the vault api + // calls and vault so there's little to check.. so we just try + // to call it and make sure it doesn't error + testClients.Vault().SetToken(vaultToken) + renew := true + _, err := testClients.Vault().Auth().Token().Create( + &api.TokenCreateRequest{ + ID: "b_token", + TTL: "1m", + Renewable: &renew, + }) + if err != nil { + t.Error(err) + } + conf := config.DefaultVaultConfig() + token := "b_token" + conf.Token = &token // + conf.RenewToken = &renew + watcher, err := VaultTokenWatcher(testClients, conf, nil) + if err != nil { + t.Error(err) + } + defer watcher.Stop() + + select { + case err := <-watcher.ErrCh(): + if err != nil { + t.Error(err) + } + case <-time.After(time.Millisecond * 100): + // give it a chance to throw an error + } + }) +} + +func TestVaultTokenRefreshToken(t *testing.T) { + watcher := NewWatcher(&NewWatcherInput{ + Clients: testClients, + }) + // force watcher to be synchronous so we can control test flow + watcher.dataCh = make(chan *View) // no buffer + wrapinfo := api.SecretWrapInfo{ + Token: "btoken", + } + b, _ := json.Marshal(wrapinfo) + type testcase struct { + name, raw_token, exp_token string + } + vault := testClients.Vault() + testcases := []testcase{ + {name: "noop", raw_token: "foo", exp_token: "foo"}, + {name: "spaces", raw_token: " foo ", exp_token: "foo"}, + {name: "secretwrap", raw_token: string(b), exp_token: "btoken"}, + } + for i, tc := range testcases { + tc := tc // avoid for-loop pointer wart + name := fmt.Sprintf("%d_%s", i, tc.name) + t.Run(name, func(t *testing.T) { + var wg sync.WaitGroup + dCh := make(chan struct{}) + watchLoop, err := watchTokenFile(watcher, "", "XXX", false, dCh) + if err != nil { + t.Error(err) + } + wg.Add(1) + go func() { + watchLoop() + wg.Done() + }() + fd := fakeDep{name: name} + watcher.dataCh <- &View{dependency: fd, data: tc.raw_token} + close(dCh) // close doneCh to stop watchLoop + wg.Wait() + if vault.Token() != tc.exp_token { + t.Errorf("bad token, expected: '%s', received '%s'", + tc.exp_token, tc.raw_token) + } + }) + } + watcher.Stop() +} + +// When vault-agent uses the wrap_ttl option it writes a json blob instead of +// a raw token. This verifies it will extract the token from that when needed. +// It doesn't test unwrap. The integration test covers that for now. +func TestVaultTokenGetToken(t *testing.T) { + t.Run("table_test", func(t *testing.T) { + wrapinfo := api.SecretWrapInfo{ + Token: "btoken", + } + b, _ := json.Marshal(wrapinfo) + testcases := []struct{ in, out string }{ + {in: "", out: ""}, + {in: "atoken", out: "atoken"}, + {in: string(b), out: "btoken"}, + } + for _, tc := range testcases { + dummy := &setTokenFaker{} + token, _ := unpackToken(dummy, tc.in, false) + if token != tc.out { + t.Errorf("unpackToken, wanted: '%v', got: '%v'", tc.out, token) + } + } + }) + t.Run("unwrap_test", func(t *testing.T) { + vault := testClients.Vault() + vault.SetToken(vaultToken) + vault.SetWrappingLookupFunc(func(operation, path string) string { + if path == "auth/token/create" { + return "30s" + } + return "" + }) + defer vault.SetWrappingLookupFunc(nil) + + secret, err := vault.Auth().Token().Create(&api.TokenCreateRequest{ + Lease: "1h", + }) + if err != nil { + t.Fatal(err) + } + + unwrap := true + wrappedToken := secret.WrapInfo.Token + token, err := unpackToken(vault, wrappedToken, unwrap) + if err != nil { + t.Fatal(err) + } + if token == wrappedToken { + t.Errorf("tokens should not match") + } + }) +} + +type setTokenFaker struct { + Token string +} + +func (t *setTokenFaker) SetToken(token string) {} +func (t *setTokenFaker) Logical() *api.Logical { return nil } + +var _ dep.Dependency = (*fakeDep)(nil) + +type fakeDep struct{ name string } + +func (d fakeDep) String() string { return d.name } +func (d fakeDep) CanShare() bool { return false } +func (d fakeDep) Stop() {} +func (d fakeDep) Type() dep.Type { return dep.TypeConsul } +func (d fakeDep) Fetch(*dep.ClientSet, *dep.QueryOptions) (interface{}, *dep.ResponseMetadata, error) { + return d.name, nil, nil +} diff --git a/watch/view.go b/watch/view.go index f75ae191b..cb378e410 100644 --- a/watch/view.go +++ b/watch/view.go @@ -189,6 +189,7 @@ func (v *View) fetch(doneCh, successCh chan<- struct{}, errCh chan<- error) { allowStale = true } + firstLoop := true // to disable rate limiting on first pass for { // If the view was stopped, short-circuit this loop. This prevents a bug // where a view can get "lost" in the event Consul Template is reloaded. @@ -239,9 +240,10 @@ func (v *View) fetch(doneCh, successCh chan<- struct{}, errCh chan<- error) { allowStale = true } - if dur := rateLimiter(start); dur > 1 { + if dur := rateLimiter(start); dur > 1 && !firstLoop { time.Sleep(dur) } + firstLoop = false // blocking queries that return due to block timeout // will have the same index diff --git a/watch/watch_test.go b/watch/watch_test.go new file mode 100644 index 000000000..1bee8cb53 --- /dev/null +++ b/watch/watch_test.go @@ -0,0 +1,198 @@ +package watch + +import ( + "encoding/json" + "io/ioutil" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + dep "github.com/hashicorp/consul-template/dependency" + "github.com/hashicorp/vault/api" +) + +const ( + vaultAddr = "http://127.0.0.1:8200" + vaultToken = "a_token" +) + +var ( + testVault *vaultServer + testClients *dep.ClientSet + tokenRoleId string +) + +func TestMain(m *testing.M) { + os.Exit(main(m)) +} + +// sub-main so I can use defer +func main(m *testing.M) int { + log.SetOutput(ioutil.Discard) + testVault = newTestVault() + defer func() { testVault.Stop() }() + + clients := dep.NewClientSet() + if err := clients.CreateVaultClient(&dep.CreateVaultClientInput{ + Address: vaultAddr, + Token: vaultToken, + }); err != nil { + panic(err) + } + + testClients = clients + tokenRoleId = vaultTokenSetup(clients) + + return m.Run() +} + +type vaultServer struct { + cmd *exec.Cmd +} + +func (v vaultServer) Stop() error { + if v.cmd != nil && v.cmd.Process != nil { + return v.cmd.Process.Signal(os.Interrupt) + } + return nil +} + +func newTestVault() *vaultServer { + path, err := exec.LookPath("vault") + if err != nil || path == "" { + panic("vault not found on $PATH") + } + args := []string{ + "server", "-dev", "-dev-root-token-id", vaultToken, + "-dev-no-store-token", + } + cmd := exec.Command("vault", args...) + cmd.Stdout = ioutil.Discard + cmd.Stderr = ioutil.Discard + + if err := cmd.Start(); err != nil { + panic("vault failed to start: " + err.Error()) + } + return &vaultServer{ + cmd: cmd, + } +} + +// Sets up approle auto-auth for token generation/testing +func vaultTokenSetup(clients *dep.ClientSet) string { + vc := clients.Vault() + + // vault auth enable approle + err := vc.Sys().EnableAuthWithOptions("approle", + &api.MountInput{ + Type: "approle", + }) + if err != nil && !strings.Contains(err.Error(), "path is already in use") { + panic(err) + } + + // vault policy write foo 'path ...' + err = vc.Sys().PutPolicy("foo", + `path "secret/data/foo" { capabilities = ["read"] }`) + if err != nil { + panic(err) + } + + // vault write auth/approle/role/foo ... + _, err = vc.Logical().Write("auth/approle/role/foo", + map[string]interface{}{ + "token_policies": "foo", + "secret_id_num_uses": 100, + "secret_id_ttl": "5m", + "token_num_users": 10, + "token_ttl": "7m", + "token_max_ttl": "10m", + }) + if err != nil { + panic(err) + } + + var sec *api.Secret + // vault read -field=role_id auth/approle/role/foo/role-id + sec, err = vc.Logical().Read("auth/approle/role/foo/role-id") + if err != nil { + panic(err) + } + role_id := sec.Data["role_id"] + return role_id.(string) +} + +// returns path to token file (which is created by the agent run) +// token file isn't cleaned, so use returned path to remove it when done +func runVaultAgent(clients *dep.ClientSet, role_id string) string { + dir, err := os.MkdirTemp("", "consul-template-test") + if err != nil { + panic(err) + } + defer os.RemoveAll(dir) + + tokenFile := filepath.Join("", "vatoken.txt") + + role_idPath := filepath.Join(dir, "roleid") + secret_idPath := filepath.Join(dir, "secretid") + vaconf := filepath.Join(dir, "vault-agent-config.json") + + // Generate secret_id, need new one for each agent run + // vault write -f -field secret_id auth/approle/role/foo/secret-id + vc := clients.Vault() + sec, err := vc.Logical().Write("auth/approle/role/foo/secret-id", nil) + if err != nil { + panic(err) + } + secret_id := sec.Data["secret_id"].(string) + err = os.WriteFile(secret_idPath, []byte(secret_id), 0o444) + if err != nil { + panic(err) + } + err = os.WriteFile(role_idPath, []byte(role_id), 0o444) + if err != nil { + panic(err) + } + + type obj map[string]interface{} + type list []obj + va := obj{ + "vault": obj{"address": vaultAddr}, + "auto_auth": obj{ + "method": obj{ + "type": "approle", + "config": obj{ + "role_id_file_path": role_idPath, + "secret_id_file_path": secret_idPath, + }, + "wrap_ttl": "5m", + }, + "sinks": list{ + {"sink": obj{"type": "file", "config": obj{"path": tokenFile}}}, + }, + }, + } + txt, err := json.Marshal(va) + if err != nil { + panic(err) + } + err = os.WriteFile(vaconf, txt, 0o644) + if err != nil { + panic(err) + } + + args := []string{ + "agent", "-exit-after-auth", "-config=" + vaconf, + } + cmd := exec.Command("vault", args...) + cmd.Stdout = ioutil.Discard + cmd.Stderr = ioutil.Discard + + if err := cmd.Run(); err != nil { + panic("vault agent failed to run: " + err.Error()) + } + return tokenFile +} diff --git a/watch/watcher.go b/watch/watcher.go index e4c3d6f07..af9cfb2b9 100644 --- a/watch/watcher.go +++ b/watch/watcher.go @@ -78,7 +78,7 @@ type NewWatcherInput struct { } // NewWatcher creates a new watcher using the given API client. -func NewWatcher(i *NewWatcherInput) (*Watcher, error) { +func NewWatcher(i *NewWatcherInput) *Watcher { w := &Watcher{ clients: i.Clients, depViewMap: make(map[string]*View), @@ -92,29 +92,7 @@ func NewWatcher(i *NewWatcherInput) (*Watcher, error) { retryFuncVault: i.RetryFuncVault, retryFuncNomad: i.RetryFuncNomad, } - - // Start a watcher for the Vault renew if that config was specified - if i.RenewVault { - vt, err := dep.NewVaultTokenQuery(i.VaultToken) - if err != nil { - return nil, errors.Wrap(err, "watcher") - } - if _, err := w.Add(vt); err != nil { - return nil, errors.Wrap(err, "watcher") - } - } - - if len(i.VaultAgentTokenFile) > 0 { - vag, err := dep.NewVaultAgentTokenQuery(i.VaultAgentTokenFile) - if err != nil { - return nil, errors.Wrap(err, "watcher") - } - if _, err := w.Add(vag); err != nil { - return nil, errors.Wrap(err, "watcher") - } - } - - return w, nil + return w } // DataCh returns a read-only channel of Views which is populated when a view @@ -125,6 +103,9 @@ func (w *Watcher) DataCh() <-chan *View { // ErrCh returns a read-only channel of errors returned by the upstream. func (w *Watcher) ErrCh() <-chan error { + if w == nil { + return nil + } return w.errCh } @@ -233,6 +214,9 @@ func (w *Watcher) Size() int { // Stop halts this watcher and any currently polling views immediately. If a // view was in the middle of a poll, no data will be returned. func (w *Watcher) Stop() { + if w == nil { + return + } w.Lock() defer w.Unlock() diff --git a/watch/watcher_test.go b/watch/watcher_test.go index 67b7c5acf..0f47273ed 100644 --- a/watch/watcher_test.go +++ b/watch/watcher_test.go @@ -8,13 +8,10 @@ import ( ) func TestAdd_updatesMap(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } d := &TestDep{} if _, err := w.Add(d); err != nil { @@ -28,13 +25,10 @@ func TestAdd_updatesMap(t *testing.T) { } func TestAdd_exists(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } d := &TestDep{} w.depViewMap[d.String()] = &View{} @@ -50,13 +44,10 @@ func TestAdd_exists(t *testing.T) { } func TestAdd_startsViewPoll(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } added, err := w.Add(&TestDep{}) if err != nil { @@ -76,13 +67,10 @@ func TestAdd_startsViewPoll(t *testing.T) { } func TestWatching_notExists(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } d := &TestDep{} if w.Watching(d) == true { @@ -91,13 +79,10 @@ func TestWatching_notExists(t *testing.T) { } func TestWatching_exists(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } d := &TestDep{} if _, err := w.Add(d); err != nil { @@ -110,13 +95,10 @@ func TestWatching_exists(t *testing.T) { } func TestRemove_exists(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } d := &TestDep{} if _, err := w.Add(d); err != nil { @@ -134,13 +116,10 @@ func TestRemove_exists(t *testing.T) { } func TestRemove_doesNotExist(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } removed := w.Remove(&TestDep{}) if removed != false { @@ -149,13 +128,10 @@ func TestRemove_doesNotExist(t *testing.T) { } func TestSize_empty(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } if w.Size() != 0 { t.Errorf("expected %d to be %d", w.Size(), 0) @@ -163,13 +139,10 @@ func TestSize_empty(t *testing.T) { } func TestSize_returnsNumViews(t *testing.T) { - w, err := NewWatcher(&NewWatcherInput{ + w := NewWatcher(&NewWatcherInput{ Clients: dep.NewClientSet(), Once: true, }) - if err != nil { - t.Fatal(err) - } for i := 0; i < 10; i++ { d := &TestDep{name: fmt.Sprintf("%d", i)}