Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CBG-4025 Invalidate user roles after resync #6942

Merged
merged 2 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 78 additions & 15 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,11 @@ func (auth *Authenticator) UpdateSequenceNumber(p Principal, seq uint64) error {
}

func (auth *Authenticator) InvalidateDefaultChannels(name string, isUser bool, invalSeq uint64) error {
return auth.InvalidateChannels(name, isUser, base.DefaultScope, base.DefaultCollection, invalSeq)
return auth.InvalidateChannels(name, isUser, base.ScopeAndCollectionNames{base.DefaultScopeAndCollectionName()}, invalSeq)
}

// Invalidates the channel list of a user/role by setting the ChannelInvalSeq to a non-zero value
func (auth *Authenticator) InvalidateChannels(name string, isUser bool, scope string, collection string, invalSeq uint64) error {
func (auth *Authenticator) InvalidateChannels(name string, isUser bool, collections base.ScopeAndCollectionNames, invalSeq uint64) error {
var princ Principal
var docID string

Expand All @@ -465,17 +465,23 @@ func (auth *Authenticator) InvalidateChannels(name string, isUser bool, scope st

base.InfofCtx(auth.LogCtx, base.KeyAccess, "Invalidate access of %q", base.UD(name))

subdocPath := "channel_inval_seq"
if scope != base.DefaultScope || collection != base.DefaultCollection {
subdocPath = "collection_access." + scope + "." + collection + "." + subdocPath
}
// Attempt to use subdoc if we're only modifying a single collection
if len(collections) == 1 {
Comment on lines +468 to +469
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more of an optimization than might be necessary, but can we change the API of SubdocInsert to be map[string]any so we can do multiple subdoc operations at once?

Later we are going to run through the logic of updating the document, though not the write operation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered adding that functionality in this PR, but decided against it due to the increase in scope/risk. Since the more common usage of InvalidateChannels (at write time) is always single collection, I thought it was reasonable to omit.

scope := collections[0].ScopeName()
collection := collections[0].CollectionName()

if subdocStore, ok := base.AsSubdocStore(auth.datastore); ok {
err := subdocStore.SubdocInsert(auth.LogCtx, docID, subdocPath, 0, invalSeq)
if err != nil && err != base.ErrAlreadyExists && err != base.ErrPathExists && err != base.ErrPathNotFound && !base.IsDocNotFoundError(err) {
return err
subdocPath := "channel_inval_seq"
if scope != base.DefaultScope || collection != base.DefaultCollection {
subdocPath = "collection_access." + scope + "." + collection + "." + subdocPath
}

if subdocStore, ok := base.AsSubdocStore(auth.datastore); ok {
err := subdocStore.SubdocInsert(auth.LogCtx, docID, subdocPath, 0, invalSeq)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgbucket.SubDocStore is part of sgbucket.DataStore so we don't need to do this cast.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsSubdocStore is also doing a check on ds.IsSupported(sgbucket.BucketStoreFeatureSubdocOperations) (which I believe is still relevant for walrus support for backports to 3.1 that support walrus). I think I'd prefer to leave this code as it was in this PR since we expect a backport.

if err != nil && err != base.ErrAlreadyExists && err != base.ErrPathExists && err != base.ErrPathNotFound && !base.IsDocNotFoundError(err) {
return err
}
return nil
}
return nil
}

_, err := auth.datastore.Update(docID, 0, func(current []byte) (updated []byte, expiry *uint32, delete bool, err error) {
Expand All @@ -489,14 +495,21 @@ func (auth *Authenticator) InvalidateChannels(name string, isUser bool, scope st
return nil, nil, false, err
}

if princ.CollectionChannels(scope, collection) == nil {
return nil, nil, false, base.ErrUpdateCancel
changed := false
for _, collectionName := range collections {
scope := collectionName.ScopeName()
collection := collectionName.CollectionName()
if princ.CollectionChannels(scope, collection) != nil {
princ.setCollectionChannelInvalSeq(scope, collection, invalSeq)
changed = true
}
}

princ.setCollectionChannelInvalSeq(scope, collection, invalSeq)
if !changed {
return nil, nil, false, base.ErrUpdateCancel
}

updated, err = base.JSONMarshal(princ)

return updated, nil, false, err
})

Expand Down Expand Up @@ -550,6 +563,56 @@ func (auth *Authenticator) InvalidateRoles(username string, invalSeq uint64) err
return err
}

// Invalidates the computed roles and channels of a user by setting the ChannelInvalSeq to a non-zero value for all specified collections.
func (auth *Authenticator) InvalidateRolesAndChannels(username string, collections base.ScopeAndCollectionNames, invalSeq uint64) error {

docID := auth.DocIDForUser(username)
base.InfofCtx(auth.LogCtx, base.KeyAccess, "Invalidate computed role and channel access of %q for collections %v", base.UD(username), collections)

_, err := auth.datastore.Update(docID, 0, func(current []byte) (updated []byte, expiry *uint32, delete bool, err error) {
// If user/role doesn't exist cancel update
if current == nil {
return nil, nil, false, base.ErrUpdateCancel
}

var user userImpl
err = base.JSONUnmarshal(current, &user)
if err != nil {
return nil, nil, false, base.ErrUpdateCancel
}

changed := false
// Invalidate channel access per collection
for _, collection := range collections {
scope := collection.ScopeName()
collection := collection.CollectionName()
if user.CollectionChannels(scope, collection) != nil {
user.setCollectionChannelInvalSeq(scope, collection, invalSeq)
changed = true
}
}

// Invalidate role access
if user.RoleNames() != nil {
user.SetRoleInvalSeq(invalSeq)
changed = true
}

if !changed {
return nil, nil, false, base.ErrUpdateCancel
}

updated, err = base.JSONMarshal(&user)
return updated, nil, false, err
})

if err == base.ErrUpdateCancel {
return nil
}

return err
}

// Updates user email and writes user doc
func (auth *Authenticator) UpdateUserEmail(u User, email string) error {

Expand Down
4 changes: 2 additions & 2 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ func TestRebuildUserChannelsMultiCollection(t *testing.T) {
err := auth.Save(user)
assert.NoError(t, err)

err = auth.InvalidateChannels("testUser", true, "scope1", "collection1", 2)
err = auth.InvalidateChannels("testUser", true, base.ScopeAndCollectionNames{base.NewScopeAndCollectionName("scope1", "collection1")}, 2)
assert.NoError(t, err)

user2, err := auth.GetUser("testUser")
Expand All @@ -452,7 +452,7 @@ func TestRebuildUserChannelsNamedCollection(t *testing.T) {
err := auth.Save(user)
assert.NoError(t, err)

err = auth.InvalidateChannels("testUser", true, "scope1", "collection1", 2)
err = auth.InvalidateChannels("testUser", true, base.ScopeAndCollectionNames{base.NewScopeAndCollectionName("scope1", "collection1")}, 2)
assert.NoError(t, err)

user2, err := auth.GetUser("testUser")
Expand Down
7 changes: 7 additions & 0 deletions base/collection_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ func DefaultScopeAndCollectionName() ScopeAndCollectionName {
return ScopeAndCollectionName{Scope: DefaultScope, Collection: DefaultCollection}
}

func NewScopeAndCollectionName(scope, collection string) ScopeAndCollectionName {
return ScopeAndCollectionName{
Scope: scope,
Collection: collection,
}
}

type ScopeAndCollectionNames []ScopeAndCollectionName

// ScopeAndCollectionNames returns a dot-separated formatted slice of scope and collection names.
Expand Down
5 changes: 4 additions & 1 deletion db/background_mgr_resync_dcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,12 @@ func (r *ResyncManagerDCP) Run(ctx context.Context, options map[string]interface
if err != nil {
return err
}

collectionNames := make(base.ScopeAndCollectionNames, 0)
for _, databaseCollection := range db.CollectionByID {
databaseCollection.invalidateAllPrincipalsCache(ctx, endSeq)
collectionNames = append(collectionNames, databaseCollection.ScopeAndCollectionName())
}
db.invalidateAllPrincipals(ctx, collectionNames, endSeq)

}

Expand Down
2 changes: 1 addition & 1 deletion db/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -2301,7 +2301,7 @@ func (db *DatabaseCollectionWithUser) MarkPrincipalsChanged(ctx context.Context,
if len(changedRoleUsers) > 0 {
base.InfofCtx(ctx, base.KeyAccess, "Rev %q / %q invalidates roles of %s", base.UD(docid), newRevID, base.UD(changedRoleUsers))
for _, name := range changedRoleUsers {
db.invalUserRoles(ctx, name, invalSeq)
db.dbCtx.invalUserRoles(ctx, name, invalSeq)
// If this is the current in memory db.user, reload to generate updated roles
if db.user != nil && db.user.Name() == name {
base.DebugfCtx(ctx, base.KeyAccess, "Role set for active user has been modified - user %q will be reloaded.", base.UD(db.user.Name()))
Expand Down
60 changes: 31 additions & 29 deletions db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -1679,23 +1679,11 @@ func (db *DatabaseCollectionWithUser) UpdateAllDocChannels(ctx context.Context,
base.InfofCtx(ctx, base.KeyAll, "Finished re-running sync function; %d/%d docs changed", docsChanged, docsProcessed)

if docsChanged > 0 {
db.invalidateAllPrincipalsCache(ctx, endSeq)
db.invalidateAllPrincipals(ctx, endSeq)
}
return docsChanged, nil
}

// invalidate channel cache of all users/roles:
func (c *DatabaseCollection) invalidateAllPrincipalsCache(ctx context.Context, endSeq uint64) {
base.InfofCtx(ctx, base.KeyAll, "Invalidating channel caches of users/roles...")
users, roles, _ := c.allPrincipalIDs(ctx)
for _, name := range users {
c.invalUserChannels(ctx, name, endSeq)
}
for _, name := range roles {
c.invalRoleChannels(ctx, name, endSeq)
}
}

func (c *DatabaseCollection) updateAllPrincipalsSequences(ctx context.Context) error {
users, roles, err := c.allPrincipalIDs(ctx)
if err != nil {
Expand Down Expand Up @@ -1893,34 +1881,48 @@ func (db *DatabaseCollectionWithUser) resyncDocument(ctx context.Context, docid,
return updatedHighSeq, unusedSequences, err
}

func (c *DatabaseCollection) invalUserRoles(ctx context.Context, username string, invalSeq uint64) {
authr := c.Authenticator(ctx)
if err := authr.InvalidateRoles(username, invalSeq); err != nil {
base.WarnfCtx(ctx, "Error invalidating roles for user %s: %v", base.UD(username), err)
// invalidateAllPrincipals invalidates computed channels and roles for all users/roles, for the specified collections:
func (dbCtx *DatabaseContext) invalidateAllPrincipals(ctx context.Context, collectionNames base.ScopeAndCollectionNames, endSeq uint64) {
base.InfofCtx(ctx, base.KeyAll, "Invalidating channel caches of users/roles...")
users, roles, _ := dbCtx.AllPrincipalIDs(ctx)
for _, name := range users {
dbCtx.invalUserRolesAndChannels(ctx, name, collectionNames, endSeq)
}
for _, name := range roles {
dbCtx.invalRoleChannels(ctx, name, collectionNames, endSeq)
}
}

func (c *DatabaseCollection) invalUserChannels(ctx context.Context, username string, invalSeq uint64) {
authr := c.Authenticator(ctx)
if err := authr.InvalidateChannels(username, true, c.ScopeName, c.Name, invalSeq); err != nil {
// invalUserChannels invalidates a user's computed channels for the specified collections
func (dbCtx *DatabaseContext) invalUserChannels(ctx context.Context, username string, collections base.ScopeAndCollectionNames, invalSeq uint64) {
authr := dbCtx.Authenticator(ctx)
if err := authr.InvalidateChannels(username, true, collections, invalSeq); err != nil {
base.WarnfCtx(ctx, "Error invalidating channels for user %s: %v", base.UD(username), err)
}
}

func (c *DatabaseCollection) invalRoleChannels(ctx context.Context, rolename string, invalSeq uint64) {
authr := c.Authenticator(ctx)
if err := authr.InvalidateChannels(rolename, false, c.ScopeName, c.Name, invalSeq); err != nil {
// invalRoleChannels invalidates a role's computed channels for the specified collections
func (dbCtx *DatabaseContext) invalRoleChannels(ctx context.Context, rolename string, collections base.ScopeAndCollectionNames, invalSeq uint64) {
authr := dbCtx.Authenticator(ctx)
if err := authr.InvalidateChannels(rolename, false, collections, invalSeq); err != nil {
base.WarnfCtx(ctx, "Error invalidating channels for role %s: %v", base.UD(rolename), err)
}
}

func (c *DatabaseCollection) invalUserOrRoleChannels(ctx context.Context, name string, invalSeq uint64) {
// invalUserRoles invalidates a user's computed roles
func (dbCtx *DatabaseContext) invalUserRoles(ctx context.Context, username string, invalSeq uint64) {

principalName, isRole := channels.AccessNameToPrincipalName(name)
if isRole {
c.invalRoleChannels(ctx, principalName, invalSeq)
} else {
c.invalUserChannels(ctx, principalName, invalSeq)
authr := dbCtx.Authenticator(ctx)
if err := authr.InvalidateRoles(username, invalSeq); err != nil {
base.WarnfCtx(ctx, "Error invalidating roles for user %s: %v", base.UD(username), err)
}
}

// invalUserRolesAndChannels invalidates the user's computed roles, and invalidates the computed channels for all specified collections
func (dbCtx *DatabaseContext) invalUserRolesAndChannels(ctx context.Context, username string, collections base.ScopeAndCollectionNames, invalSeq uint64) {
authr := dbCtx.Authenticator(ctx)
if err := authr.InvalidateRolesAndChannels(username, collections, invalSeq); err != nil {
base.WarnfCtx(ctx, "Error invalidating roles for user %s: %v", base.UD(username), err)
}
}

Expand Down
31 changes: 31 additions & 0 deletions db/database_collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,37 @@ func (c *DatabaseCollection) UpdateSyncFun(ctx context.Context, syncFun string)
return
}

// DatabaseCollection helper methods for channel and role invalidation - invoke the multi-collection version on
// the databaseContext for a single collection.
// invalUserOrRoleChannels invalidates a user or role's channels for collection c
func (c *DatabaseCollection) invalUserOrRoleChannels(ctx context.Context, name string, invalSeq uint64) {
principalName, isRole := channels.AccessNameToPrincipalName(name)
if isRole {
c.invalRoleChannels(ctx, principalName, invalSeq)
} else {
c.invalUserChannels(ctx, principalName, invalSeq)
}
}

// invalRoleChannels invalidates a user's computed channels for collection c
func (c *DatabaseCollection) invalUserChannels(ctx context.Context, username string, invalSeq uint64) {
c.dbCtx.invalUserChannels(ctx, username, base.ScopeAndCollectionNames{c.ScopeAndCollectionName()}, invalSeq)
}

// invalRoleChannels invalidates a role's computed channels for collection c
func (c *DatabaseCollection) invalRoleChannels(ctx context.Context, rolename string, invalSeq uint64) {
c.dbCtx.invalRoleChannels(ctx, rolename, base.ScopeAndCollectionNames{c.ScopeAndCollectionName()}, invalSeq)
}

// invalidateAllPrincipals invalidates computed channels and roles for collection c, for all users and roles
func (c *DatabaseCollection) invalidateAllPrincipals(ctx context.Context, endSeq uint64) {
c.dbCtx.invalidateAllPrincipals(ctx, base.ScopeAndCollectionNames{c.ScopeAndCollectionName()}, endSeq)
}

func (c *DatabaseCollection) useMou() bool {
return c.dbCtx.UseMou()
}

func (c *DatabaseCollection) ScopeAndCollectionName() base.ScopeAndCollectionName {
return base.NewScopeAndCollectionName(c.ScopeName, c.Name)
}
2 changes: 1 addition & 1 deletion db/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2731,7 +2731,7 @@ func Test_invalidateAllPrincipalsCache(t *testing.T) {
assert.NoError(t, err)
assert.Greater(t, endSeq, uint64(0))

collection.invalidateAllPrincipalsCache(ctx, endSeq)
collection.invalidateAllPrincipals(ctx, endSeq)
err = collection.WaitForPendingChanges(ctx)
assert.NoError(t, err)

Expand Down
Loading
Loading