Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loading MFA configs upont restart #15261

Merged
merged 4 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions builtin/logical/transit/path_random.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package transit

import (
"context"

"github.com/hashicorp/vault/helper/random"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
Expand Down
3 changes: 3 additions & 0 deletions changelog/15261.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
auth: load login MFA configuration upon restart
```
43 changes: 36 additions & 7 deletions vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/helper/identity/mfa"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/osutil"
Expand Down Expand Up @@ -2139,7 +2140,6 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c
if err := c.setupQuotas(ctx, false); err != nil {
return err
}
c.setupCachedMFAResponseAuth()

if err := c.setupHeaderHMACKey(ctx, false); err != nil {
return err
Expand All @@ -2161,9 +2161,14 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c
if err := c.loadIdentityStoreArtifacts(ctx); err != nil {
return err
}
if err := loadMFAConfigs(ctx, c); err != nil {
if err := loadPolicyMFAConfigs(ctx, c); err != nil {
return err
}
c.setupCachedMFAResponseAuth()
if err := c.loadLoginMFAConfigs(ctx); err != nil {
return err
}

if err := c.setupAuditedHeadersConfig(ctx); err != nil {
return err
}
Expand Down Expand Up @@ -2325,10 +2330,6 @@ func (c *Core) preSeal() error {
result = multierror.Append(result, fmt.Errorf("error stopping expiration: %w", err))
}
c.stopActivityLog()
// Clear any cached auth response
c.mfaResponseAuthQueueLock.Lock()
c.mfaResponseAuthQueue = nil
c.mfaResponseAuthQueueLock.Unlock()

if err := c.teardownCredentials(context.Background()); err != nil {
result = multierror.Append(result, fmt.Errorf("error tearing down credentials: %w", err))
Expand Down Expand Up @@ -2356,10 +2357,13 @@ func (c *Core) preSeal() error {
seal.StopHealthCheck()
}

c.loginMFABackend.usedCodes = nil
if c.systemBackend != nil && c.systemBackend.mfaBackend != nil {
c.systemBackend.mfaBackend.usedCodes = nil
}
if err := c.teardownLoginMFA(); err != nil {
result = multierror.Append(result, fmt.Errorf("error tearing down login MFA, error: %w", err))
}

preSealPhysical(c)

c.logger.Info("pre-seal teardown complete")
Expand Down Expand Up @@ -3073,6 +3077,31 @@ type LicenseState struct {
Terminated bool
}

func (c *Core) loadLoginMFAConfigs(ctx context.Context) error {
eConfigs := make([]*mfa.MFAEnforcementConfig, 0)
allNamespaces := c.collectNamespaces()
for _, ns := range allNamespaces {
err := c.loginMFABackend.loadMFAMethodConfigs(ctx, ns)
if err != nil {
return fmt.Errorf("error loading MFA method Config, namespaceid %s, error: %w", ns.ID, err)
}

loadedConfigs, err := c.loginMFABackend.loadMFAEnforcementConfigs(ctx, ns)
if err != nil {
return fmt.Errorf("error loading MFA enforcement Config, namespaceid %s, error: %w", ns.ID, err)
}

eConfigs = append(eConfigs, loadedConfigs...)
}

for _, conf := range eConfigs {
if err := c.loginMFABackend.loginMFAMethodExistenceCheck(conf); err != nil {
c.loginMFABackend.mfaLogger.Error("failed to find all MFA methods that exist in MFA enforcement configs", "configID", conf.ID, "namespcaeID", conf.NamespaceID, "error", err.Error())
hghaf099 marked this conversation as resolved.
Show resolved Hide resolved
}
}
return nil
}

type MFACachedAuthResponse struct {
CachedAuth *logical.Auth
RequestPath string
Expand Down
2 changes: 1 addition & 1 deletion vault/core_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func postUnsealPhysical(c *Core) error {
return nil
}

func loadMFAConfigs(context.Context, *Core) error { return nil }
func loadPolicyMFAConfigs(context.Context, *Core) error { return nil }

func shouldStartClusterListener(*Core) bool { return true }

Expand Down
192 changes: 161 additions & 31 deletions vault/login_mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ func NewLoginMFABackend(core *Core, logger hclog.Logger) *LoginMFABackend {
}

func NewMFABackend(core *Core, logger hclog.Logger, prefix string, schemaFuncs []func() *memdb.TableSchema) *MFABackend {
db, _ := SetupMFAMemDB(schemaFuncs)
return &MFABackend{
Core: core,
mfaLock: &sync.RWMutex{},
db: db,
mfaLogger: logger.Named("mfa"),
namespacer: core,
methodTable: prefix,
}
}

func SetupMFAMemDB(schemaFuncs []func() *memdb.TableSchema) (*memdb.MemDB, error) {
mfaSchemas := &memdb.DBSchema{
Tables: make(map[string]*memdb.TableSchema),
}
Expand All @@ -134,15 +146,24 @@ func NewMFABackend(core *Core, logger hclog.Logger, prefix string, schemaFuncs [
mfaSchemas.Tables[schema.Name] = schema
}

db, _ := memdb.NewMemDB(mfaSchemas)
return &MFABackend{
Core: core,
mfaLock: &sync.RWMutex{},
db: db,
mfaLogger: logger.Named("mfa"),
namespacer: core,
methodTable: prefix,
db, err := memdb.NewMemDB(mfaSchemas)
if err != nil {
return nil, err
}
return db, nil
}

func (b *LoginMFABackend) ResetLoginMFAMemDB() error {
var err error

db, err := SetupMFAMemDB(loginMFASchemaFuncs())
if err != nil {
return err
}

b.db = db

return nil
}

func (i *IdentityStore) handleMFAMethodListTOTP(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
Expand Down Expand Up @@ -474,6 +495,103 @@ func (i *IdentityStore) handleLoginMFAAdminDestroyUpdate(ctx context.Context, re
return nil, nil
}

// loadMFAMethodConfigs loads MFA method configs for login MFA
func (b *LoginMFABackend) loadMFAMethodConfigs(ctx context.Context, ns *namespace.Namespace) error {
b.mfaLogger.Trace("loading login MFA configurations")
barrierView, err := b.Core.barrierViewForNamespace(ns.ID)
if err != nil {
return fmt.Errorf("error getting namespace view, namespaceid %s, error %w", ns.ID, err)
}
existing, err := barrierView.List(ctx, loginMFAConfigPrefix)
if err != nil {
return fmt.Errorf("failed to list MFA configurations for namespace path %s and prefix %s: %w", ns.Path, loginMFAConfigPrefix, err)
}
b.mfaLogger.Trace("methods collected", "num_existing", len(existing))

for _, key := range existing {
b.mfaLogger.Trace("loading method", "method", key)

// Read the config from storage
mConfig, err := b.getMFAConfig(ctx, loginMFAConfigPrefix+key, barrierView)
if err != nil {
return err
}

if mConfig == nil {
b.mfaLogger.Trace("failed to find the config related to a method", "namespace", ns.Path, "prefix", loginMFAConfigPrefix, "method", key)
continue
}

// Load the config in MemDB
err = b.MemDBUpsertMFAConfig(ctx, mConfig)
if err != nil {
return fmt.Errorf("failed to load configuration ID %s prefix %s in MemDB: %w", mConfig.ID, loginMFAConfigPrefix, err)
}
}

b.mfaLogger.Trace("configurations restored", "namespace", ns.Path, "prefix", loginMFAConfigPrefix)

return nil
}

// loadMFAEnforcementConfigs loads MFA method configs for login MFA
func (b *LoginMFABackend) loadMFAEnforcementConfigs(ctx context.Context, ns *namespace.Namespace) ([]*mfa.MFAEnforcementConfig, error) {
b.mfaLogger.Trace("loading login MFA enforcement configurations")
barrierView, err := b.Core.barrierViewForNamespace(ns.ID)
if err != nil {
return nil, fmt.Errorf("error getting namespace view, namespaceid %s, error %w", ns.ID, err)
}
existing, err := barrierView.List(ctx, mfaLoginEnforcementPrefix)
if err != nil {
return nil, fmt.Errorf("failed to list MFA enforcement configurations for namespace %s with prefix %s: %w", ns.Path, mfaLoginEnforcementPrefix, err)
}
b.mfaLogger.Trace("enforcements configs collected", "num_existing", len(existing))

eConfigs := make([]*mfa.MFAEnforcementConfig, 0)
for _, key := range existing {
b.mfaLogger.Trace("loading enforcement", "config", key)

// Read the config from storage
mConfig, err := b.getMFALoginEnforcementConfig(ctx, mfaLoginEnforcementPrefix+key, barrierView)
if err != nil {
return nil, err
}

if mConfig == nil {
b.mfaLogger.Trace("failed to find an enforcement config", "namespace", ns.Path, "prefix", mfaLoginEnforcementPrefix, "config", key)
continue
}

// Load the config in MemDB
err = b.MemDBUpsertMFALoginEnforcementConfig(ctx, mConfig)
if err != nil {
return nil, fmt.Errorf("failed to load enforcement configuration ID %s with prefix %s in MemDB: %w", mConfig.ID, mfaLoginEnforcementPrefix, err)
}

eConfigs = append(eConfigs, mConfig)
}

b.mfaLogger.Trace("enforcement configurations restored", "namespace", ns.Path, "prefix", mfaLoginEnforcementPrefix)

return eConfigs, nil
}

func (b *LoginMFABackend) loginMFAMethodExistenceCheck(eConfig *mfa.MFAEnforcementConfig) error {
var aggErr *multierror.Error
for _, confID := range eConfig.MFAMethodIDs {
config, memErr := b.MemDBMFAConfigByID(confID)
if memErr != nil {
aggErr = multierror.Append(aggErr, memErr)
return aggErr.ErrorOrNil()
}
if config == nil {
aggErr = multierror.Append(aggErr, fmt.Errorf("found an MFA method ID in enforcement config, but failed to find the MFA method config method ID %s", confID))
}
}

return aggErr.ErrorOrNil()
}

func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logical.Request, d *framework.FieldData) (retResp *logical.Response, retErr error) {
// mfaReqID is the ID of the login request
mfaReqID := d.Get("mfa_request_id").(string)
Expand Down Expand Up @@ -551,6 +669,22 @@ func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logic
return resp, nil
}

func (c *Core) teardownLoginMFA() error {
if !c.IsDRSecondary() {
// Clear any cached auth response
c.mfaResponseAuthQueueLock.Lock()
c.mfaResponseAuthQueue = nil
c.mfaResponseAuthQueueLock.Unlock()

c.loginMFABackend.usedCodes = nil

if err := c.loginMFABackend.ResetLoginMFAMemDB(); err != nil {
return err
}
}
return nil
}

// LoginMFACreateToken creates a token after the login MFA is validated.
// It also applies the lease quotas on the original login request path.
func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAuth *logical.Auth) (*logical.Response, error) {
Expand Down Expand Up @@ -2320,7 +2454,6 @@ func (b *LoginMFABackend) deleteMFALoginEnforcementConfigByNameAndNamespace(ctx
}

entryIndex := mfaLoginEnforcementPrefix + eConfig.ID

barrierView, err := b.Core.barrierViewForNamespace(eConfig.NamespaceID)
if err != nil {
return err
Expand Down Expand Up @@ -2530,6 +2663,25 @@ func (b *MFABackend) getMFAConfig(ctx context.Context, path string, barrierView
return &mConfig, nil
}

func (b *LoginMFABackend) getMFALoginEnforcementConfig(ctx context.Context, path string, barrierView *BarrierView) (*mfa.MFAEnforcementConfig, error) {
entry, err := barrierView.Get(ctx, path)
if err != nil {
return nil, err
}

if entry == nil {
return nil, nil
}

var mConfig mfa.MFAEnforcementConfig
err = proto.Unmarshal(entry.Value, &mConfig)
if err != nil {
return nil, err
}

return &mConfig, nil
}

func (b *LoginMFABackend) putMFALoginEnforcementConfig(ctx context.Context, eConfig *mfa.MFAEnforcementConfig) error {
entryIndex := mfaLoginEnforcementPrefix + eConfig.ID
marshaledEntry, err := proto.Marshal(eConfig)
Expand All @@ -2548,28 +2700,6 @@ func (b *LoginMFABackend) putMFALoginEnforcementConfig(ctx context.Context, eCon
})
}

func (b *LoginMFABackend) getMFALoginEnforcementConfig(ctx context.Context, key, namespaceId string) (*mfa.MFAEnforcementConfig, error) {
barrierView, err := b.Core.barrierViewForNamespace(namespaceId)
if err != nil {
return nil, err
}
entry, err := barrierView.Get(ctx, mfaLoginEnforcementPrefix+key)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}

var eConfig mfa.MFAEnforcementConfig
err = proto.Unmarshal(entry.Value, &eConfig)
if err != nil {
return nil, err
}

return &eConfig, nil
}

var mfaHelp = map[string][2]string{
"methods-list": {
"Lists all the available MFA methods by their name.",
Expand Down