Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Pass along raw claims in identity context (#447)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Jun 16, 2022
1 parent 047a80f commit afb3383
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 6 deletions.
2 changes: 1 addition & 1 deletion flyteadmin/auth/authzserver/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{}
scopes.Insert(auth.ScopeAll)
}

return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo), nil
return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo, claimsRaw), nil
}

// NewProvider creates a new OAuth2 Provider that is able to do OAuth 2-legged and 3-legged flows. It'll lookup
Expand Down
14 changes: 13 additions & 1 deletion flyteadmin/auth/identity_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ var (
emptyIdentityContext = IdentityContext{}
)

type claimsType = map[string]interface{}

// IdentityContext is an abstract entity to enclose the authenticated identity of the user/app. Both gRPC and HTTP
// servers have interceptors to set the IdentityContext on the context.Context.
// To retrieve the current IdentityContext call auth.IdentityContextFromContext(ctx).
Expand All @@ -25,6 +27,8 @@ type IdentityContext struct {
userInfo *service.UserInfoResponse
// Set to pointer just to keep this struct go-simple to support equal operator
scopes *sets.String
// Raw JWT token from the IDP. Set to a pointer to support the equal operator for this struct.
claims *claimsType
}

func (c IdentityContext) Audience() string {
Expand Down Expand Up @@ -59,6 +63,13 @@ func (c IdentityContext) Scopes() sets.String {
return sets.NewString()
}

func (c IdentityContext) Claims() map[string]interface{} {
if c.claims != nil {
return *c.claims
}
return make(map[string]interface{})
}

func (c IdentityContext) WithContext(ctx context.Context) context.Context {
return context.WithValue(ctx, ContextKeyIdentityContext, c)
}
Expand All @@ -68,7 +79,7 @@ func (c IdentityContext) AuthenticatedAt() time.Time {
}

// NewIdentityContext creates a new IdentityContext.
func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes sets.String, userInfo *service.UserInfoResponse) IdentityContext {
func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes sets.String, userInfo *service.UserInfoResponse, claims map[string]interface{}) IdentityContext {
// For some reason, google IdP returns a subject in the ID Token but an empty subject in the /user_info endpoint
if userInfo == nil {
userInfo = &service.UserInfoResponse{}
Expand All @@ -85,6 +96,7 @@ func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Tim
userInfo: userInfo,
authenticatedAt: authenticatedAt,
scopes: &scopes,
claims: &claims,
}
}

Expand Down
21 changes: 21 additions & 0 deletions flyteadmin/auth/identity_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package auth

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestGetClaims(t *testing.T) {
noClaims := map[string]interface{}(nil)
noClaimsCtx := NewIdentityContext("", "", "", time.Now(), nil, nil, nil)
assert.EqualValues(t, noClaims, noClaimsCtx.Claims())

claims := map[string]interface{}{
"groups": []string{"g1", "g2"},
"something": "else",
}
withClaimsCtx := NewIdentityContext("", "", "", time.Now(), nil, nil, claims)
assert.EqualValues(t, claims, withClaimsCtx.Claims())
}
2 changes: 2 additions & 0 deletions flyteadmin/auth/interfaces/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ type IdentityContext interface {
UserInfo() *service.UserInfoResponse
AuthenticatedAt() time.Time
Scopes() sets.String
// Returns the full set of claims in the JWT token provided by the IDP.
Claims() map[string]interface{}

IsEmpty() bool
WithContext(ctx context.Context) context.Context
Expand Down
34 changes: 34 additions & 0 deletions flyteadmin/auth/interfaces/mocks/identity_context.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions flyteadmin/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ func ParseIDTokenAndValidate(ctx context.Context, clientID, rawIDToken string, p

return idToken, flyteErr
}

return idToken, nil
}

Expand Down Expand Up @@ -130,8 +129,12 @@ func IdentityContextFromIDTokenToken(ctx context.Context, tokenStr, clientID str
if err != nil {
return nil, err
}
var claims map[string]interface{}
if err := idToken.Claims(&claims); err != nil {
logger.Infof(ctx, "Failed to unmarshal claims from id token, err: %v", err)
}

// TODO: Document why automatically specify "all" scope
return NewIdentityContext(idToken.Audience[0], idToken.Subject, "", idToken.IssuedAt,
sets.NewString(ScopeAll), userInfo), nil
sets.NewString(ScopeAll), userInfo, claims), nil
}
4 changes: 2 additions & 2 deletions flyteadmin/pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ func TestCreateExecution(t *testing.T) {
request.Spec.RawOutputDataConfig = &admin.RawOutputDataConfig{OutputLocationPrefix: rawOutput}
request.Spec.ClusterAssignment = &clusterAssignment

identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil)
identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
ctx := identity.WithContext(context.Background())
response, err := execManager.CreateExecution(ctx, request, requestedAt)
assert.Nil(t, err)
Expand Down Expand Up @@ -2834,7 +2834,7 @@ func TestTerminateExecution(t *testing.T) {
r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &mockExecutor)
execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})

identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil)
identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
ctx := identity.WithContext(context.Background())
resp, err := execManager.TerminateExecution(ctx, admin.ExecutionTerminateRequest{
Id: &core.WorkflowExecutionIdentifier{
Expand Down

0 comments on commit afb3383

Please sign in to comment.