Skip to content

Commit

Permalink
fixes bound_claims validation for provider-specific group and user in…
Browse files Browse the repository at this point in the history
…fo fetching (#149)
  • Loading branch information
austingebauer authored Dec 2, 2020
1 parent 041c525 commit c1176c2
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 69 deletions.
2 changes: 1 addition & 1 deletion path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
}

// Validate provider_config
if _, err := NewProviderConfig(config, ProviderMap()); err != nil {
if _, err := NewProviderConfig(ctx, config, ProviderMap()); err != nil {
return logical.ErrorResponse("invalid provider_config: %s", err), nil
}

Expand Down
43 changes: 23 additions & 20 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,15 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d
return nil, errors.New("unhandled case during login")
}

if err := validateBoundClaims(b.Logger(), role.BoundClaimsType, role.BoundClaims, allClaims); err != nil {
return logical.ErrorResponse("error validating claims: %s", err.Error()), nil
}

alias, groupAliases, err := b.createIdentity(allClaims, role)
alias, groupAliases, err := b.createIdentity(ctx, allClaims, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}

if err := validateBoundClaims(b.Logger(), role.BoundClaimsType, role.BoundClaims, allClaims); err != nil {
return logical.ErrorResponse("error validating claims: %s", err.Error()), nil
}

tokenMetadata := map[string]string{"role": roleName}
for k, v := range alias.Metadata {
tokenMetadata[k] = v
Expand Down Expand Up @@ -308,7 +308,7 @@ func (b *jwtAuthBackend) verifyOIDCToken(ctx context.Context, config *jwtConfig,

// createIdentity creates an alias and set of groups aliases based on the role
// definition and received claims.
func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *jwtRole) (*logical.Alias, []*logical.Alias, error) {
func (b *jwtAuthBackend) createIdentity(ctx context.Context, allClaims map[string]interface{}, role *jwtRole) (*logical.Alias, []*logical.Alias, error) {
userClaimRaw, ok := allClaims[role.UserClaim]
if !ok {
return nil, nil, fmt.Errorf("claim %q not found in token", role.UserClaim)
Expand All @@ -318,8 +318,12 @@ func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *
return nil, nil, fmt.Errorf("claim %q could not be converted to string", role.UserClaim)
}

err := b.fetchUserInfo(allClaims, role)
pConfig, err := NewProviderConfig(ctx, b.cachedConfig, ProviderMap())
if err != nil {
return nil, nil, fmt.Errorf("failed to load custom provider config: %s", err)
}

if err := b.fetchUserInfo(ctx, pConfig, allClaims, role); err != nil {
return nil, nil, err
}

Expand All @@ -339,7 +343,7 @@ func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *
return alias, groupAliases, nil
}

groupsClaimRaw, err := b.fetchGroups(allClaims, role)
groupsClaimRaw, err := b.fetchGroups(ctx, pConfig, allClaims, role)
if err != nil {
return nil, nil, fmt.Errorf("failed to fetch groups: %s", err)
}
Expand All @@ -366,32 +370,31 @@ func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *
}

// Checks if there's a custom provider_config and calls FetchUserInfo() if implemented.
func (b *jwtAuthBackend) fetchUserInfo(allClaims map[string]interface{}, role *jwtRole) error {
pConfig, err := NewProviderConfig(b.cachedConfig, ProviderMap())
if err != nil {
return fmt.Errorf("failed to load custom provider config: %s", err)
}
func (b *jwtAuthBackend) fetchUserInfo(ctx context.Context, pConfig CustomProvider, allClaims map[string]interface{}, role *jwtRole) error {
// Fetch user info from custom provider if it's implemented
if pConfig != nil {
if uif, ok := pConfig.(UserInfoFetcher); ok {
return uif.FetchUserInfo(b, allClaims, role)
return uif.FetchUserInfo(ctx, b, allClaims, role)
}
}

return nil
}

// Checks if there's a custom provider_config and calls FetchGroups() if implemented
func (b *jwtAuthBackend) fetchGroups(allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
pConfig, err := NewProviderConfig(b.cachedConfig, ProviderMap())
if err != nil {
return nil, fmt.Errorf("failed to load custom provider config: %s", err)
}
func (b *jwtAuthBackend) fetchGroups(ctx context.Context, pConfig CustomProvider, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
// If the custom provider implements interface GroupsFetcher, call it,
// otherwise fall through to the default method
if pConfig != nil {
if gf, ok := pConfig.(GroupsFetcher); ok {
return gf.FetchGroups(b, allClaims, role)
groupsRaw, err := gf.FetchGroups(ctx, b, allClaims, role)
if err != nil {
return nil, err
}

// Add groups obtained by provider-specific fetching to the claims
// so that they can be used for bound_claims validation on the role.
allClaims["groups"] = groupsRaw
}
}
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)
Expand Down
10 changes: 5 additions & 5 deletions path_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,15 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
}
}

if err := validateBoundClaims(b.Logger(), role.BoundClaimsType, role.BoundClaims, allClaims); err != nil {
return logical.ErrorResponse("error validating claims: %s", err.Error()), nil
}

alias, groupAliases, err := b.createIdentity(allClaims, role)
alias, groupAliases, err := b.createIdentity(ctx, allClaims, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}

if err := validateBoundClaims(b.Logger(), role.BoundClaimsType, role.BoundClaims, allClaims); err != nil {
return logical.ErrorResponse("error validating claims: %s", err.Error()), nil
}

tokenMetadata := map[string]string{"role": roleName}
for k, v := range alias.Metadata {
tokenMetadata[k] = v
Expand Down
4 changes: 2 additions & 2 deletions provider_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type AzureProvider struct {
}

// Initialize anything in the AzureProvider struct - satisfying the CustomProvider interface
func (a *AzureProvider) Initialize(jc *jwtConfig) error {
func (a *AzureProvider) Initialize(_ context.Context, _ *jwtConfig) error {
return nil
}

Expand All @@ -45,7 +45,7 @@ func (a *AzureProvider) SensitiveKeys() []string {
}

// FetchGroups - custom groups fetching for azure - satisfying GroupsFetcher interface
func (a *AzureProvider) FetchGroups(b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
func (a *AzureProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)

if groupsClaimRaw == nil {
Expand Down
12 changes: 10 additions & 2 deletions provider_azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func TestLogin_fetchGroups(t *testing.T) {
require.NoError(t, err)

b, storage := getBackend(t)
ctx := context.Background()

data := map[string]interface{}{
"oidc_discovery_url": aServer.server.URL,
Expand Down Expand Up @@ -156,12 +157,19 @@ func TestLogin_fetchGroups(t *testing.T) {
}

// Ensure b.cachedConfig is populated
_, err = b.(*jwtAuthBackend).config(context.Background(), storage)
config, err := b.(*jwtAuthBackend).config(ctx, storage)
if err != nil {
t.Fatal(err)
}

groupsResp, err := b.(*jwtAuthBackend).fetchGroups(allClaims, role)
// Initialize the azure provider
provider, err := NewProviderConfig(ctx, config, ProviderMap())
if err != nil {
t.Fatal(err)
}

// Ensure groups are as expected
groupsResp, err := b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, role)
assert.NoError(t, err)
assert.Equal(t, []interface{}{"group1", "group2"}, groupsResp)
}
Expand Down
11 changes: 6 additions & 5 deletions provider_config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwtauth

import (
"context"
"fmt"
)

Expand All @@ -21,7 +22,7 @@ type CustomProvider interface {
// Initialize should validate jwtConfig.ProviderConfig, set internal values
// and run any initialization necessary for subsequent calls to interface
// functions the provider implements
Initialize(*jwtConfig) error
Initialize(context.Context, *jwtConfig) error

// SensitiveKeys returns any fields in a provider's jwtConfig.ProviderConfig
// that should be masked or omitted when output
Expand All @@ -31,7 +32,7 @@ type CustomProvider interface {
// NewProviderConfig - returns appropriate provider struct if provider_config is
// specified in jwtConfig. The provider map is provider name -to- instance of a
// CustomProvider.
func NewProviderConfig(jc *jwtConfig, providerMap map[string]CustomProvider) (CustomProvider, error) {
func NewProviderConfig(ctx context.Context, jc *jwtConfig, providerMap map[string]CustomProvider) (CustomProvider, error) {
if len(jc.ProviderConfig) == 0 {
return nil, nil
}
Expand All @@ -43,19 +44,19 @@ func NewProviderConfig(jc *jwtConfig, providerMap map[string]CustomProvider) (Cu
if !ok {
return nil, fmt.Errorf("provider %q not found in custom providers", provider)
}
if err := newCustomProvider.Initialize(jc); err != nil {
if err := newCustomProvider.Initialize(ctx, jc); err != nil {
return nil, fmt.Errorf("error initializing %q provider_config: %s", provider, err)
}
return newCustomProvider, nil
}

// UserInfoFetcher - Optional support for custom user info handling
type UserInfoFetcher interface {
FetchUserInfo(*jwtAuthBackend, map[string]interface{}, *jwtRole) error
FetchUserInfo(context.Context, *jwtAuthBackend, map[string]interface{}, *jwtRole) error
}

// GroupsFetcher - Optional support for custom groups handling
type GroupsFetcher interface {
// FetchGroups queries for groups claims during login
FetchGroups(*jwtAuthBackend, map[string]interface{}, *jwtRole) (interface{}, error)
FetchGroups(context.Context, *jwtAuthBackend, map[string]interface{}, *jwtRole) (interface{}, error)
}
15 changes: 8 additions & 7 deletions provider_config_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwtauth

import (
"context"
"fmt"
"testing"

Expand All @@ -12,7 +13,7 @@ type testProviderConfig struct {
throwError bool
}

func (t *testProviderConfig) Initialize(jc *jwtConfig) error {
func (t *testProviderConfig) Initialize(_ context.Context, jc *jwtConfig) error {
if t.throwError {
return fmt.Errorf("i'm throwing an error")
}
Expand All @@ -37,7 +38,7 @@ func TestNewProviderConfig(t *testing.T) {
"test": &testProviderConfig{},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.NoError(t, err)
assert.Equal(t, "yes", theProvider.(*testProviderConfig).initialized)

Expand All @@ -51,7 +52,7 @@ func TestNewProviderConfig(t *testing.T) {
"test": &testProviderConfig{},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.NoError(t, err)
assert.Nil(t, theProvider)
})
Expand All @@ -66,7 +67,7 @@ func TestNewProviderConfig(t *testing.T) {
"test": &testProviderConfig{},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.EqualError(t, err, "'provider' field not found in provider_config")
assert.Nil(t, theProvider)
})
Expand All @@ -82,7 +83,7 @@ func TestNewProviderConfig(t *testing.T) {
"not-test": &testProviderConfig{},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.EqualError(t, err, "provider \"test\" not found in custom providers")
assert.Nil(t, theProvider)
})
Expand All @@ -97,7 +98,7 @@ func TestNewProviderConfig(t *testing.T) {
"test": &testProviderConfig{},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.EqualError(t, err, "'provider' field not found in provider_config")
assert.Nil(t, theProvider)
})
Expand All @@ -113,7 +114,7 @@ func TestNewProviderConfig(t *testing.T) {
"test": &testProviderConfig{throwError: true},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.EqualError(t, err, "error initializing \"test\" provider_config: i'm throwing an error")
assert.Nil(t, theProvider)
})
Expand Down
33 changes: 14 additions & 19 deletions provider_gsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type GSuiteProviderConfig struct {
}

// Initialize initializes the GSuiteProvider by validating and creating configuration.
func (g *GSuiteProvider) Initialize(jc *jwtConfig) error {
func (g *GSuiteProvider) Initialize(ctx context.Context, jc *jwtConfig) error {
// Decode the provider config
var config GSuiteProviderConfig
if err := mapstructure.Decode(jc.ProviderConfig, &config); err != nil {
Expand All @@ -60,10 +60,10 @@ func (g *GSuiteProvider) Initialize(jc *jwtConfig) error {
}
config.serviceAccountKeyJSON = keyJSON

return g.initialize(config)
return g.initialize(ctx, config)
}

func (g *GSuiteProvider) initialize(config GSuiteProviderConfig) error {
func (g *GSuiteProvider) initialize(ctx context.Context, config GSuiteProviderConfig) error {
var err error

// Validate configuration
Expand All @@ -88,6 +88,13 @@ func (g *GSuiteProvider) initialize(config GSuiteProviderConfig) error {
// Set the subject to impersonate and config
g.jwtConfig.Subject = config.AdminImpersonateEmail
g.config = config

// Create a new admin service for requests to Google admin APIs
g.adminSvc, err = admin.NewService(ctx, option.WithHTTPClient(g.jwtConfig.Client(ctx)))
if err != nil {
return err
}

return nil
}

Expand All @@ -97,7 +104,7 @@ func (g *GSuiteProvider) SensitiveKeys() []string {
}

// FetchGroups fetches and returns groups from G Suite.
func (g *GSuiteProvider) FetchGroups(b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
func (g *GSuiteProvider) FetchGroups(ctx context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
if !g.config.FetchGroups {
return nil, nil
}
Expand All @@ -107,15 +114,9 @@ func (g *GSuiteProvider) FetchGroups(b *jwtAuthBackend, allClaims map[string]int
return nil, err
}

// Set context and create a new admin service for requests to Google admin APIs
g.adminSvc, err = admin.NewService(b.providerCtx, option.WithHTTPClient(g.jwtConfig.Client(b.providerCtx)))
if err != nil {
return nil, err
}

// Get the G Suite groups
userGroupsMap := make(map[string]bool)
if err := g.search(b.providerCtx, userGroupsMap, userName, g.config.GroupsRecurseMaxDepth); err != nil {
if err := g.search(ctx, userGroupsMap, userName, g.config.GroupsRecurseMaxDepth); err != nil {
return nil, err
}

Expand Down Expand Up @@ -157,7 +158,7 @@ func (g *GSuiteProvider) search(ctx context.Context, visited map[string]bool, us
}

// FetchUserInfo fetches additional user information from G Suite using custom schemas.
func (g *GSuiteProvider) FetchUserInfo(b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) error {
func (g *GSuiteProvider) FetchUserInfo(ctx context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) error {
if !g.config.FetchUserInfo || g.config.UserCustomSchemas == "" {
if g.config.UserCustomSchemas != "" {
b.Logger().Warn(fmt.Sprintf("must set 'fetch_user_info=true' to fetch 'user_custom_schemas': %s", g.config.UserCustomSchemas))
Expand All @@ -171,13 +172,7 @@ func (g *GSuiteProvider) FetchUserInfo(b *jwtAuthBackend, allClaims map[string]i
return err
}

// Set context and create a new admin service for requests to Google admin APIs
g.adminSvc, err = admin.NewService(b.providerCtx, option.WithHTTPClient(g.jwtConfig.Client(b.providerCtx)))
if err != nil {
return err
}

return g.fillCustomSchemas(b.providerCtx, userName, allClaims)
return g.fillCustomSchemas(ctx, userName, allClaims)
}

// fillCustomSchemas fetches G Suite user information associated with the custom schemas
Expand Down
Loading

0 comments on commit c1176c2

Please sign in to comment.