Skip to content

Commit

Permalink
Forward all claims in userinfo response (flyteorg#511)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Jan 10, 2023
1 parent 82935b9 commit 0e60fbb
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 18 deletions.
2 changes: 1 addition & 1 deletion flyteadmin/auth/authzserver/claims_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{}
scopes.Insert(auth.ScopeAll)
}

return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo, claimsRaw), nil
return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo, claimsRaw)
}
12 changes: 11 additions & 1 deletion flyteadmin/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,18 @@ func GetUserInfoForwardResponseHandler() UserInfoForwardResponseHandler {
return func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error {
info, ok := m.(*service.UserInfoResponse)
if ok {
if info.AdditionalClaims != nil {
for k, v := range info.AdditionalClaims.GetFields() {
jsonBytes, err := v.MarshalJSON()
if err != nil {
logger.Warningf(ctx, "failed to marshal claim [%s] to json: %v", k, err)
continue
}
header := fmt.Sprintf("X-User-Claim-%s", strings.ReplaceAll(k, "_", "-"))
w.Header().Set(header, string(jsonBytes))
}
}
w.Header().Set("X-User-Subject", info.Subject)
w.Header().Set("X-User-Name", info.Name)
}
return nil
}
Expand Down
20 changes: 14 additions & 6 deletions flyteadmin/auth/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"strings"
"testing"

"google.golang.org/protobuf/types/known/structpb"

"github.com/flyteorg/flyteadmin/auth/config"
"github.com/flyteorg/flyteadmin/auth/interfaces/mocks"
"github.com/flyteorg/flyteadmin/pkg/common"
Expand Down Expand Up @@ -300,20 +302,26 @@ func TestUserInfoForwardResponseHander(t *testing.T) {
ctx := context.Background()
handler := GetUserInfoForwardResponseHandler()
w := httptest.NewRecorder()
additionalClaims := map[string]interface{}{
"cid": "cid-id",
"ver": 1,
}
additionalClaimsStruct, err := structpb.NewStruct(additionalClaims)
assert.NoError(t, err)
resp := service.UserInfoResponse{
Subject: "user-id",
Name: "User Name",
Subject: "user-id",
AdditionalClaims: additionalClaimsStruct,
}
assert.NoError(t, handler(ctx, w, &resp))
assert.Contains(t, w.Result().Header, "X-User-Subject")
assert.Equal(t, w.Result().Header["X-User-Subject"], []string{"user-id"})

assert.Contains(t, w.Result().Header, "X-User-Name")
assert.Equal(t, w.Result().Header["X-User-Name"], []string{"User Name"})
assert.Contains(t, w.Result().Header, "X-User-Claim-Cid")
assert.Equal(t, w.Result().Header["X-User-Claim-Cid"], []string{"\"cid-id\""})
assert.Contains(t, w.Result().Header, "X-User-Claim-Ver")
assert.Equal(t, w.Result().Header["X-User-Claim-Ver"], []string{"1"})

w = httptest.NewRecorder()
unrelatedResp := service.OAuth2MetadataResponse{}
assert.NoError(t, handler(ctx, w, &unrelatedResp))
assert.NotContains(t, w.Result().Header, "X-User-Subject")
assert.NotContains(t, w.Result().Header, "X-User-Name")
}
16 changes: 14 additions & 2 deletions flyteadmin/auth/identity_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package auth

import (
"context"
"fmt"
"time"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"

"k8s.io/apimachinery/pkg/util/sets"
Expand Down Expand Up @@ -79,7 +82,8 @@ func (c IdentityContext) AuthenticatedAt() time.Time {
}

// NewIdentityContext creates a new IdentityContext.
func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes sets.String, userInfo *service.UserInfoResponse, claims map[string]interface{}) IdentityContext {
func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes sets.String, userInfo *service.UserInfoResponse, claims map[string]interface{}) (
IdentityContext, error) {
// For some reason, google IdP returns a subject in the ID Token but an empty subject in the /user_info endpoint
if userInfo == nil {
userInfo = &service.UserInfoResponse{}
Expand All @@ -89,6 +93,14 @@ func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Tim
userInfo.Subject = userID
}

if len(claims) > 0 {
claimsStruct, err := utils.MarshalObjToStruct(claims)
if err != nil {
return IdentityContext{}, fmt.Errorf("failed to marshal claims [%+v] to struct: %w", claims, err)
}
userInfo.AdditionalClaims = claimsStruct
}

return IdentityContext{
audience: audience,
userID: userID,
Expand All @@ -97,7 +109,7 @@ func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Tim
authenticatedAt: authenticatedAt,
scopes: &scopes,
claims: &claims,
}
}, nil
}

