Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

identity: group refresh shouldn't lock unless an update is needed #8795

Merged
merged 1 commit into from
Apr 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions helper/storagepacker/storagepacker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"fmt"
"strconv"
"strings"
"time"

"github.com/armon/go-metrics"
"github.com/golang/protobuf/proto"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-hclog"
Expand Down Expand Up @@ -129,6 +131,7 @@ func (s *StoragePacker) DeleteItem(_ context.Context, itemID string) error {
}

func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Logger, itemIDs ...string) error {
defer metrics.MeasureSince([]string{"storage_packer", "delete_items"}, time.Now())
var err error
switch len(itemIDs) {
case 0:
Expand Down Expand Up @@ -254,6 +257,7 @@ func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Lo
}

func (s *StoragePacker) putBucket(ctx context.Context, bucket *Bucket) error {
defer metrics.MeasureSince([]string{"storage_packer", "put_bucket"}, time.Now())
if bucket == nil {
return fmt.Errorf("nil bucket entry")
}
Expand Down Expand Up @@ -293,6 +297,8 @@ func (s *StoragePacker) putBucket(ctx context.Context, bucket *Bucket) error {
// GetItem fetches the storage entry for a given key from its corresponding
// bucket.
func (s *StoragePacker) GetItem(itemID string) (*Item, error) {
defer metrics.MeasureSince([]string{"storage_packer", "get_item"}, time.Now())

if itemID == "" {
return nil, fmt.Errorf("empty item ID")
}
Expand Down Expand Up @@ -320,6 +326,8 @@ func (s *StoragePacker) GetItem(itemID string) (*Item, error) {

// PutItem stores the given item in its respective bucket
func (s *StoragePacker) PutItem(_ context.Context, item *Item) error {
defer metrics.MeasureSince([]string{"storage_packer", "put_item"}, time.Now())

if item == nil {
return fmt.Errorf("nil item")
}
Expand Down
4 changes: 4 additions & 0 deletions vault/identity_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"
"fmt"
"strings"
"time"

metrics "github.com/armon/go-metrics"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
Expand Down Expand Up @@ -478,6 +480,8 @@ func (i *IdentityStore) entityByAliasFactorsInTxn(txn *memdb.Txn, mountAccessor,
// CreateOrFetchEntity creates a new entity. This is used by core to
// associate each login attempt by an alias to a unified entity in Vault.
func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.Alias) (*identity.Entity, error) {
defer metrics.MeasureSince([]string{"identity", "create_or_fetch_entity"}, time.Now())

var entity *identity.Entity
var err error
var update bool
Expand Down
163 changes: 105 additions & 58 deletions vault/identity_store_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"fmt"
"strings"
"sync"
"time"

metrics "github.com/armon/go-metrics"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/errwrap"
memdb "github.com/hashicorp/go-memdb"
Expand Down Expand Up @@ -330,6 +332,7 @@ func (i *IdentityStore) loadEntities(ctx context.Context) error {
// updated, in which case, callers should send in both entity and
// previousEntity.
func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error {
defer metrics.MeasureSince([]string{"identity", "upsert_entity_txn"}, time.Now())
var err error

if txn == nil {
Expand Down Expand Up @@ -485,6 +488,7 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
// updated, in which case, callers should send in both entity and
// previousEntity.
func (i *IdentityStore) upsertEntity(ctx context.Context, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error {
defer metrics.MeasureSince([]string{"identity", "upsert_entity"}, time.Now())

// Create a MemDB transaction to update both alias and entity
txn := i.db.Txn(true)
Expand Down Expand Up @@ -1365,6 +1369,8 @@ func (i *IdentityStore) MemDBGroupByName(ctx context.Context, groupName string,
}

func (i *IdentityStore) UpsertGroup(ctx context.Context, group *identity.Group, persist bool) error {
defer metrics.MeasureSince([]string{"identity", "upsert_group"}, time.Now())

txn := i.db.Txn(true)
defer txn.Abort()

Expand All @@ -1379,6 +1385,8 @@ func (i *IdentityStore) UpsertGroup(ctx context.Context, group *identity.Group,
}

func (i *IdentityStore) UpsertGroupInTxn(ctx context.Context, txn *memdb.Txn, group *identity.Group, persist bool) error {
defer metrics.MeasureSince([]string{"identity", "upsert_group_txn"}, time.Now())

var err error

if txn == nil {
Expand Down Expand Up @@ -1879,90 +1887,129 @@ func (i *IdentityStore) MemDBGroupByAliasID(aliasID string, clone bool) (*identi
}

func (i *IdentityStore) refreshExternalGroupMembershipsByEntityID(ctx context.Context, entityID string, groupAliases []*logical.Alias) ([]*logical.Alias, error) {
i.logger.Debug("refreshing external group memberships", "entity_id", entityID, "group_aliases", groupAliases)
defer metrics.MeasureSince([]string{"identity", "refresh_external_groups"}, time.Now())

if entityID == "" {
return nil, fmt.Errorf("empty entity ID")
}

i.groupLock.Lock()
defer i.groupLock.Unlock()

txn := i.db.Txn(true)
defer txn.Abort()
refreshFunc := func(dryRun bool) (bool, []*logical.Alias, error) {

oldGroups, err := i.MemDBGroupsByMemberEntityIDInTxn(txn, entityID, true, true)
if err != nil {
return nil, err
}
if !dryRun {
i.groupLock.Lock()
defer i.groupLock.Unlock()
}

mountAccessor := ""
if len(groupAliases) != 0 {
mountAccessor = groupAliases[0].MountAccessor
}
txn := i.db.Txn(!dryRun)
defer txn.Abort()

var newGroups []*identity.Group
var validAliases []*logical.Alias
for _, alias := range groupAliases {
aliasByFactors, err := i.MemDBAliasByFactors(alias.MountAccessor, alias.Name, true, true)
oldGroups, err := i.MemDBGroupsByMemberEntityIDInTxn(txn, entityID, true, true)
if err != nil {
return nil, err
return false, nil, err
}
if aliasByFactors == nil {
continue
}
mappingGroup, err := i.MemDBGroupByAliasID(aliasByFactors.ID, true)
if err != nil {
return nil, err

mountAccessor := ""
if len(groupAliases) != 0 {
mountAccessor = groupAliases[0].MountAccessor
}
if mappingGroup == nil {
return nil, fmt.Errorf("group unavailable for a valid alias ID %q", aliasByFactors.ID)

var newGroups []*identity.Group
var validAliases []*logical.Alias
for _, alias := range groupAliases {
aliasByFactors, err := i.MemDBAliasByFactorsInTxn(txn, alias.MountAccessor, alias.Name, true, true)
if err != nil {
return false, nil, err
}
if aliasByFactors == nil {
continue
}
mappingGroup, err := i.MemDBGroupByAliasIDInTxn(txn, aliasByFactors.ID, true)
if err != nil {
return false, nil, err
}
if mappingGroup == nil {
return false, nil, fmt.Errorf("group unavailable for a valid alias ID %q", aliasByFactors.ID)
}

newGroups = append(newGroups, mappingGroup)
validAliases = append(validAliases, alias)
}

newGroups = append(newGroups, mappingGroup)
validAliases = append(validAliases, alias)
}
diff := diffGroups(oldGroups, newGroups)

diff := diffGroups(oldGroups, newGroups)
// Add the entity ID to all the new groups
for _, group := range diff.New {
if group.Type != groupTypeExternal {
continue
}

// Add the entity ID to all the new groups
for _, group := range diff.New {
if group.Type != groupTypeExternal {
continue
}
// We need to update a group, if we are in a dry run we should
// report back that a change needs to take place.
if dryRun {
return true, nil, nil
}

i.logger.Debug("adding member entity ID to external group", "member_entity_id", entityID, "group_id", group.ID)
i.logger.Debug("adding member entity ID to external group", "member_entity_id", entityID, "group_id", group.ID)

group.MemberEntityIDs = append(group.MemberEntityIDs, entityID)
group.MemberEntityIDs = append(group.MemberEntityIDs, entityID)

err = i.UpsertGroupInTxn(ctx, txn, group, true)
if err != nil {
return nil, err
err = i.UpsertGroupInTxn(ctx, txn, group, true)
if err != nil {
return false, nil, err
}
}
}

// Remove the entity ID from all the deleted groups
for _, group := range diff.Deleted {
if group.Type != groupTypeExternal {
continue
}
// Remove the entity ID from all the deleted groups
for _, group := range diff.Deleted {
if group.Type != groupTypeExternal {
continue
}

// If the external group is from a different mount, don't remove the
// entity ID from it.
if mountAccessor != "" && group.Alias != nil && group.Alias.MountAccessor != mountAccessor {
continue
}
// If the external group is from a different mount, don't remove the
// entity ID from it.
if mountAccessor != "" && group.Alias != nil && group.Alias.MountAccessor != mountAccessor {
continue
}

i.logger.Debug("removing member entity ID from external group", "member_entity_id", entityID, "group_id", group.ID)
// We need to update a group, if we are in a dry run we should
// report back that a change needs to take place.
if dryRun {
return true, nil, nil
}

group.MemberEntityIDs = strutil.StrListDelete(group.MemberEntityIDs, entityID)
i.logger.Debug("removing member entity ID from external group", "member_entity_id", entityID, "group_id", group.ID)

err = i.UpsertGroupInTxn(ctx, txn, group, true)
if err != nil {
return nil, err
group.MemberEntityIDs = strutil.StrListDelete(group.MemberEntityIDs, entityID)

err = i.UpsertGroupInTxn(ctx, txn, group, true)
if err != nil {
return false, nil, err
}
}

txn.Commit()
return false, validAliases, nil
}

txn.Commit()
// dryRun
needsUpdate, validAliases, err := refreshFunc(true)
if err != nil {
return nil, err
}

if needsUpdate || len(groupAliases) > 0 {
i.logger.Debug("refreshing external group memberships", "entity_id", entityID, "group_aliases", groupAliases)
}

if !needsUpdate {
return validAliases, nil
}

// Run the update
_, validAliases, err = refreshFunc(false)
if err != nil {
return nil, err
}

return validAliases, nil
}
Expand Down