Skip to content

Commit

Permalink
Merge pull request #1645 from hashicorp/vault-token-file-refactor
Browse files Browse the repository at this point in the history
vault token management moved into separate watcher
  • Loading branch information
eikenb authored Sep 29, 2022
2 parents c677df8 + 938aed7 commit 3c60253
Show file tree
Hide file tree
Showing 13 changed files with 679 additions and 162 deletions.
23 changes: 0 additions & 23 deletions dependency/client_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,33 +347,10 @@ func (c *ClientSet) CreateVaultClient(i *CreateVaultClientInput) error {
}
}

// Set the token if given
if i.Token != "" {
client.SetToken(i.Token)
}

// Check if we are unwrapping
if i.UnwrapToken {
secret, err := client.Logical().Unwrap(i.Token)
if err != nil {
return fmt.Errorf("client set: vault unwrap: %s", err)
}

if secret == nil {
return fmt.Errorf("client set: vault unwrap: no secret")
}

if secret.Auth == nil {
return fmt.Errorf("client set: vault unwrap: no secret auth")
}

if secret.Auth.ClientToken == "" {
return fmt.Errorf("client set: vault unwrap: no token returned")
}

client.SetToken(secret.Auth.ClientToken)
}

// Save the data on ourselves
c.Lock()
c.vault = &vaultClient{
Expand Down
31 changes: 0 additions & 31 deletions dependency/client_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,6 @@ import (
"github.com/stretchr/testify/require"
)

func TestClientSet_unwrapVaultToken(t *testing.T) {
// Don't use t.Parallel() here as the SetWrappingLookupFunc is a global
// setting and breaks other tests if run in parallel

vault := testClients.Vault()

// Create a wrapped token
vault.SetWrappingLookupFunc(func(operation, path string) string {
return "30s"
})
defer vault.SetWrappingLookupFunc(nil)

wrappedToken, err := vault.Auth().Token().Create(&api.TokenCreateRequest{
Lease: "1h",
})
if err != nil {
t.Fatal(err)
}

token := vault.Token()

if token == wrappedToken.WrapInfo.Token {
t.Errorf("expected %q to not be %q", token,
wrappedToken.WrapInfo.Token)
}

if _, err := vault.Auth().Token().LookupSelf(); err != nil {
t.Fatal(err)
}
}

func TestClientSet_K8SServiceTokenAuth(t *testing.T) {
t.Parallel()

Expand Down
10 changes: 4 additions & 6 deletions dependency/vault_agent_token.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package dependency

import (
"io/ioutil"
"log"
"os"
"strings"
"time"

"github.com/pkg/errors"
Expand Down Expand Up @@ -39,6 +37,7 @@ func NewVaultAgentTokenQuery(path string) (*VaultAgentTokenQuery, error) {
func (d *VaultAgentTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions) (interface{}, *ResponseMetadata, error) {
log.Printf("[TRACE] %s: READ %s", d, d.path)

var token string
select {
case <-d.stopCh:
log.Printf("[TRACE] %s: stopped", d)
Expand All @@ -50,16 +49,15 @@ func (d *VaultAgentTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions) (in

log.Printf("[TRACE] %s: reported change", d)

token, err := ioutil.ReadFile(d.path)
raw_token, err := os.ReadFile(d.path)
if err != nil {
return "", nil, errors.Wrap(err, d.String())
}

d.stat = r.stat
clients.Vault().SetToken(strings.TrimSpace(string(token)))
token = string(raw_token)
}

return respWithMetadata("")
return respWithMetadata(token)
}

// CanShare returns if this dependency is sharable.
Expand Down
13 changes: 4 additions & 9 deletions dependency/vault_agent_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ func TestVaultAgentTokenQuery_Fetch(t *testing.T) {
// Don't use t.Parallel() here as the SetToken() calls are global and break
// other tests if run in parallel

// reset token back to original
vc := testClients.Vault()
token := vc.Token()
defer vc.SetToken(token)

// Set up the Vault token file.
tokenFile, err := ioutil.TempFile("", "token1")
if err != nil {
Expand All @@ -33,22 +28,22 @@ func TestVaultAgentTokenQuery_Fetch(t *testing.T) {
}

clientSet := testClients
_, _, err = d.Fetch(clientSet, nil)
token, _, err := d.Fetch(clientSet, nil)
if err != nil {
t.Fatal(err)
}

assert.Equal(t, "token", clientSet.Vault().Token())
assert.Equal(t, "token", token)

// Update the contents.
renderer.AtomicWrite(
tokenFile.Name(), false, []byte("another_token"), 0o644, false)
_, _, err = d.Fetch(clientSet, nil)
token, _, err = d.Fetch(clientSet, nil)
if err != nil {
t.Fatal(err)
}

assert.Equal(t, "another_token", clientSet.Vault().Token())
assert.Equal(t, "another_token", token)
}

func TestVaultAgentTokenQuery_Fetch_missingFile(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions dependency/vault_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,8 @@ func TestVaultReadQuery_Fetch_NonSecrets(t *testing.T) {
vc := clients.Vault()

err = vc.Sys().EnableAuth("approle", "approle", "")
if err != nil {
t.Fatal(err)
if err != nil && !strings.Contains(err.Error(), "path is already in use") {
t.Fatalf("(%T) %s\n", err, err)
}

_, err = vc.Logical().Write("auth/approle/role/my-approle", nil)
Expand Down
14 changes: 9 additions & 5 deletions manager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ func TestMain(m *testing.M) {
}
testConsul = consul

clients := dep.NewClientSet()
if err := clients.CreateConsulClient(&dep.CreateConsulClientInput{
Address: testConsul.HTTPAddr,
}); err != nil {
consulConfig := config.DefaultConsulConfig()
consulConfig.Address = &testConsul.HTTPAddr
clients, err := NewClientSet(&config.Config{
Consul: consulConfig,
Vault: config.DefaultVaultConfig(),
Nomad: config.DefaultNomadConfig(),
})
if err != nil {
testConsul.Stop()
log.Fatal(err)
log.Fatal(fmt.Errorf("failed to start clients: %v", err))
}
testClients = clients

Expand Down
59 changes: 34 additions & 25 deletions manager/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ type Runner struct {
// dependenciesLock is a lock around touching the dependencies map.
dependenciesLock sync.Mutex

// token watcher
vaultTokenWatcher *watch.Watcher
// watcher is the watcher this runner is using.
watcher *watch.Watcher

Expand Down Expand Up @@ -181,10 +183,21 @@ func NewRunner(config *config.Config, dry bool) (*Runner, error) {
dry: dry,
}

if err := runner.init(); err != nil {
// Create the clientset
clients, err := NewClientSet(config)
if err != nil {
return nil, fmt.Errorf("runner: %w", err)
}
// needs to be run early to do initial token handling
runner.vaultTokenWatcher, err = watch.VaultTokenWatcher(
clients, config.Vault, runner.DoneCh)
if err != nil {
return nil, err
}

if err := runner.init(clients); err != nil {
return nil, err
}
return runner, nil
}

Expand Down Expand Up @@ -226,7 +239,7 @@ func (r *Runner) Start() {

if r.child != nil {
r.stopDedup()
r.stopWatcher()
r.stopWatchers()

log.Printf("[INFO] (runner) waiting for child process to exit")
select {
Expand Down Expand Up @@ -330,7 +343,7 @@ func (r *Runner) Start() {

if r.child != nil {
r.stopDedup()
r.stopWatcher()
r.stopWatchers()

log.Printf("[INFO] (runner) waiting for child process to exit")
select {
Expand Down Expand Up @@ -384,6 +397,12 @@ func (r *Runner) Start() {
r.ErrCh <- err
return

case err := <-r.vaultTokenWatcher.ErrCh():
// Push the error back up the stack
log.Printf("[ERR] (runner): %s", err)
r.ErrCh <- err
return

case tmpl := <-r.quiescenceCh:
// Remove the quiescence for this template from the map. This will force
// the upcoming Run call to actually evaluate and render the template.
Expand Down Expand Up @@ -455,7 +474,7 @@ func (r *Runner) internalStop(immediately bool) {

log.Printf("[INFO] (runner) stopping")
r.stopDedup()
r.stopWatcher()
r.stopWatchers()
r.stopChild(immediately)

if err := r.deletePid(); err != nil {
Expand All @@ -475,11 +494,15 @@ func (r *Runner) stopDedup() {
}
}

func (r *Runner) stopWatcher() {
func (r *Runner) stopWatchers() {
if r.watcher != nil {
log.Printf("[DEBUG] (runner) stopping watcher")
r.watcher.Stop()
}
if r.vaultTokenWatcher != nil {
log.Printf("[DEBUG] (runner) stopping vault token watcher")
r.vaultTokenWatcher.Stop()
}
}

func (r *Runner) stopChild(immediately bool) {
Expand Down Expand Up @@ -885,7 +908,7 @@ func (r *Runner) runTemplate(tmpl *template.Template, runCtx *templateRunCtx) (*

// init() creates the Runner's underlying data structures and returns an error
// if any problems occur.
func (r *Runner) init() error {
func (r *Runner) init(clients *dep.ClientSet) error {
// Ensure default configuration values
r.config = config.DefaultConfig().Merge(r.config)
r.config.Finalize()
Expand All @@ -900,18 +923,8 @@ func (r *Runner) init() error {
dep.SetVaultDefaultLeaseDuration(config.TimeDurationVal(r.config.Vault.DefaultLeaseDuration))
dep.SetVaultLeaseRenewalThreshold(*r.config.Vault.LeaseRenewalThreshold)

// Create the clientset
clients, err := newClientSet(r.config)
if err != nil {
return fmt.Errorf("runner: %s", err)
}

// Create the watcher
watcher, err := newWatcher(r.config, clients, r.config.Once)
if err != nil {
return fmt.Errorf("runner: %s", err)
}
r.watcher = watcher
r.watcher = newWatcher(r.config, clients, r.config.Once)

numTemplates := len(*r.config.Templates)
templates := make([]*template.Template, 0, numTemplates)
Expand Down Expand Up @@ -1291,8 +1304,8 @@ func findCommand(c *config.TemplateConfig, templates []*config.TemplateConfig) *
return nil
}

// newClientSet creates a new client set from the given config.
func newClientSet(c *config.Config) (*dep.ClientSet, error) {
// NewClientSet creates a new client set from the given config.
func NewClientSet(c *config.Config) (*dep.ClientSet, error) {
clients := dep.NewClientSet()

if err := clients.CreateConsulClient(&dep.CreateConsulClientInput{
Expand Down Expand Up @@ -1378,10 +1391,10 @@ func newClientSet(c *config.Config) (*dep.ClientSet, error) {
}

// newWatcher creates a new watcher.
func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) (*watch.Watcher, error) {
func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) *watch.Watcher {
log.Printf("[INFO] (runner) creating watcher")

w, err := watch.NewWatcher(&watch.NewWatcherInput{
return watch.NewWatcher(&watch.NewWatcherInput{
Clients: clients,
MaxStale: config.TimeDurationVal(c.MaxStale),
Once: c.Once,
Expand All @@ -1396,8 +1409,4 @@ func newWatcher(c *config.Config, clients *dep.ClientSet, once bool) (*watch.Wat
VaultToken: clients.Vault().Token(),
RetryFuncNomad: watch.RetryFunc(c.Nomad.Retry.RetryFunc()),
})
if err != nil {
return nil, errors.Wrap(err, "runner")
}
return w, nil
}
Loading

0 comments on commit 3c60253

Please sign in to comment.