Skip to content

Commit

Permalink
Move unique functionality into getGroups to reduce calls to google
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Hoey <[email protected]>
  • Loading branch information
snuggie12 committed Oct 29, 2022
1 parent b010071 commit 1c544f5
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 31 deletions.
43 changes: 19 additions & 24 deletions connector/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector

var groups []string
if s.Groups && c.adminSrv != nil {
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership)
checkedGroups := make(map[string]struct{})
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
}
Expand All @@ -253,7 +254,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector

// getGroups creates a connection to the admin directory service and lists
// all groups the user is a member of
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool) ([]string, error) {
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
var userGroups []string
var err error
groupsList := &admin.Groups{}
Expand All @@ -265,26 +266,33 @@ func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership
}

for _, group := range groupsList.Groups {
if _, exists := checkedGroups[group.Email]; exists {
continue
}

checkedGroups[group.Email] = struct{}{}
// TODO (joelspeed): Make desired group key configurable
userGroups = append(userGroups, group.Email)

// getGroups takes a user's email/alias as well as a group's email/alias
if fetchTransitiveGroupMembership {
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership)
if err != nil {
return nil, fmt.Errorf("could not list transitive groups: %v", err)
}
if !fetchTransitiveGroupMembership {
continue
}

userGroups = append(userGroups, transitiveGroups...)
// getGroups takes a user's email/alias as well as a group's email/alias
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return nil, fmt.Errorf("could not list transitive groups: %v", err)
}

userGroups = append(userGroups, transitiveGroups...)
}

if groupsList.NextPageToken == "" {
break
}
}

return uniqueGroups(userGroups), nil
return userGroups, nil
}

// createDirectoryService sets up super user impersonation and creates an admin client for calling
Expand Down Expand Up @@ -316,7 +324,7 @@ func createDirectoryService(serviceAccountFilePath, email string, logger log.Log
}
config, err := google.JWTConfigFromJSON(jsonCredentials, admin.AdminDirectoryGroupReadonlyScope)
if err != nil {
return nil, fmt.Errorf("unable to parse credentials to config: %v", err)
return nil, fmt.Errorf("unable to parse client secret file to config: %v", err)
}

// Only attempt impersonation when there is a user configured
Expand All @@ -326,16 +334,3 @@ func createDirectoryService(serviceAccountFilePath, email string, logger log.Log

return admin.NewService(ctx, option.WithHTTPClient(config.Client(ctx)))
}

// uniqueGroups returns the unique groups of a slice
func uniqueGroups(groups []string) []string {
keys := make(map[string]struct{})
unique := []string{}
for _, group := range groups {
if _, exists := keys[group]; !exists {
keys[group] = struct{}{}
unique = append(unique, group)
}
}
return unique
}
106 changes: 99 additions & 7 deletions connector/google/google_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package google

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand All @@ -10,17 +11,38 @@ import (

"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/option"
)

var (
// groups_0
// ┌───────┤
// groups_2 groups_1
// │ ├────────┐
// └── user_1 user_2
testGroups = map[string][]*admin.Group{
"[email protected]": {{Email: "[email protected]"}, {Email: "[email protected]"}},
"[email protected]": {{Email: "[email protected]"}},
"[email protected]": {{Email: "[email protected]"}},
"[email protected]": {{Email: "[email protected]"}},
"[email protected]": {},
}
callCounter = make(map[string]int)
)

func testSetup(t *testing.T) *httptest.Server {
mux := http.NewServeMux()
// TODO: mock calls
// mux.HandleFunc("/admin/directory/v1/groups", func(w http.ResponseWriter, r *http.Request) {
// w.Header().Add("Content-Type", "application/json")
// json.NewEncoder(w).Encode(&admin.Groups{
// Groups: []*admin.Group{},
// })
// })

mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
userKey := r.URL.Query().Get("userKey")
if groups, ok := testGroups[userKey]; ok {
json.NewEncoder(w).Encode(admin.Groups{Groups: groups})
callCounter[userKey]++
}
})

return httptest.NewServer(mux)
}

Expand Down Expand Up @@ -144,3 +166,73 @@ func TestOpen(t *testing.T) {
})
}
}

func TestGetGroups(t *testing.T) {
ts := testSetup(t)
defer ts.Close()

serviceAccountFilePath, err := tempServiceAccountKey()
assert.Nil(t, err)

os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath)
conn, err := newConnector(&Config{
ClientID: "testClient",
ClientSecret: "testSecret",
RedirectURI: ts.URL + "/callback",
Scopes: []string{"openid", "groups"},
AdminEmail: "[email protected]",
}, ts.URL)
assert.Nil(t, err)

conn.adminSrv, err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
assert.Nil(t, err)
type testCase struct {
userKey string
fetchTransitiveGroupMembership bool
shouldErr bool
expectedGroups []string
}

for name, testCase := range map[string]testCase{
"user1_non_transitive_lookup": {
userKey: "[email protected]",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"[email protected]", "[email protected]"},
},
"user1_transitive_lookup": {
userKey: "[email protected]",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"[email protected]", "[email protected]", "[email protected]"},
},
"user2_non_transitive_lookup": {
userKey: "[email protected]",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"[email protected]"},
},
"user2_transitive_lookup": {
userKey: "[email protected]",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"[email protected]", "[email protected]"},
},
} {
testCase := testCase
callCounter = map[string]int{}
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})

groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
if testCase.shouldErr {
assert.NotNil(err)
} else {
assert.Nil(err)
}
assert.ElementsMatch(testCase.expectedGroups, groups)
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
})
}
}

0 comments on commit 1c544f5

Please sign in to comment.