Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Access Token callback Vault token reload #285

Merged
merged 8 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions bootstrap/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (cp *Processor) Process(

configProviderUrl := cp.flags.ConfigProviderUrl()

// Create new ProviderInfo and initialize it from command-line flag or Variables variables
// Create new ProviderInfo and initialize it from command-line flag or Variables
configProviderInfo, err := NewProviderInfo(cp.envVars, configProviderUrl)
if err != nil {
return err
Expand All @@ -142,23 +142,29 @@ func (cp *Processor) Process(
switch configProviderInfo.UseProvider() {
case true:
var accessToken string
var getAccessToken types.GetAccessTokenCallback

// secretProvider will be nil if not configured to be used. In that case, no access token required.
if secretProvider != nil {
accessToken, err = secretProvider.GetAccessToken(configProviderInfo.serviceConfig.Type, serviceKey)
if err != nil {
return fmt.Errorf(
"failed to get Configuration Provider (%s) access token: %s",
configProviderInfo.serviceConfig.Type,
err.Error())
// Define the callback function to retrieve the Access Token
getAccessToken = func() (string, error) {
accessToken, err = secretProvider.GetAccessToken(configProviderInfo.serviceConfig.Type, serviceKey)
if err != nil {
return "", fmt.Errorf(
"failed to get Configuration Provider (%s) access token: %s",
configProviderInfo.serviceConfig.Type,
err.Error())
}

cp.lc.Infof("Using Configuration Provider access token of length %d", len(accessToken))
return accessToken, nil
}

cp.lc.Infof("Using Config Provider access token of length %d", len(accessToken))
} else {
cp.lc.Info("Not configured to use Config Provider access token")
}

configClient, err := cp.createProviderClient(serviceKey, configStem, accessToken, configProviderInfo.ServiceConfig())
configClient, err := cp.createProviderClient(serviceKey, configStem, getAccessToken, configProviderInfo.ServiceConfig())
if err != nil {
return fmt.Errorf("failed to create Configuration Provider client: %s", err.Error())
}
Expand Down Expand Up @@ -305,7 +311,8 @@ func (cp *Processor) ListenForCustomConfigChanges(
for {
select {
case <-cp.ctx.Done():
cp.lc.Infof("Exiting waiting for custom configuration changes")
configClient.StopWatching()
cp.lc.Infof("Watching for '%s' configuration changes has stopped", sectionName)
return

case ex := <-errorStream:
Expand All @@ -325,11 +332,18 @@ func (cp *Processor) ListenForCustomConfigChanges(
func (cp *Processor) createProviderClient(
serviceKey string,
configStem string,
accessTokenFile string,
getAccessToken types.GetAccessTokenCallback,
providerConfig types.ServiceConfig) (configuration.Client, error) {

var err error
providerConfig.BasePath = filepath.Join(configStem, ConfigVersion, serviceKey)
providerConfig.AccessToken = accessTokenFile
if getAccessToken != nil {
providerConfig.AccessToken, err = getAccessToken()
if err != nil {
return nil, err
}
providerConfig.GetAccessToken = getAccessToken
}

cp.lc.Info(fmt.Sprintf(
"Using Configuration provider (%s) from: %s with base path of %s",
Expand Down Expand Up @@ -431,6 +445,8 @@ func (cp *Processor) listenForChanges(serviceConfig interfaces.Configuration, co
for {
select {
case <-cp.ctx.Done():
configClient.StopWatching()
lc.Infof("Watching for '%s' configuration changes has stopped", writableKey)
return

case ex := <-errorStream:
Expand Down
4 changes: 3 additions & 1 deletion bootstrap/interfaces/secret.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package interfaces

import "time"
import (
"time"
)

// SecretProvider defines the contract for secret provider implementations that
// allow secrets to be retrieved/stored from/to a services Secret Store.
Expand Down
28 changes: 20 additions & 8 deletions bootstrap/registration/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ import (
"errors"
"fmt"

"github.com/edgexfoundry/go-mod-bootstrap/v2/config"
"github.com/edgexfoundry/go-mod-core-contracts/v2/clients/logger"
"github.com/edgexfoundry/go-mod-core-contracts/v2/common"

registryTypes "github.com/edgexfoundry/go-mod-registry/v2/pkg/types"
"github.com/edgexfoundry/go-mod-registry/v2/registry"

"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/container"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/interfaces"
"github.com/edgexfoundry/go-mod-bootstrap/v2/bootstrap/startup"
"github.com/edgexfoundry/go-mod-bootstrap/v2/config"
"github.com/edgexfoundry/go-mod-bootstrap/v2/di"
)

Expand All @@ -42,19 +43,29 @@ func createRegistryClient(

var err error
var accessToken string
var getAccessToken registryTypes.GetAccessTokenCallback

secretProvider := container.SecretProviderFrom(dic.Get)
// secretProvider will be nil if not configured to be used. In that case, no access token required.
if secretProvider != nil {
accessToken, err = secretProvider.GetAccessToken(bootstrapConfig.Registry.Type, serviceKey)
if err != nil {
return nil, fmt.Errorf(
"failed to get Registry (%s) access token: %s",
bootstrapConfig.Registry.Type,
err.Error())
// Define the callback function to retrieve the Access Token
getAccessToken = func() (string, error) {
accessToken, err = secretProvider.GetAccessToken(bootstrapConfig.Registry.Type, serviceKey)
if err != nil {
return "", fmt.Errorf(
"failed to get Registry (%s) access token: %s",
bootstrapConfig.Registry.Type,
err.Error())
}

lc.Infof("Using Registry access token of length %d", len(accessToken))
return accessToken, nil
}

lc.Infof("Using Registry access token of length %d", len(accessToken))
accessToken, err = getAccessToken()
if err != nil {
return nil, err
}
}

registryConfig := registryTypes.Config{
Expand All @@ -68,6 +79,7 @@ func createRegistryClient(
ServiceProtocol: config.DefaultHttpProtocol,
CheckInterval: bootstrapConfig.Service.HealthCheckInterval,
CheckRoute: common.ApiPingRoute,
GetAccessToken: getAccessToken,
}

lc.Info(fmt.Sprintf("Using Registry (%s) from %s", registryConfig.Type, registryConfig.GetRegistryUrl()))
Expand Down
5 changes: 4 additions & 1 deletion bootstrap/secret/secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func NewSecretProvider(

secretConfig, err = getSecretConfig(secretStoreConfig, tokenLoader)
if err == nil {
secureProvider := NewSecureProvider(configuration, lc, tokenLoader)
secureProvider := NewSecureProvider(ctx, configuration, lc, tokenLoader)
var secretClient secrets.SecretClient

lc.Info("Attempting to create secret client")
Expand All @@ -82,6 +82,9 @@ func NewSecretProvider(
break
}

provider = secureProvider
lc.Info("Created SecretClient")

err = secureProvider.LoadServiceSecrets(secretStoreConfig)
if err != nil {
return nil, err
Expand Down
64 changes: 61 additions & 3 deletions bootstrap/secret/secure.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package secret

import (
"context"
"errors"
"fmt"
"os"
Expand All @@ -35,7 +36,11 @@ import (
"github.com/edgexfoundry/go-mod-secrets/v2/secrets"
)

const TokenTypeConsul = "consul"
const (
TokenTypeConsul = "consul"
AccessTokenAuthError = "HTTP response with status code 403"
SecretsAuthError = "Received a '403' response"
)

// SecureProvider implements the SecretProvider interface
type SecureProvider struct {
Expand All @@ -46,17 +51,19 @@ type SecureProvider struct {
secretsCache map[string]map[string]string // secret's path, key, value
cacheMutex *sync.RWMutex
lastUpdated time.Time
ctx context.Context
}

// NewSecureProvider creates & initializes Provider instance for secure secrets.
func NewSecureProvider(config interfaces.Configuration, lc logger.LoggingClient, loader authtokenloader.AuthTokenLoader) *SecureProvider {
func NewSecureProvider(ctx context.Context, config interfaces.Configuration, lc logger.LoggingClient, loader authtokenloader.AuthTokenLoader) *SecureProvider {
provider := &SecureProvider{
configuration: config,
lc: lc,
loader: loader,
secretsCache: make(map[string]map[string]string),
cacheMutex: &sync.RWMutex{},
lastUpdated: time.Now(),
ctx: ctx,
}
return provider
}
Expand All @@ -80,6 +87,13 @@ func (p *SecureProvider) GetSecret(path string, keys ...string) (map[string]stri
}

secureSecrets, err := p.secretClient.GetSecrets(path, keys...)

retry, err := p.reloadTokenOnAuthError(err)
if retry {
// Retry with potential new token
secureSecrets, err = p.secretClient.GetSecrets(path, keys...)
}

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -142,6 +156,13 @@ func (p *SecureProvider) StoreSecret(path string, secrets map[string]string) err
}

err := p.secretClient.StoreSecrets(path, secrets)

retry, err := p.reloadTokenOnAuthError(err)
if retry {
// Retry with potential new token
err = p.secretClient.StoreSecrets(path, secrets)
}

if err != nil {
return err
}
Expand All @@ -156,6 +177,30 @@ func (p *SecureProvider) StoreSecret(path string, secrets map[string]string) err
return nil
}

func (p *SecureProvider) reloadTokenOnAuthError(err error) (bool, error) {
if err == nil {
return false, nil
}

if !strings.Contains(err.Error(), SecretsAuthError) &&
!strings.Contains(err.Error(), AccessTokenAuthError) {
return false, err
}

// Reload token in case new token was created causing the auth error
token, err := p.loader.Load(p.configuration.GetBootstrap().SecretStore.TokenFile)
if err != nil {
return false, err
}

err = p.secretClient.SetAuthToken(p.ctx, token)
if err != nil {
return false, err
}

return true, nil
}

// SecretsUpdated is not need for secure secrets as this is handled when secrets are stored.
func (p *SecureProvider) SecretsUpdated() {
// Do nothing
Expand All @@ -170,7 +215,20 @@ func (p *SecureProvider) SecretsLastUpdated() time.Time {
func (p *SecureProvider) GetAccessToken(tokenType string, serviceKey string) (string, error) {
switch tokenType {
case TokenTypeConsul:
return p.secretClient.GenerateConsulToken(serviceKey)
token, err := p.secretClient.GenerateConsulToken(serviceKey)

retry, err := p.reloadTokenOnAuthError(err)
if retry {
// Retry with potential new token
token, err = p.secretClient.GenerateConsulToken(serviceKey)
}

if err != nil {
return "", err
}

return token, nil

default:
return "", fmt.Errorf("invalid access token type '%s'", tokenType)
}
Expand Down
Loading