Skip to content

Commit

Permalink
#2911: adjust code after code review, add more test coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Viacheslav Sychov <[email protected]>
  • Loading branch information
vsychov committed May 24, 2023
1 parent 27d7eb5 commit 3f5e6c7
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
8 changes: 4 additions & 4 deletions connector/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ type Config struct {

// Open returns a connector which can be used to login users through Google.
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) {
if len(c.AdminEmail) != 0 {
if c.AdminEmail != "" {
log.Deprecated(logger, `google: use "domainToAdminEmail.*: %s" option instead of "adminEmail: %s".`, c.AdminEmail, c.AdminEmail)
if c.DomainToAdminEmail == nil {
c.DomainToAdminEmail = make(map[string]string)
Expand Down Expand Up @@ -91,11 +91,11 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
// TODO: or is it?
if len(c.DomainToAdminEmail) == 0 && c.ServiceAccountFilePath != "" {
cancel()
return nil, fmt.Errorf("directory service requires adminEmail")
return nil, fmt.Errorf("directory service requires the domainToAdminEmail option to be configured")
}

// Fixing a regression caused by default config fallback: https://github.com/dexidp/dex/issues/2699
if (c.ServiceAccountFilePath != "" && len(c.DomainToAdminEmail) != 0) || slices.Contains(scopes, "groups") {
if (c.ServiceAccountFilePath != "" && len(c.DomainToAdminEmail) > 0) || slices.Contains(scopes, "groups") {
for domain, adminEmail := range c.DomainToAdminEmail {
srv, err := createDirectoryService(c.ServiceAccountFilePath, adminEmail, logger)
if err != nil {
Expand Down Expand Up @@ -248,7 +248,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
}

var groups []string
if s.Groups && len(c.adminSrv) != 0 {
if s.Groups && len(c.adminSrv) > 0 {
checkedGroups := make(map[string]struct{})
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
Expand Down
57 changes: 56 additions & 1 deletion connector/google/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestOpen(t *testing.T) {
Scopes: []string{"openid", "groups"},
ServiceAccountFilePath: serviceAccountFilePath,
},
expectedErr: "requires adminEmail",
expectedErr: "requires the domainToAdminEmail",
},
"service_account_key_not_found": {
config: &Config{
Expand Down Expand Up @@ -236,3 +236,58 @@ func TestGetGroups(t *testing.T) {
})
}
}

func TestDomainToAdminEmailConfig(t *testing.T) {
ts := testSetup()
defer ts.Close()

serviceAccountFilePath, err := tempServiceAccountKey()
assert.Nil(t, err)

os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath)
conn, err := newConnector(&Config{
ClientID: "testClient",
ClientSecret: "testSecret",
RedirectURI: ts.URL + "/callback",
Scopes: []string{"openid", "groups"},
DomainToAdminEmail: map[string]string{"dexidp.com": "[email protected]"},
})
assert.Nil(t, err)

conn.adminSrv["dexidp.com"], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
assert.Nil(t, err)
type testCase struct {
userKey string
expectedErr string
}

for name, testCase := range map[string]testCase{
"correct_user_request": {
userKey: "[email protected]",
expectedErr: "",
},
"wrong_user_request": {
userKey: "[email protected]",
expectedErr: "unable to find super admin email",
},
"wrong_connector_response": {
userKey: "user_1_foo.bar",
expectedErr: "unable to find super admin email",
},
} {
testCase := testCase
callCounter = map[string]int{}
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})

_, err := conn.getGroups(testCase.userKey, true, lookup)
if testCase.expectedErr != "" {
assert.ErrorContains(err, testCase.expectedErr)
} else {
assert.Nil(err)
}
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
})
}
}

0 comments on commit 3f5e6c7

Please sign in to comment.