Skip to content

Commit

Permalink
vault token management moved into separate watcher
Browse files Browse the repository at this point in the history
Refactor the vault token management (renewing, tokenfile watching,
unwrapping, etc) into a dedicated watcher that is only responsible for
that.

Done to encapsulate the vault management token into one place where it
can be more easily understood and tested. It was scattered about and
inconsistent (eg. it only tried to unwrap the first token).

Strips the vault token code out of current watcher and client_set.

Most of this commit is the new test suite for this code as it wasn't
really tested before.
  • Loading branch information
eikenb committed Sep 29, 2022
1 parent 7e3c9ba commit 9f97fe1
Show file tree
Hide file tree
Showing 14 changed files with 664 additions and 173 deletions.
19 changes: 17 additions & 2 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/hashicorp/consul-template/service_os"
"github.com/hashicorp/consul-template/signals"
"github.com/hashicorp/consul-template/version"
"github.com/hashicorp/consul-template/watch"
)

// Exit codes are int values that represent an exit code for a particular error.
Expand Down Expand Up @@ -103,8 +104,20 @@ func (cli *CLI) Run(args []string) int {
return ExitCodeOK
}

// Create the clientset
clients, err := manager.NewClientSet(config)
if err != nil {
return logError(err, ExitCodeConfigError)
}

// vault token watcher
vtwatchErrCh := watch.VaultTokenWatcher(clients, config.Vault)
if err != nil {
return logError(err, ExitCodeRunnerError)
}

