Skip to content

Commit

Permalink
CBG-4025 Invalidate user roles after resync (#6942)
Browse files Browse the repository at this point in the history
* CBG-4025 Invalidate user roles after resync

As part of this change, modified access invalidation functions to perform invalidation in a single update to the principal document, instead of once per collection (and one additional time for roles).  To support this:
- Switched existing invalidation functions (invalRoleChannels, invalUserChannels, etc) to have *DatabaseContext receiver instead of *DatabaseCollection, and added a ScopeAndCollectionNames parameter to specify the set of collections that should have access invalidated.
- For ease of use, maintained the existing invalidation functions on *DatabaseCollection - they now just call in to the *DatabaseContext functions with their single collection
- Added a new invalUserRolesAndChannels to invalidate a user’s roles and channels in a single user doc update.  Only currently used by resync

Query based resync still processes a single collection’s updates at a time - it’s structured a bit differently and didn’t seem to be worth refactoring at this point.  It has been updated to properly invalidate user roles.

The new test TestResyncInvalidatePrincipals covers the fix - have verified it with SG_TEST_USE_DEFAULT_COLLECTION=true/false.  Also made a test utility change to remove the password parameter from GetRolePayload since roles don’t have passwords.

* Lint fixes
  • Loading branch information
adamcfraser authored Jul 9, 2024
1 parent 61ffd6a commit 1409e08
Show file tree
Hide file tree
Showing 16 changed files with 265 additions and 76 deletions.
93 changes: 78 additions & 15 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,11 @@ func (auth *Authenticator) UpdateSequenceNumber(p Principal, seq uint64) error {
}

func (auth *Authenticator) InvalidateDefaultChannels(name string, isUser bool, invalSeq uint64) error {
return auth.InvalidateChannels(name, isUser, base.DefaultScope, base.DefaultCollection, invalSeq)
return auth.InvalidateChannels(name, isUser, base.ScopeAndCollectionNames{base.DefaultScopeAndCollectionName()}, invalSeq)
}

// Invalidates the channel list of a user/role by setting the ChannelInvalSeq to a non-zero value
func (auth *Authenticator) InvalidateChannels(name string, isUser bool, scope string, collection string, invalSeq uint64) error {
func (auth *Authenticator) InvalidateChannels(name string, isUser bool, collections base.ScopeAndCollectionNames, invalSeq uint64) error {
var princ Principal
var docID string

Expand All @@ -465,17 +465,23 @@ func (auth *Authenticator) InvalidateChannels(name string, isUser bool, scope st

base.InfofCtx(auth.LogCtx, base.KeyAccess, "Invalidate access of %q", base.UD(name))

subdocPath := "channel_inval_seq"
if scope != base.DefaultScope || collection != base.DefaultCollection {
subdocPath = "collection_access." + scope + "." + collection + "." + subdocPath
}
// Attempt to use subdoc if we're only modifying a single collection
if len(collections) == 1 {
scope := collections[0].ScopeName()
collection := collections[0].CollectionName()

if subdocStore, ok := base.AsSubdocStore(auth.datastore); ok {
err := subdocStore.SubdocInsert(auth.LogCtx, docID, subdocPath, 0, invalSeq)
if err != nil && err != base.ErrAlreadyExists && err != base.ErrPathExists && err != base.ErrPathNotFound && !base.IsDocNotFoundError(err) {
return err
subdocPath := "channel_inval_seq"
if scope != base.DefaultScope || collection != base.DefaultCollection {
subdocPath = "collection_access." + scope + "." + collection + "." + subdocPath
}

if subdocStore, ok := base.AsSubdocStore(auth.datastore); ok {
err := subdocStore.SubdocInsert(auth.LogCtx, docID, subdocPath, 0, invalSeq)
if err != nil && err != base.ErrAlreadyExists && err != base.ErrPathExists && err != base.ErrPathNotFound && !base.IsDocNotFoundError(err) {
return err
}
return nil
}
return nil
}

_, err := auth.datastore.Update(docID, 0, func(current []byte) (updated []byte, expiry *uint32, delete bool, err error) {
Expand All @@ -489,14 +495,21 @@ func (auth *Authenticator) InvalidateChannels(name string, isUser bool, scope st
return nil, nil, false, err
}

if princ.CollectionChannels(scope, collection) == nil {
return nil, nil, false, base.ErrUpdateCancel
changed := false
for _, collectionName := range collections {
scope := collectionName.ScopeName()
collection := collectionName.CollectionName()
if princ.CollectionChannels(scope, collection) != nil {
princ.setCollectionChannelInvalSeq(scope, collection, invalSeq)
changed = true
}
}

princ.setCollectionChannelInvalSeq(scope, collection, invalSeq)
if !changed {
return nil, nil, false, base.ErrUpdateCancel
}

updated, err = base.JSONMarshal(princ)

return updated, nil, false, err
})

Expand Down Expand Up @@ -550,6 +563,56 @@ func (auth *Authenticator) InvalidateRoles(username string, invalSeq uint64) err
return err
}

// Invalidates the computed roles and channels of a user by setting the ChannelInvalSeq to a non-zero value for all specified collections.
func (auth *Authenticator) InvalidateRolesAndChannels(username string, collections base.ScopeAndCollectionNames, invalSeq uint64) error {

docID := auth.DocIDForUser(username)
base.InfofCtx(auth.LogCtx, base.KeyAccess, "Invalidate computed role and channel access of %q for collections %v", base.UD(username), collections)

_, err := auth.datastore.Update(docID, 0, func(current []byte) (updated []byte, expiry *uint32, delete bool, err error) {
// If user/role doesn't exist cancel update
if current == nil {
return nil, nil, false, base.ErrUpdateCancel
}

var user userImpl
err = base.JSONUnmarshal(current, &user)
if err != nil {
return nil, nil, false, base.ErrUpdateCancel
}

changed := false
// Invalidate channel access per collection
for _, collection := range collections {
scope := collection.ScopeName()
collection := collection.CollectionName()
if user.CollectionChannels(scope, collection) != nil {
user.setCollectionChannelInvalSeq(scope, collection, invalSeq)
changed = true
}
}

// Invalidate role access
if user.RoleNames() != nil {
user.SetRoleInvalSeq(invalSeq)
changed = true
}

if !changed {
return nil, nil, false, base.ErrUpdateCancel
}

updated, err = base.JSONMarshal(&user)
return updated, nil, false, err
})

if err == base.ErrUpdateCancel {
return nil
}

return err
}

// Updates user email and writes user doc
func (auth *Authenticator) UpdateUserEmail(u User, email string) error {

Expand Down
4 changes: 2 additions & 2 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ func TestRebuildUserChannelsMultiCollection(t *testing.T) {
err := auth.Save(user)
assert.NoError(t, err)

err = auth.InvalidateChannels("testUser", true, "scope1", "collection1", 2)
err = auth.InvalidateChannels("testUser", true, base.ScopeAndCollectionNames{base.NewScopeAndCollectionName("scope1", "collection1")}, 2)
assert.NoError(t, err)

user2, err := auth.GetUser("testUser")
Expand All @@ -452,7 +452,7 @@ func TestRebuildUserChannelsNamedCollection(t *testing.T) {
err := auth.Save(user)
assert.NoError(t, err)

err = auth.InvalidateChannels("testUser", true, "scope1", "collection1", 2)
err = auth.InvalidateChannels("testUser", true, base.ScopeAndCollectionNames{base.NewScopeAndCollectionName("scope1", "collection1")}, 2)
assert.NoError(t, err)

user2, err := auth.GetUser("testUser")
Expand Down
7 changes: 7 additions & 0 deletions base/collection_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ func DefaultScopeAndCollectionName() ScopeAndCollectionName {
return ScopeAndCollectionName{Scope: DefaultScope, Collection: DefaultCollection}
}

func NewScopeAndCollectionName(scope, collection string) ScopeAndCollectionName {
return ScopeAndCollectionName{
Scope: scope,
Collection: collection,
}
}

type ScopeAndCollectionNames []ScopeAndCollectionName

// ScopeAndCollectionNames returns a dot-separated formatted slice of scope and collection names.
Expand Down
5 changes: 4 additions & 1 deletion db/background_mgr_resync_dcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,12 @@ func (r *ResyncManagerDCP) Run(ctx context.Context, options map[string]interface
if err != nil {
return err
}

collectionNames := make(base.ScopeAndCollectionNames, 0)
for _, databaseCollection := range db.CollectionByID {
databaseCollection.invalidateAllPrincipalsCache(ctx, endSeq)
collectionNames = append(collectionNames, databaseCollection.ScopeAndCollectionName())
}
db.invalidateAllPrincipals(ctx, collectionNames, endSeq)

}

Expand Down
2 changes: 1 addition & 1 deletion db/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -2301,7 +2301,7 @@ func (db *DatabaseCollectionWithUser) MarkPrincipalsChanged(ctx context.Context,
if len(changedRoleUsers) > 0 {
base.InfofCtx(ctx, base.KeyAccess, "Rev %q / %q invalidates roles of %s", base.UD(docid), newRevID, base.UD(changedRoleUsers))
for _, name := range changedRoleUsers {
db.invalUserRoles(ctx, name, invalSeq)
db.dbCtx.invalUserRoles(ctx, name, invalSeq)
// If this is the current in memory db.user, reload to generate updated roles
if db.user != nil && db.user.Name() == name {
base.DebugfCtx(ctx, base.KeyAccess, "Role set for active user has been modified - user %q will be reloaded.", base.UD(db.user.Name()))
Expand Down
60 changes: 31 additions & 29 deletions db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -1678,23 +1678,11 @@ func (db *DatabaseCollectionWithUser) UpdateAllDocChannels(ctx context.Context,
base.InfofCtx(ctx, base.KeyAll, "Finished re-running sync function; %d/%d docs changed", docsChanged, docsProcessed)

if docsChanged > 0 {
db.invalidateAllPrincipalsCache(ctx, endSeq)
db.invalidateAllPrincipals(ctx, endSeq)
}
return docsChanged, nil
}

// invalidate channel cache of all users/roles:
func (c *DatabaseCollection) invalidateAllPrincipalsCache(ctx context.Context, endSeq uint64) {
base.InfofCtx(ctx, base.KeyAll, "Invalidating channel caches of users/roles...")
users, roles, _ := c.allPrincipalIDs(ctx)
for _, name := range users {
c.invalUserChannels(ctx, name, endSeq)
}
for _, name := range roles {
c.invalRoleChannels(ctx, name, endSeq)
}
}

func (c *DatabaseCollection) updateAllPrincipalsSequences(ctx context.Context) error {
users, roles, err := c.allPrincipalIDs(ctx)
if err != nil {
Expand Down Expand Up @@ -1892,34 +1880,48 @@ func (db *DatabaseCollectionWithUser) resyncDocument(ctx context.Context, docid,
return updatedHighSeq, unusedSequences, err
}

func (c *DatabaseCollection) invalUserRoles(ctx context.Context, username string, invalSeq uint64) {
authr := c.Authenticator(ctx)
if err := authr.InvalidateRoles(username, invalSeq); err != nil {
base.WarnfCtx(ctx, "Error invalidating roles for user %s: %v", base.UD(username), err)
// invalidateAllPrincipals invalidates computed channels and roles for all users/roles, for the specified collections:
func (dbCtx *DatabaseContext) invalidateAllPrincipals(ctx context.Context, collectionNames base.ScopeAndCollectionNames, endSeq uint64) {
base.InfofCtx(ctx, base.KeyAll, "Invalidating channel caches of users/roles...")
users, roles, _ := dbCtx.AllPrincipalIDs(ctx)
for _, name := range users {
dbCtx.invalUserRolesAndChannels(ctx, name, collectionNames, endSeq)
}
for _, name := range roles {
dbCtx.invalRoleChannels(ctx, name, collectionNames, endSeq)
}
}

func (c *DatabaseCollection) invalUserChannels(ctx context.Context, username string, invalSeq uint64) {
authr := c.Authenticator(ctx)
if err := authr.InvalidateChannels(username, true, c.ScopeName, c.Name, invalSeq); err != nil {
// invalUserChannels invalidates a user's computed channels for the specified collections
func (dbCtx *DatabaseContext) invalUserChannels(ctx context.Context, username string, collections base.ScopeAndCollectionNames, invalSeq uint64) {
authr := dbCtx.Authenticator(ctx)
if err := authr.InvalidateChannels(username, true, collections, invalSeq); err != nil {
base.WarnfCtx(ctx, "Error invalidating channels for user %s: %v", base.UD(username), err)
}
}

func (c *DatabaseCollection) invalRoleChannels(ctx context.Context, rolename string, invalSeq uint64) {
authr := c.Authenticator(ctx)
if err := authr.InvalidateChannels(rolename, false, c.ScopeName, c.Name, invalSeq); err != nil {
// invalRoleChannels invalidates a role's computed channels for the specified collections
func (dbCtx *DatabaseContext) invalRoleChannels(ctx context.Context, rolename string, collections base.ScopeAndCollectionNames, invalSeq uint64) {
authr := dbCtx.Authenticator(ctx)
if err := authr.InvalidateChannels(rolename, false, collections, invalSeq); err != nil {
base.WarnfCtx(ctx, "Error invalidating channels for role %s: %v", base.UD(rolename), err)
}
}

func (c *DatabaseCollection) invalUserOrRoleChannels(ctx context.Context, name string, invalSeq uint64) {
// invalUserRoles invalidates a user's computed roles
func (dbCtx *DatabaseContext) invalUserRoles(ctx context.Context, username string, invalSeq uint64) {

principalName, isRole := channels.AccessNameToPrincipalName(name)
if isRole {
c.invalRoleChannels(ctx, principalName, invalSeq)
} else {
c.invalUserChannels(ctx, principalName, invalSeq)
authr := dbCtx.Authenticator(ctx)
if err := authr.InvalidateRoles(username, invalSeq); err != nil {
base.WarnfCtx(ctx, "Error invalidating roles for user %s: %v", base.UD(username), err)
}
}

// invalUserRolesAndChannels invalidates the user's computed roles, and invalidates the computed channels for all specified collections
func (dbCtx *DatabaseContext) invalUserRolesAndChannels(ctx context.Context, username string, collections base.ScopeAndCollectionNames, invalSeq uint64) {
authr := dbCtx.Authenticator(ctx)
if err := authr.InvalidateRolesAndChannels(username, collections, invalSeq); err != nil {
base.WarnfCtx(ctx, "Error invalidating roles for user %s: %v", base.UD(username), err)
}
}

Expand Down
31 changes: 31 additions & 0 deletions db/database_collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,37 @@ func (c *DatabaseCollection) UpdateSyncFun(ctx context.Context, syncFun string)
return
}

// DatabaseCollection helper methods for channel and role invalidation - invoke the multi-collection version on
// the databaseContext for a single collection.
// invalUserOrRoleChannels invalidates a user or role's channels for collection c
func (c *DatabaseCollection) invalUserOrRoleChannels(ctx context.Context, name string, invalSeq uint64) {
principalName, isRole := channels.AccessNameToPrincipalName(name)
if isRole {
c.invalRoleChannels(ctx, principalName, invalSeq)
} else {
c.invalUserChannels(ctx, principalName, invalSeq)
}
}

// invalRoleChannels invalidates a user's computed channels for collection c
func (c *DatabaseCollection) invalUserChannels(ctx context.Context, username string, invalSeq uint64) {
c.dbCtx.invalUserChannels(ctx, username, base.ScopeAndCollectionNames{c.ScopeAndCollectionName()}, invalSeq)
}

// invalRoleChannels invalidates a role's computed channels for collection c
func (c *DatabaseCollection) invalRoleChannels(ctx context.Context, rolename string, invalSeq uint64) {
c.dbCtx.invalRoleChannels(ctx, rolename, base.ScopeAndCollectionNames{c.ScopeAndCollectionName()}, invalSeq)
}

// invalidateAllPrincipals invalidates computed channels and roles for collection c, for all users and roles
func (c *DatabaseCollection) invalidateAllPrincipals(ctx context.Context, endSeq uint64) {
c.dbCtx.invalidateAllPrincipals(ctx, base.ScopeAndCollectionNames{c.ScopeAndCollectionName()}, endSeq)
}

func (c *DatabaseCollection) useMou() bool {
return c.dbCtx.UseMou()
}

func (c *DatabaseCollection) ScopeAndCollectionName() base.ScopeAndCollectionName {
return base.NewScopeAndCollectionName(c.ScopeName, c.Name)
}
2 changes: 1 addition & 1 deletion db/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2731,7 +2731,7 @@ func Test_invalidateAllPrincipalsCache(t *testing.T) {
assert.NoError(t, err)
assert.Greater(t, endSeq, uint64(0))

collection.invalidateAllPrincipalsCache(ctx, endSeq)
collection.invalidateAllPrincipals(ctx, endSeq)
err = collection.WaitForPendingChanges(ctx)
assert.NoError(t, err)

Expand Down
Loading

0 comments on commit 1409e08

Please sign in to comment.