Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes bound_claims validation for provider-specific group and user info fetching #149

Merged
merged 1 commit into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
}

// Validate provider_config
if _, err := NewProviderConfig(config, ProviderMap()); err != nil {
if _, err := NewProviderConfig(ctx, config, ProviderMap()); err != nil {
return logical.ErrorResponse("invalid provider_config: %s", err), nil
}

Expand Down
43 changes: 23 additions & 20 deletions path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,15 @@ func (b *jwtAuthBackend) pathLogin(ctx context.Context, req *logical.Request, d
return nil, errors.New("unhandled case during login")
}

if err := validateBoundClaims(b.Logger(), role.BoundClaimsType, role.BoundClaims, allClaims); err != nil {
return logical.ErrorResponse("error validating claims: %s", err.Error()), nil
}

alias, groupAliases, err := b.createIdentity(allClaims, role)
alias, groupAliases, err := b.createIdentity(ctx, allClaims, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}

if err := validateBoundClaims(b.Logger(), role.BoundClaimsType, role.BoundClaims, allClaims); err != nil {
return logical.ErrorResponse("error validating claims: %s", err.Error()), nil
}

tokenMetadata := map[string]string{"role": roleName}
for k, v := range alias.Metadata {
tokenMetadata[k] = v
Expand Down Expand Up @@ -308,7 +308,7 @@ func (b *jwtAuthBackend) verifyOIDCToken(ctx context.Context, config *jwtConfig,

// createIdentity creates an alias and set of groups aliases based on the role
// definition and received claims.
func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *jwtRole) (*logical.Alias, []*logical.Alias, error) {
func (b *jwtAuthBackend) createIdentity(ctx context.Context, allClaims map[string]interface{}, role *jwtRole) (*logical.Alias, []*logical.Alias, error) {
userClaimRaw, ok := allClaims[role.UserClaim]
if !ok {
return nil, nil, fmt.Errorf("claim %q not found in token", role.UserClaim)
Expand All @@ -318,8 +318,12 @@ func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *
return nil, nil, fmt.Errorf("claim %q could not be converted to string", role.UserClaim)
}

err := b.fetchUserInfo(allClaims, role)
pConfig, err := NewProviderConfig(ctx, b.cachedConfig, ProviderMap())
if err != nil {
return nil, nil, fmt.Errorf("failed to load custom provider config: %s", err)
}

if err := b.fetchUserInfo(ctx, pConfig, allClaims, role); err != nil {
return nil, nil, err
}

Expand All @@ -339,7 +343,7 @@ func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *
return alias, groupAliases, nil
}

groupsClaimRaw, err := b.fetchGroups(allClaims, role)
groupsClaimRaw, err := b.fetchGroups(ctx, pConfig, allClaims, role)
if err != nil {
return nil, nil, fmt.Errorf("failed to fetch groups: %s", err)
}
Expand All @@ -366,32 +370,31 @@ func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *
}

// Checks if there's a custom provider_config and calls FetchUserInfo() if implemented.
func (b *jwtAuthBackend) fetchUserInfo(allClaims map[string]interface{}, role *jwtRole) error {
pConfig, err := NewProviderConfig(b.cachedConfig, ProviderMap())
if err != nil {
return fmt.Errorf("failed to load custom provider config: %s", err)
}
func (b *jwtAuthBackend) fetchUserInfo(ctx context.Context, pConfig CustomProvider, allClaims map[string]interface{}, role *jwtRole) error {
// Fetch user info from custom provider if it's implemented
if pConfig != nil {
if uif, ok := pConfig.(UserInfoFetcher); ok {
return uif.FetchUserInfo(b, allClaims, role)
return uif.FetchUserInfo(ctx, b, allClaims, role)
}
}

return nil
}

// Checks if there's a custom provider_config and calls FetchGroups() if implemented
func (b *jwtAuthBackend) fetchGroups(allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
pConfig, err := NewProviderConfig(b.cachedConfig, ProviderMap())
if err != nil {
return nil, fmt.Errorf("failed to load custom provider config: %s", err)
}
func (b *jwtAuthBackend) fetchGroups(ctx context.Context, pConfig CustomProvider, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
// If the custom provider implements interface GroupsFetcher, call it,
// otherwise fall through to the default method
if pConfig != nil {
if gf, ok := pConfig.(GroupsFetcher); ok {
return gf.FetchGroups(b, allClaims, role)
groupsRaw, err := gf.FetchGroups(ctx, b, allClaims, role)
if err != nil {
return nil, err
}

// Add groups obtained by provider-specific fetching to the claims
// so that they can be used for bound_claims validation on the role.
allClaims["groups"] = groupsRaw
}
}
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)
Expand Down
10 changes: 5 additions & 5 deletions path_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,15 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
}
}

if err := validateBoundClaims(b.Logger(), role.BoundClaimsType, role.BoundClaims, allClaims); err != nil {
return logical.ErrorResponse("error validating claims: %s", err.Error()), nil
}

alias, groupAliases, err := b.createIdentity(allClaims, role)
alias, groupAliases, err := b.createIdentity(ctx, allClaims, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}

if err := validateBoundClaims(b.Logger(), role.BoundClaimsType, role.BoundClaims, allClaims); err != nil {
return logical.ErrorResponse("error validating claims: %s", err.Error()), nil
}

tokenMetadata := map[string]string{"role": roleName}
for k, v := range alias.Metadata {
tokenMetadata[k] = v
Expand Down
4 changes: 2 additions & 2 deletions provider_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type AzureProvider struct {
}

// Initialize anything in the AzureProvider struct - satisfying the CustomProvider interface
func (a *AzureProvider) Initialize(jc *jwtConfig) error {
func (a *AzureProvider) Initialize(_ context.Context, _ *jwtConfig) error {
return nil
}

Expand All @@ -45,7 +45,7 @@ func (a *AzureProvider) SensitiveKeys() []string {
}

// FetchGroups - custom groups fetching for azure - satisfying GroupsFetcher interface
func (a *AzureProvider) FetchGroups(b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
func (a *AzureProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)

if groupsClaimRaw == nil {
Expand Down
12 changes: 10 additions & 2 deletions provider_azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func TestLogin_fetchGroups(t *testing.T) {
require.NoError(t, err)

b, storage := getBackend(t)
ctx := context.Background()

data := map[string]interface{}{
"oidc_discovery_url": aServer.server.URL,
Expand Down Expand Up @@ -156,12 +157,19 @@ func TestLogin_fetchGroups(t *testing.T) {
}

// Ensure b.cachedConfig is populated
_, err = b.(*jwtAuthBackend).config(context.Background(), storage)
config, err := b.(*jwtAuthBackend).config(ctx, storage)
if err != nil {
t.Fatal(err)
}

groupsResp, err := b.(*jwtAuthBackend).fetchGroups(allClaims, role)
// Initialize the azure provider
provider, err := NewProviderConfig(ctx, config, ProviderMap())
if err != nil {
t.Fatal(err)
}

// Ensure groups are as expected
groupsResp, err := b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, role)
assert.NoError(t, err)
assert.Equal(t, []interface{}{"group1", "group2"}, groupsResp)
}
Expand Down
11 changes: 6 additions & 5 deletions provider_config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwtauth

import (
"context"
"fmt"
)

Expand All @@ -21,7 +22,7 @@ type CustomProvider interface {
// Initialize should validate jwtConfig.ProviderConfig, set internal values
// and run any initialization necessary for subsequent calls to interface
// functions the provider implements
Initialize(*jwtConfig) error
Initialize(context.Context, *jwtConfig) error

// SensitiveKeys returns any fields in a provider's jwtConfig.ProviderConfig
// that should be masked or omitted when output
Expand All @@ -31,7 +32,7 @@ type CustomProvider interface {
// NewProviderConfig - returns appropriate provider struct if provider_config is
// specified in jwtConfig. The provider map is provider name -to- instance of a
// CustomProvider.
func NewProviderConfig(jc *jwtConfig, providerMap map[string]CustomProvider) (CustomProvider, error) {
func NewProviderConfig(ctx context.Context, jc *jwtConfig, providerMap map[string]CustomProvider) (CustomProvider, error) {
if len(jc.ProviderConfig) == 0 {
return nil, nil
}
Expand All @@ -43,19 +44,19 @@ func NewProviderConfig(jc *jwtConfig, providerMap map[string]CustomProvider) (Cu
if !ok {
return nil, fmt.Errorf("provider %q not found in custom providers", provider)
}
if err := newCustomProvider.Initialize(jc); err != nil {
if err := newCustomProvider.Initialize(ctx, jc); err != nil {
return nil, fmt.Errorf("error initializing %q provider_config: %s", provider, err)
}
return newCustomProvider, nil
}

// UserInfoFetcher - Optional support for custom user info handling
type UserInfoFetcher interface {
FetchUserInfo(*jwtAuthBackend, map[string]interface{}, *jwtRole) error
FetchUserInfo(context.Context, *jwtAuthBackend, map[string]interface{}, *jwtRole) error
}

// GroupsFetcher - Optional support for custom groups handling
type GroupsFetcher interface {
// FetchGroups queries for groups claims during login
FetchGroups(*jwtAuthBackend, map[string]interface{}, *jwtRole) (interface{}, error)
FetchGroups(context.Context, *jwtAuthBackend, map[string]interface{}, *jwtRole) (interface{}, error)
}
15 changes: 8 additions & 7 deletions provider_config_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwtauth

import (
"context"
"fmt"
"testing"

Expand All @@ -12,7 +13,7 @@ type testProviderConfig struct {
throwError bool
}

func (t *testProviderConfig) Initialize(jc *jwtConfig) error {
func (t *testProviderConfig) Initialize(_ context.Context, jc *jwtConfig) error {
calvn marked this conversation as resolved.
Show resolved Hide resolved
if t.throwError {
return fmt.Errorf("i'm throwing an error")
}
Expand All @@ -37,7 +38,7 @@ func TestNewProviderConfig(t *testing.T) {
"test": &testProviderConfig{},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.NoError(t, err)
assert.Equal(t, "yes", theProvider.(*testProviderConfig).initialized)

Expand All @@ -51,7 +52,7 @@ func TestNewProviderConfig(t *testing.T) {
"test": &testProviderConfig{},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.NoError(t, err)
assert.Nil(t, theProvider)
})
Expand All @@ -66,7 +67,7 @@ func TestNewProviderConfig(t *testing.T) {
"test": &testProviderConfig{},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.EqualError(t, err, "'provider' field not found in provider_config")
assert.Nil(t, theProvider)
})
Expand All @@ -82,7 +83,7 @@ func TestNewProviderConfig(t *testing.T) {
"not-test": &testProviderConfig{},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.EqualError(t, err, "provider \"test\" not found in custom providers")
assert.Nil(t, theProvider)
})
Expand All @@ -97,7 +98,7 @@ func TestNewProviderConfig(t *testing.T) {
"test": &testProviderConfig{},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.EqualError(t, err, "'provider' field not found in provider_config")
assert.Nil(t, theProvider)
})
Expand All @@ -113,7 +114,7 @@ func TestNewProviderConfig(t *testing.T) {
"test": &testProviderConfig{throwError: true},
}

theProvider, err := NewProviderConfig(jc, pMap)
theProvider, err := NewProviderConfig(context.Background(), jc, pMap)
assert.EqualError(t, err, "error initializing \"test\" provider_config: i'm throwing an error")
assert.Nil(t, theProvider)
})
Expand Down
33 changes: 14 additions & 19 deletions provider_gsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type GSuiteProviderConfig struct {
}

// Initialize initializes the GSuiteProvider by validating and creating configuration.
func (g *GSuiteProvider) Initialize(jc *jwtConfig) error {
func (g *GSuiteProvider) Initialize(ctx context.Context, jc *jwtConfig) error {
// Decode the provider config
var config GSuiteProviderConfig
if err := mapstructure.Decode(jc.ProviderConfig, &config); err != nil {
Expand All @@ -60,10 +60,10 @@ func (g *GSuiteProvider) Initialize(jc *jwtConfig) error {
}
config.serviceAccountKeyJSON = keyJSON

return g.initialize(config)
return g.initialize(ctx, config)
}

func (g *GSuiteProvider) initialize(config GSuiteProviderConfig) error {
func (g *GSuiteProvider) initialize(ctx context.Context, config GSuiteProviderConfig) error {
var err error

// Validate configuration
Expand All @@ -88,6 +88,13 @@ func (g *GSuiteProvider) initialize(config GSuiteProviderConfig) error {
// Set the subject to impersonate and config
g.jwtConfig.Subject = config.AdminImpersonateEmail
g.config = config

// Create a new admin service for requests to Google admin APIs
g.adminSvc, err = admin.NewService(ctx, option.WithHTTPClient(g.jwtConfig.Client(ctx)))
if err != nil {
return err
}

return nil
}

Expand All @@ -97,7 +104,7 @@ func (g *GSuiteProvider) SensitiveKeys() []string {
}

// FetchGroups fetches and returns groups from G Suite.
func (g *GSuiteProvider) FetchGroups(b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
func (g *GSuiteProvider) FetchGroups(ctx context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
if !g.config.FetchGroups {
return nil, nil
}
Expand All @@ -107,15 +114,9 @@ func (g *GSuiteProvider) FetchGroups(b *jwtAuthBackend, allClaims map[string]int
return nil, err
}

// Set context and create a new admin service for requests to Google admin APIs
g.adminSvc, err = admin.NewService(b.providerCtx, option.WithHTTPClient(g.jwtConfig.Client(b.providerCtx)))
if err != nil {
return nil, err
}

// Get the G Suite groups
userGroupsMap := make(map[string]bool)
if err := g.search(b.providerCtx, userGroupsMap, userName, g.config.GroupsRecurseMaxDepth); err != nil {
if err := g.search(ctx, userGroupsMap, userName, g.config.GroupsRecurseMaxDepth); err != nil {
return nil, err
}

Expand Down Expand Up @@ -157,7 +158,7 @@ func (g *GSuiteProvider) search(ctx context.Context, visited map[string]bool, us
}

// FetchUserInfo fetches additional user information from G Suite using custom schemas.
func (g *GSuiteProvider) FetchUserInfo(b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) error {
func (g *GSuiteProvider) FetchUserInfo(ctx context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) error {
if !g.config.FetchUserInfo || g.config.UserCustomSchemas == "" {
if g.config.UserCustomSchemas != "" {
b.Logger().Warn(fmt.Sprintf("must set 'fetch_user_info=true' to fetch 'user_custom_schemas': %s", g.config.UserCustomSchemas))
Expand All @@ -171,13 +172,7 @@ func (g *GSuiteProvider) FetchUserInfo(b *jwtAuthBackend, allClaims map[string]i
return err
}

// Set context and create a new admin service for requests to Google admin APIs
g.adminSvc, err = admin.NewService(b.providerCtx, option.WithHTTPClient(g.jwtConfig.Client(b.providerCtx)))
if err != nil {
return err
}

return g.fillCustomSchemas(b.providerCtx, userName, allClaims)
return g.fillCustomSchemas(ctx, userName, allClaims)
}

// fillCustomSchemas fetches G Suite user information associated with the custom schemas
Expand Down
Loading