// Initial runner
runner, err := manager.NewRunner(config, dry)
runner, err := manager.NewRunner(clients, config, dry)
if err != nil {
return logError(err, ExitCodeRunnerError)
}
Expand All @@ -115,6 +128,8 @@ func (cli *CLI) Run(args []string) int {

for {
select {
case err := <-vtwatchErrCh:
return logError(err, ExitCodeRunnerError)
case err := <-runner.ErrCh:
// Check if the runner's error returned a specific exit status, and return
// that value. If no value was given, return a generic exit status.
Expand Down Expand Up @@ -150,7 +165,7 @@ func (cli *CLI) Run(args []string) int {
return logError(err, ExitCodeConfigError)
}

runner, err = manager.NewRunner(config, dry)
runner, err = manager.NewRunner(clients, config, dry)
if err != nil {
return logError(err, ExitCodeRunnerError)
}
Expand Down
23 changes: 0 additions & 23 deletions dependency/client_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,33 +347,10 @@ func (c *ClientSet) CreateVaultClient(i *CreateVaultClientInput) error {
}
}

// Set the token if given
if i.Token != "" {
client.SetToken(i.Token)
}

// Check if we are unwrapping
if i.UnwrapToken {
secret, err := client.Logical().Unwrap(i.Token)
if err != nil {
return fmt.Errorf("client set: vault unwrap: %s", err)
}

if secret == nil {
return fmt.Errorf("client set: vault unwrap: no secret")
}

if secret.Auth == nil {
return fmt.Errorf("client set: vault unwrap: no secret auth")
}

if secret.Auth.ClientToken == "" {
return fmt.Errorf("client set: vault unwrap: no token returned")
}

client.SetToken(secret.Auth.ClientToken)
}

// Save the data on ourselves
c.Lock()
c.vault = &vaultClient{
Expand Down
31 changes: 0 additions & 31 deletions dependency/client_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,6 @@ import (
"github.com/stretchr/testify/require"
)

func TestClientSet_unwrapVaultToken(t *testing.T) {
// Don't use t.Parallel() here as the SetWrappingLookupFunc is a global
// setting and breaks other tests if run in parallel

vault := testClients.Vault()

// Create a wrapped token
vault.SetWrappingLookupFunc(func(operation, path string) string {
return "30s"
})
defer vault.SetWrappingLookupFunc(nil)

wrappedToken, err := vault.Auth().Token().Create(&api.TokenCreateRequest{
Lease: "1h",
})
if err != nil {
t.Fatal(err)
}

token := vault.Token()

if token == wrappedToken.WrapInfo.Token {
t.Errorf("expected %q to not be %q", token,
wrappedToken.WrapInfo.Token)
}

if _, err := vault.Auth().Token().LookupSelf(); err != nil {
t.Fatal(err)
}
}

func TestClientSet_K8SServiceTokenAuth(t *testing.T) {
t.Parallel()

Expand Down
10 changes: 4 additions & 6 deletions dependency/vault_agent_token.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package dependency

import (
"io/ioutil"
"log"
"os"
"strings"
"time"

"github.com/pkg/errors"
Expand Down Expand Up @@ -39,6 +37,7 @@ func NewVaultAgentTokenQuery(path string) (*VaultAgentTokenQuery, error) {
func (d *VaultAgentTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions) (interface{}, *ResponseMetadata, error) {
log.Printf("[TRACE] %s: READ %s", d, d.path)

var token string
select {
case <-d.stopCh:
log.Printf("[TRACE] %s: stopped", d)
Expand All @@ -50,16 +49,15 @@ func (d *VaultAgentTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions) (in

log.Printf("[TRACE] %s: reported change", d)

token, err := ioutil.ReadFile(d.path)
raw_token, err := os.ReadFile(d.path)
if err != nil {
return "", nil, errors.Wrap(err, d.String())
}

d.stat = r.stat
clients.Vault().SetToken(strings.TrimSpace(string(token)))
token = string(raw_token)
}

return respWithMetadata("")
return respWithMetadata(token)
}

// CanShare returns if this dependency is sharable.
Expand Down
13 changes: 4 additions & 9 deletions dependency/vault_agent_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ func TestVaultAgentTokenQuery_Fetch(t *testing.T) {
// Don't use t.Parallel() here as the SetToken() calls are global and break
// other tests if run in parallel

// reset token back to original
vc := testClients.Vault()
token := vc.Token()
defer vc.SetToken(token)

// Set up the Vault token file.
tokenFile, err := ioutil.TempFile("", "token1")
if err != nil {
Expand All @@ -33,22 +28,22 @@ func TestVaultAgentTokenQuery_Fetch(t *testing.T) {
}

clientSet := testClients
_, _, err = d.Fetch(clientSet, nil)
token, _, err := d.Fetch(clientSet, nil)
if err != nil {
t.Fatal(err)
}

assert.Equal(t, "token", clientSet.Vault().Token())
assert.Equal(t, "token", token)

// Update the contents.
renderer.AtomicWrite(
tokenFile.Name(), false, []byte("another_token"), 0o644, false)
_, _, err = d.Fetch(clientSet, nil)
token, _, err = d.Fetch(clientSet, nil)
if err != nil {
t.Fatal(err)
}

assert.Equal(t, "another_token", clientSet.Vault().Token())
assert.Equal(t, "another_token", token)
}

func TestVaultAgentTokenQuery_Fetch_missingFile(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions dependency/vault_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,8 @@ func TestVaultReadQuery_Fetch_NonSecrets(t *testing.T) {
vc := clients.Vault()

err = vc.Sys().EnableAuth("approle", "approle", "")
if err != nil {
t.Fatal(err)
if err != nil && !strings.Contains(err.Error(), "path is already in use") {
t.Fatalf("(%T) %s\n", err, err)
}

_, err = vc.Logical().Write("auth/approle/role/my-approle", nil)
Expand Down
14 changes: 9 additions & 5 deletions manager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ func TestMain(m *testing.M) {
}
testConsul = consul

clients := dep.NewClientSet()
if err := clients.CreateConsulClient(&dep.CreateConsulClientInput{
Address: testConsul.HTTPAddr,
}); err != nil {
consulConfig := config.DefaultConsulConfig()
consulConfig.Address = &testConsul.HTTPAddr
clients, err := NewClientSet(&config.Config{
Consul: consulConfig,
Vault: config.DefaultVaultConfig(),
Nomad: config.DefaultNomadConfig(),
})
if err != nil {
testConsul.Stop()
log.Fatal(err)
log.Fatal(fmt.Errorf("failed to start clients: %v", err))
}
testClients = clients

Expand Down
30 changes: 8 additions & 22 deletions manager/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ type RenderEvent struct {

// NewRunner accepts a slice of TemplateConfigs and returns a pointer to the new
// Runner and any error that occurred during creation.
func NewRunner(config *config.Config, dry bool) (*Runner, error) {
func NewRunner(clients *dep.ClientSet, config *config.Config, dry bool) (*Runner, error) {
log.Printf("[INFO] (runner) creating new runner (dry: %v, once: %v)",
dry, config.Once)

Expand All @@ -181,7 +181,7 @@ func NewRunner(config *config.Config, dry bool) (*Runner, error) {
dry: dry,
}

if err := runner.init(); err != nil {
if err := runner.init(clients); err != nil {
return nil, err
}

Expand Down Expand Up @@ -885,7 +885,7 @@ func (r *Runner) runTemplate(tmpl *template.Template, runCtx *templateRunCtx) (*

// init() creates the Runner's underlying data structures and returns an error
// if any problems occur.
func (r *Runner) init() error {
func (r *Runner) init(clients *dep.ClientSet) error {
// Ensure default configuration values
r.config = config.DefaultConfig().Merge(r.config)
r.config.Finalize()
Expand All @@ -900,18 +900,8 @@ func (r *Runner) init() error {
dep.SetVaultDefaultLeaseDuration(config.TimeDurationVal(r.config.Vault.DefaultLeaseDuration))
dep.SetVaultLeaseRenewalThreshold(*r.config.Vault.LeaseRenewalThreshold)

// Create the clientset
clients, err := newClientSet(r.config)
if err != nil {
return fmt.Errorf("runner: %s", err)
}

// Create the watcher
watcher, err := newWatcher(r.config, clients, r.config.Once)
if err != nil {
return fmt.Errorf("runner: %s", err)
}
r.watcher = watcher
r.watcher = newWatcher(r.config, clients, r.config.Once)

numTemplates := len(*r.config.Templates)
templates := make([]*template.Template, 0, numTemplates)
Expand Down Expand Up @@ -1291,8 +1281,8 @@ func findCommand(c *config.TemplateConfig, templates []*config.TemplateConfig) *
return nil
}

// newClientSet creates a new client set from the given config.
func newClientSet(c *config.Config) (*dep.ClientSet, error) {
// NewClientSet creates a new client set from the given config.
func NewClientSet(c *config.Config) (*dep.ClientSet, error) {
clients := dep.NewClientSet()

if err := clients.CreateConsulClient(&dep.CreateConsulClientInput{
Expand Down Expand Up @@ -1378,10 +1368,10 @@ func newClientSet(c *config.Config) (*dep.ClientSet, error) {
}

// newWatcher creates a new watcher.
func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) (*watch.Watcher, error) {
func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) *watch.Watcher {
log.Printf("[INFO] (runner) creating watcher")

w, err := watch.NewWatcher(&watch.NewWatcherInput{
return watch.NewWatcher(&watch.NewWatcherInput{
Clients: clients,
MaxStale: config.TimeDurationVal(c.MaxStale),
Once: c.Once,
Expand All @@ -1396,8 +1386,4 @@ func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) (*watch.Wat
VaultToken: clients.Vault().Token(),
RetryFuncNomad: watch.RetryFunc(c.Nomad.Retry.RetryFunc()),
})
if err != nil {
return nil, errors.Wrap(err, "runner")
}
return w, nil
}
Loading

0 comments on commit 9f97fe1

Please sign in to comment.