Skip to content

Commit

Permalink
Add a flag to find users list for deleted orgs
Browse files Browse the repository at this point in the history
To avoid any unexpected behaviour
  • Loading branch information
Marcus Cobden committed Sep 25, 2018
1 parent 0926f68 commit 9853ced
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 36 deletions.
4 changes: 3 additions & 1 deletion users-sync/attrsync/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ func (c *AttributeSyncer) EnqueueOrgsSync(ctx context.Context, orgExternalIDs []
// we could do this DB lookup later and remove it from the Enqueue call path
// but it should be relatively cheap.
for _, externalID := range orgExternalIDs {
users, err := c.db.ListOrganizationUsers(ctx, externalID)
// We request memberships for deleted orgs too so that we update users who
// were connected to the deleted orgs
users, err := c.db.ListOrganizationUsers(ctx, externalID, true)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions users/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (a *API) adminListUsersForOrganization(w http.ResponseWriter, r *http.Reque
renderError(w, r, users.ErrNotFound)
return
}
us, err := a.db.ListOrganizationUsers(r.Context(), orgID)
us, err := a.db.ListOrganizationUsers(r.Context(), orgID, true)
if err != nil {
renderError(w, r, err)
return
Expand All @@ -130,7 +130,7 @@ func (a *API) adminRemoveUserFromOrganization(w http.ResponseWriter, r *http.Req
orgExternalID := vars["orgExternalID"]
userID := vars["userID"]

if members, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID); err != nil {
if members, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID, false); err != nil {
renderError(w, r, err)
return
} else if len(members) == 1 {
Expand Down
6 changes: 3 additions & 3 deletions users/api/org.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ func (a *API) listOrganizationUsers(currentUser *users.User, w http.ResponseWrit
return
}

users, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID)
users, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID, false)
if err != nil {
renderError(w, r, err)
return
Expand Down Expand Up @@ -492,7 +492,7 @@ func (a *API) removeUser(currentUser *users.User, w http.ResponseWriter, r *http
return
}

if members, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID); err != nil {
if members, err := a.db.ListOrganizationUsers(r.Context(), orgExternalID, false); err != nil {
renderError(w, r, err)
return
} else if len(members) == 1 {
Expand Down Expand Up @@ -599,7 +599,7 @@ func (a *API) extendOrgTrialPeriod(ctx context.Context, org *users.Organization,
return nil
}

members, err := a.db.ListOrganizationUsers(ctx, org.ExternalID)
members, err := a.db.ListOrganizationUsers(ctx, org.ExternalID, false)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion users/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type DB interface {
ListUsers(ctx context.Context, f filter.User, page uint64) ([]*users.User, error)
ListOrganizations(ctx context.Context, f filter.Organization, page uint64) ([]*users.Organization, error)
ListAllOrganizations(ctx context.Context, f filter.Organization, page uint64) ([]*users.Organization, error)
ListOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error)
ListOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error)

// ListOrganizationsForUserIDs lists all organizations these users have
// access to.
Expand Down
8 changes: 4 additions & 4 deletions users/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ func TestDB_RemoveOtherUsersAccess(t *testing.T) {
require.NoError(t, err)
require.Len(t, otherUserOrganizations, 1)

orgUsers, err := db.ListOrganizationUsers(context.Background(), org.ExternalID)
orgUsers, err := db.ListOrganizationUsers(context.Background(), org.ExternalID, false)
require.NoError(t, err)
require.Len(t, orgUsers, 2)

err = db.RemoveUserFromOrganization(context.Background(), org.ExternalID, otherUser.Email)
require.NoError(t, err)

orgUsers, err = db.ListOrganizationUsers(context.Background(), org.ExternalID)
orgUsers, err = db.ListOrganizationUsers(context.Background(), org.ExternalID, false)
require.NoError(t, err)
require.Len(t, orgUsers, 1)
}
Expand Down Expand Up @@ -73,7 +73,7 @@ func TestDB_RemoveOtherUsersAccessWithTeams(t *testing.T) {
require.Len(t, otherUserTeams, 1)
require.Equal(t, team.ID, otherUserTeams[0].ID)

orgUsers, err := db.ListOrganizationUsers(ctx, org.ExternalID)
orgUsers, err := db.ListOrganizationUsers(ctx, org.ExternalID, false)
require.NoError(t, err)
require.Len(t, orgUsers, 2)

Expand All @@ -84,7 +84,7 @@ func TestDB_RemoveOtherUsersAccessWithTeams(t *testing.T) {
err = db.RemoveUserFromOrganization(ctx, org.ExternalID, otherUser.Email)
require.NoError(t, err)

orgUsers, err = db.ListOrganizationUsers(ctx, org.ExternalID)
orgUsers, err = db.ListOrganizationUsers(ctx, org.ExternalID, false)
require.NoError(t, err)
require.Len(t, orgUsers, 1)

Expand Down
12 changes: 8 additions & 4 deletions users/db/memory/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,26 @@ func (d *DB) ListAllOrganizations(_ context.Context, f filter.Organization, page
}

// ListOrganizationUsers lists all the users in an organization
func (d *DB) ListOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error) {
func (d *DB) ListOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error) {
d.mtx.Lock()
defer d.mtx.Unlock()
return d.listOrganizationUsers(ctx, orgExternalID)
return d.listOrganizationUsers(ctx, orgExternalID, includeDeletedOrgs)
}

// listOrganizationUsers lists all the users in an organization
// This is a lock-free version of the above, in order to be able to re-use the actual logic
// in other methods as otherwise, calling mtx.Lock() twice on the same goroutine deadlocks it.
func (d *DB) listOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error) {
func (d *DB) listOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error) {
o, err := d.findOrganizationByExternalID(orgExternalID)
if err != nil {
return nil, err
}

var users []*users.User
if !o.DeletedAt.IsZero() && !includeDeletedOrgs {
return users, nil
}

for _, m := range d.memberships[o.ID] {
u, err := d.findUserByID(m)
if err != nil {
Expand Down Expand Up @@ -808,7 +812,7 @@ func (d *DB) GetSummary(ctx context.Context) ([]*users.SummaryEntry, error) {
entries := []*users.SummaryEntry{}
for _, org := range d.organizations {
team := d.teams[org.TeamID]
orgUsers, err := d.listOrganizationUsers(ctx, org.ExternalID)
orgUsers, err := d.listOrganizationUsers(ctx, org.ExternalID, false)
if err != nil {
return nil, err
}
Expand Down
35 changes: 22 additions & 13 deletions users/db/postgres/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,13 @@ func (d DB) ListAllOrganizations(ctx context.Context, f filter.Organization, pag
}

// ListOrganizationUsers lists all the users in an organization
func (d DB) ListOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error) {
orgUsers, err := d.listDirectOrganizationUsers(ctx, orgExternalID)
// it will still return a user list for 'deleted' organizations
func (d DB) ListOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error) {
orgUsers, err := d.listDirectOrganizationUsers(ctx, orgExternalID, includeDeletedOrgs)
if err != nil {
return nil, err
}
teamUsers, err := d.listTeamOrganizationUsers(ctx, orgExternalID)
teamUsers, err := d.listTeamOrganizationUsers(ctx, orgExternalID, includeDeletedOrgs)
if err != nil {
return nil, err
}
Expand All @@ -206,14 +207,18 @@ func (d DB) ListOrganizationUsers(ctx context.Context, orgExternalID string) ([]
return users, nil
}

func (d DB) listDirectOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error) {
func (d DB) listDirectOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error) {
filter := squirrel.Eq{
"organizations.external_id": orgExternalID,
"memberships.deleted_at": nil,
}
if !includeDeletedOrgs {
filter["organizations.deleted_at"] = nil
}
rows, err := d.usersQuery().
Join("memberships on (memberships.user_id = users.id)").
Join("organizations on (memberships.organization_id = organizations.id)").
Where(squirrel.Eq{
"organizations.external_id": orgExternalID,
"memberships.deleted_at": nil,
}).
Where(filter).
OrderBy("users.created_at").
QueryContext(ctx)
if err != nil {
Expand All @@ -223,14 +228,18 @@ func (d DB) listDirectOrganizationUsers(ctx context.Context, orgExternalID strin
return d.scanUsers(rows)
}

func (d DB) listTeamOrganizationUsers(ctx context.Context, orgExternalID string) ([]*users.User, error) {
func (d DB) listTeamOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) ([]*users.User, error) {
filter := squirrel.Eq{
"organizations.external_id": orgExternalID,
"team_memberships.deleted_at": nil,
}
if !includeDeletedOrgs {
filter["organizations.deleted_at"] = nil
}
rows, err := d.usersQuery().
Join("team_memberships on (team_memberships.user_id = users.id)").
Join("organizations on (team_memberships.team_id = organizations.team_id)").
Where(squirrel.Eq{
"organizations.external_id": orgExternalID,
"team_memberships.deleted_at": nil,
}).
Where(filter).
QueryContext(ctx)
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions users/db/timed.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ func (t timed) ListAllOrganizations(ctx context.Context, f filter.Organization,
return
}

func (t timed) ListOrganizationUsers(ctx context.Context, orgExternalID string) (us []*users.User, err error) {
func (t timed) ListOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) (us []*users.User, err error) {
t.timeRequest(ctx, "ListOrganizationUsers", func(ctx context.Context) error {
us, err = t.d.ListOrganizationUsers(ctx, orgExternalID)
us, err = t.d.ListOrganizationUsers(ctx, orgExternalID, includeDeletedOrgs)
return err
})
return
Expand Down
4 changes: 2 additions & 2 deletions users/db/traced.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ func (t traced) ListAllOrganizations(ctx context.Context, f filter.Organization,
return t.d.ListAllOrganizations(ctx, f, page)
}

func (t traced) ListOrganizationUsers(ctx context.Context, orgExternalID string) (us []*users.User, err error) {
func (t traced) ListOrganizationUsers(ctx context.Context, orgExternalID string, includeDeletedOrgs bool) (us []*users.User, err error) {
defer t.trace("ListOrganizationUsers", orgExternalID, us, err)
return t.d.ListOrganizationUsers(ctx, orgExternalID)
return t.d.ListOrganizationUsers(ctx, orgExternalID, includeDeletedOrgs)
}

func (t traced) ListOrganizationsForUserIDs(ctx context.Context, userIDs ...string) (os []*users.Organization, err error) {
Expand Down
8 changes: 4 additions & 4 deletions users/grpc/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func (a *usersServer) NotifyTrialPendingExpiry(ctx context.Context, req *users.N
}

// Notify all users
members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID)
members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID, false)
if err != nil {
return nil, err
}
Expand All @@ -325,7 +325,7 @@ func (a *usersServer) NotifyTrialExpired(ctx context.Context, req *users.NotifyT
}

// Notify all users
members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID)
members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID, false)
if err != nil {
return nil, err
}
Expand All @@ -349,7 +349,7 @@ func (a *usersServer) NotifyRefuseDataUpload(ctx context.Context, req *users.Not
}

// Notify all users
members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID)
members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -392,7 +392,7 @@ func (a *usersServer) InformOrganizationBillingConfigured(ctx context.Context, r
return nil, err
}

members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID)
members, err := a.db.ListOrganizationUsers(ctx, req.ExternalID, false)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 9853ced

Please sign in to comment.