// IdentityContextFromContext retrieves the authenticated identity from context.Context.
Expand Down
8 changes: 6 additions & 2 deletions flyteadmin/auth/identity_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ import (

func TestGetClaims(t *testing.T) {
noClaims := map[string]interface{}(nil)
noClaimsCtx := NewIdentityContext("", "", "", time.Now(), nil, nil, nil)
noClaimsCtx, err := NewIdentityContext("", "", "", time.Now(), nil, nil, nil)
assert.NoError(t, err)
assert.EqualValues(t, noClaims, noClaimsCtx.Claims())

claims := map[string]interface{}{
"groups": []string{"g1", "g2"},
"something": "else",
}
withClaimsCtx := NewIdentityContext("", "", "", time.Now(), nil, nil, claims)
withClaimsCtx, err := NewIdentityContext("", "", "", time.Now(), nil, nil, claims)
assert.NoError(t, err)
assert.EqualValues(t, claims, withClaimsCtx.Claims())

assert.NotEmpty(t, withClaimsCtx.UserInfo().AdditionalClaims)
}
2 changes: 1 addition & 1 deletion flyteadmin/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,5 @@ func IdentityContextFromIDTokenToken(ctx context.Context, tokenStr, clientID str

// TODO: Document why automatically specify "all" scope
return NewIdentityContext(idToken.Audience[0], idToken.Subject, "", idToken.IssuedAt,
sets.NewString(ScopeAll), userInfo, claims), nil
sets.NewString(ScopeAll), userInfo, claims)
}
2 changes: 1 addition & 1 deletion flyteadmin/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/cloudevents/sdk-go/v2 v2.8.0
github.com/coreos/go-oidc v2.2.1+incompatible
github.com/evanphx/json-patch v4.12.0+incompatible
github.com/flyteorg/flyteidl v1.2.5
github.com/flyteorg/flyteidl v1.3.3
github.com/flyteorg/flyteplugins v1.0.20
github.com/flyteorg/flytepropeller v1.1.51
github.com/flyteorg/flytestdlib v1.0.14
Expand Down
4 changes: 2 additions & 2 deletions flyteadmin/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga
github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94=
github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ=
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/flyteorg/flyteidl v1.2.5 h1:oPs0PX9opR9JtWjP5ZH2YMChkbGGL45PIy+90FlaxYc=
github.com/flyteorg/flyteidl v1.2.5/go.mod h1:OJAq333OpInPnMhvVz93AlEjmlQ+t0FAD4aakIYE4OU=
github.com/flyteorg/flyteidl v1.3.3 h1:WOhTcFTr6r67gelUyxU9OH0SIJG7Pj+VtNFObj3/pnc=
github.com/flyteorg/flyteidl v1.3.3/go.mod h1:OJAq333OpInPnMhvVz93AlEjmlQ+t0FAD4aakIYE4OU=
github.com/flyteorg/flyteplugins v1.0.20 h1:8ZGN2c0iaZa3d/UmN2VYozLBRhthAIO48aD5g8Wly7s=
github.com/flyteorg/flyteplugins v1.0.20/go.mod h1:ZbZVBxEWh8Icj1AgfNKg0uPzHHGd9twa4eWcY2Yt6xE=
github.com/flyteorg/flytepropeller v1.1.51 h1:ITPH2Fqx+/1hKBFnfb6Rawws3VbEJ3tQ/1tQXSIXvcQ=
Expand Down
6 changes: 4 additions & 2 deletions flyteadmin/pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ func TestCreateExecution(t *testing.T) {
request.Spec.RawOutputDataConfig = &admin.RawOutputDataConfig{OutputLocationPrefix: rawOutput}
request.Spec.ClusterAssignment = &clusterAssignment

identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
identity, err := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
assert.NoError(t, err)
ctx := identity.WithContext(context.Background())
response, err := execManager.CreateExecution(ctx, request, requestedAt)
assert.Nil(t, err)
Expand Down Expand Up @@ -3122,7 +3123,8 @@ func TestTerminateExecution(t *testing.T) {
r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &mockExecutor)
execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})

identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
identity, err := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
assert.NoError(t, err)
ctx := identity.WithContext(context.Background())
resp, err := execManager.TerminateExecution(ctx, admin.ExecutionTerminateRequest{
Id: &core.WorkflowExecutionIdentifier{
Expand Down

0 comments on commit 0e60fbb

Please sign in to comment.