diff --git a/users-sync/attrsync/sync.go b/users-sync/attrsync/sync.go index 317c119b0..29e6d5ffb 100644 --- a/users-sync/attrsync/sync.go +++ b/users-sync/attrsync/sync.go @@ -98,7 +98,9 @@ func (c *AttributeSyncer) EnqueueOrgsSync(ctx context.Context, orgExternalIDs [] // we could do this DB lookup later and remove it from the Enqueue call path // but it should be relatively cheap. for _, externalID := range orgExternalIDs { - users, err := c.db.ListOrganizationUsers(ctx, externalID) + // We request memberships for deleted orgs too so that we update users who + // were connected to the deleted orgs + users, err := c.db.ListOrganizationUsers(ctx, externalID, true) if err != nil { return err } diff --git a/users/api/admin.go b/users/api/admin.go index 6e0c941f7..7f3b27a65 100644 --- a/users/api/admin.go +++ b/users/api/admin.go @@ -105,7 +105,7 @@ func (a *API) adminListUsersForOrganization(w http.ResponseWriter, r *http.Reque renderError(w, r, users.ErrNotFound) return } - us, err := a.db.ListOrganizationUsers(r.Context(), orgID) + us, err := a.db.ListOrganizationUsers(r.Context(), orgID, true) if err != nil { renderError(w, r, err) return @@ -130,7 +130,7 @@ func (a *API) adminRemoveUserFromOrganization(w http.ResponseWriter, r *http.Req orgExternalID := vars["orgExternalID"] userID := vars["userID"] - if members, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID); err != nil { + if members, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID, false); err != nil { renderError(w, r, err) return } else if len(members) == 1 { diff --git a/users/api/org.go b/users/api/org.go index b55d4609c..3dc3f0557 100644 --- a/users/api/org.go +++ b/users/api/org.go @@ -420,7 +420,7 @@ func (a *API) listOrganizationUsers(currentUser *users.User, w http.ResponseWrit return } - users, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID) + users, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID, false) if err != nil { renderError(w, r, err) return @@ -492,7 +492,7 @@ func (a *API) removeUser(currentUser *users.User, w http.ResponseWriter, r *http return } - if members, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID); err != nil { + if members, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID, false); err != nil { renderError(w, r, err) return } else if len(members) == 1 { @@ -599,7 +599,7 @@ func (a *API) extendOrgTrialPeriod(ctx context.Context, org *users.Organization, return nil } - members, err := a.db.ListOrganizationUsers(ctx, org.ExternalID) + members, err := a.db.ListOrganizationUsers(ctx, org.ExternalID, false) if err != nil { return err } diff --git a/users/db/db.go b/users/db/db.go index a72ff26b7..c1abbed80 100644 --- a/users/db/db.go +++ b/users/db/db.go @@ -57,7 +57,7 @@ type DB interface { ListUsers(ctx context.Context, f filter.User, page uint64) ([]*users.User, error) ListOrganizations(ctx context.Context, f filter.Organization, page uint64) ([]*users.Organization, error) ListAllOrganizations(ctx context.Context, f filter.Organization, page uint64) ([]*users.Organization, error) - ListOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error) + ListOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error) // ListOrganizationsForUserIDs lists all organizations these users have // access to. diff --git a/users/db/db_test.go b/users/db/db_test.go index f5d5acd96..0590655a8 100644 --- a/users/db/db_test.go +++ b/users/db/db_test.go @@ -29,14 +29,14 @@ func TestDB_RemoveOtherUsersAccess(t *testing.T) { require.NoError(t, err) require.Len(t, otherUserOrganizations, 1) - orgUsers, err := db.ListOrganizationUsers(context.Background(), org.ExternalID) + orgUsers, err := db.ListOrganizationUsers(context.Background(), org.ExternalID, false) require.NoError(t, err) require.Len(t, orgUsers, 2) err = db.RemoveUserFromOrganization(context.Background(), org.ExternalID, otherUser.Email) require.NoError(t, err) - orgUsers, err = db.ListOrganizationUsers(context.Background(), org.ExternalID) + orgUsers, err = db.ListOrganizationUsers(context.Background(), org.ExternalID, false) require.NoError(t, err) require.Len(t, orgUsers, 1) } @@ -73,7 +73,7 @@ func TestDB_RemoveOtherUsersAccessWithTeams(t *testing.T) { require.Len(t, otherUserTeams, 1) require.Equal(t, team.ID, otherUserTeams[0].ID) - orgUsers, err := db.ListOrganizationUsers(ctx, org.ExternalID) + orgUsers, err := db.ListOrganizationUsers(ctx, org.ExternalID, false) require.NoError(t, err) require.Len(t, orgUsers, 2) @@ -84,7 +84,7 @@ func TestDB_RemoveOtherUsersAccessWithTeams(t *testing.T) { err = db.RemoveUserFromOrganization(ctx, org.ExternalID, otherUser.Email) require.NoError(t, err) - orgUsers, err = db.ListOrganizationUsers(ctx, org.ExternalID) + orgUsers, err = db.ListOrganizationUsers(ctx, org.ExternalID, false) require.NoError(t, err) require.Len(t, orgUsers, 1) diff --git a/users/db/memory/organization.go b/users/db/memory/organization.go index d414b18a2..8aa03767f 100644 --- a/users/db/memory/organization.go +++ b/users/db/memory/organization.go @@ -125,22 +125,26 @@ func (d *DB) ListAllOrganizations(_ context.Context, f filter.Organization, page } // ListOrganizationUsers lists all the users in an organization -func (d *DB) ListOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error) { +func (d *DB) ListOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error) { d.mtx.Lock() defer d.mtx.Unlock() - return d.listOrganizationUsers(ctx, orgExternalID) + return d.listOrganizationUsers(ctx, orgExternalID, includeDeletedOrgs) } // listOrganizationUsers lists all the users in an organization // This is a lock-free version of the above, in order to be able to re-use the actual logic // in other methods as otherwise, calling mtx.Lock() twice on the same goroutine deadlocks it. -func (d *DB) listOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error) { +func (d *DB) listOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error) { o, err := d.findOrganizationByExternalID(orgExternalID) if err != nil { return nil, err } var users []*users.User + if !o.DeletedAt.IsZero() && !includeDeletedOrgs { + return users, nil + } + for _, m := range d.memberships[o.ID] { u, err := d.findUserByID(m) if err != nil { @@ -808,7 +812,7 @@ func (d *DB) GetSummary(ctx context.Context) ([]*users.SummaryEntry, error) { entries := []*users.SummaryEntry{} for _, org := range d.organizations { team := d.teams[org.TeamID] - orgUsers, err := d.listOrganizationUsers(ctx, org.ExternalID) + orgUsers, err := d.listOrganizationUsers(ctx, org.ExternalID, false) if err != nil { return nil, err } diff --git a/users/db/postgres/organization.go b/users/db/postgres/organization.go index 224fcffce..6f71256bd 100644 --- a/users/db/postgres/organization.go +++ b/users/db/postgres/organization.go @@ -192,12 +192,13 @@ func (d DB) ListAllOrganizations(ctx context.Context, f filter.Organization, pag } // ListOrganizationUsers lists all the users in an organization -func (d DB) ListOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error) { - orgUsers, err := d.listDirectOrganizationUsers(ctx, orgExternalID) +// it will still return a user list for 'deleted' organizations +func (d DB) ListOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error) { + orgUsers, err := d.listDirectOrganizationUsers(ctx, orgExternalID, includeDeletedOrgs) if err != nil { return nil, err } - teamUsers, err := d.listTeamOrganizationUsers(ctx, orgExternalID) + teamUsers, err := d.listTeamOrganizationUsers(ctx, orgExternalID, includeDeletedOrgs) if err != nil { return nil, err } @@ -206,14 +207,18 @@ func (d DB) ListOrganizationUsers(ctx context.Context, orgExternalID string) ([] return users, nil } -func (d DB) listDirectOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error) { +func (d DB) listDirectOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error) { + filter := squirrel.Eq{ + "organizations.external_id": orgExternalID, + "memberships.deleted_at": nil, + } + if !includeDeletedOrgs { + filter["organizations.deleted_at"] = nil + } rows, err := d.usersQuery(). Join("memberships on (memberships.user_id = users.id)"). Join("organizations on (memberships.organization_id = organizations.id)"). - Where(squirrel.Eq{ - "organizations.external_id": orgExternalID, - "memberships.deleted_at": nil, - }). + Where(filter). OrderBy("users.created_at"). QueryContext(ctx) if err != nil { @@ -223,14 +228,18 @@ func (d DB) listDirectOrganizationUsers(ctx context.Context, orgExternalID strin return d.scanUsers(rows) } -func (d DB) listTeamOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error) { +func (d DB) listTeamOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error) { + filter := squirrel.Eq{ + "organizations.external_id": orgExternalID, + "team_memberships.deleted_at": nil, + } + if !includeDeletedOrgs { + filter["organizations.deleted_at"] = nil + } rows, err := d.usersQuery(). Join("team_memberships on (team_memberships.user_id = users.id)"). Join("organizations on (team_memberships.team_id = organizations.team_id)"). - Where(squirrel.Eq{ - "organizations.external_id": orgExternalID, - "team_memberships.deleted_at": nil, - }). + Where(filter). QueryContext(ctx) if err != nil { return nil, err diff --git a/users/db/timed.go b/users/db/timed.go index f4bd29fa4..f1639034b 100644 --- a/users/db/timed.go +++ b/users/db/timed.go @@ -143,9 +143,9 @@ func (t timed) ListAllOrganizations(ctx context.Context, f filter.Organization, return } -func (t timed) ListOrganizationUsers(ctx context.Context, orgExternalID string) (us []*users.User, err error) { +func (t timed) ListOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) (us []*users.User, err error) { t.timeRequest(ctx, "ListOrganizationUsers", func(ctx context.Context) error { - us, err = t.d.ListOrganizationUsers(ctx, orgExternalID) + us, err = t.d.ListOrganizationUsers(ctx, orgExternalID, includeDeletedOrgs) return err }) return diff --git a/users/db/traced.go b/users/db/traced.go index 962353e42..2a21a3989 100644 --- a/users/db/traced.go +++ b/users/db/traced.go @@ -94,9 +94,9 @@ func (t traced) ListAllOrganizations(ctx context.Context, f filter.Organization, return t.d.ListAllOrganizations(ctx, f, page) } -func (t traced) ListOrganizationUsers(ctx context.Context, orgExternalID string) (us []*users.User, err error) { +func (t traced) ListOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) (us []*users.User, err error) { defer t.trace("ListOrganizationUsers", orgExternalID, us, err) - return t.d.ListOrganizationUsers(ctx, orgExternalID) + return t.d.ListOrganizationUsers(ctx, orgExternalID, includeDeletedOrgs) } func (t traced) ListOrganizationsForUserIDs(ctx context.Context, userIDs ...string) (os []*users.Organization, err error) { diff --git a/users/grpc/lookup.go b/users/grpc/lookup.go index 447b811e3..9dd78bd40 100644 --- a/users/grpc/lookup.go +++ b/users/grpc/lookup.go @@ -301,7 +301,7 @@ func (a *usersServer) NotifyTrialPendingExpiry(ctx context.Context, req *users.N } // Notify all users - members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID) + members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID, false) if err != nil { return nil, err } @@ -325,7 +325,7 @@ func (a *usersServer) NotifyTrialExpired(ctx context.Context, req *users.NotifyT } // Notify all users - members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID) + members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID, false) if err != nil { return nil, err } @@ -349,7 +349,7 @@ func (a *usersServer) NotifyRefuseDataUpload(ctx context.Context, req *users.Not } // Notify all users - members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID) + members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID, false) if err != nil { return nil, err } @@ -392,7 +392,7 @@ func (a *usersServer) InformOrganizationBillingConfigured(ctx context.Context, r return nil, err } - members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID) + members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID, false) if err != nil { return nil, err }