Skip to content

Commit

Permalink
Inject user identifier to ExecutionSpec (flyteorg#549)
Browse files Browse the repository at this point in the history
Signed-off-by: byhsu <[email protected]>
This pr provides a middleware to inject user identifier to ExecutionSpec.
By default, the value of the user identifier is userid from access/id token.
Users can customize their own middleware and inject different values.
  • Loading branch information
ByronHsu authored May 15, 2023
1 parent 2fdd399 commit 610451d
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 7 deletions.
15 changes: 15 additions & 0 deletions auth/identity_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ type IdentityContext struct {
scopes *sets.String
// Raw JWT token from the IDP. Set to a pointer to support the equal operator for this struct.
claims *claimsType
// executionIdentity stores a unique string that can be used to identify the user associated with a given task.
// This identifier is passed down to the ExecutionSpec and can be used for various purposes, such as setting the user identifier on a pod label.
// By default, the execution user identifier is filled with the value of IdentityContext.userID. However, you can customize your middleware to assign other values if needed.
// Providing a user identifier can be useful for tracking tasks and associating them with specific users, especially in multi-user environments.
executionIdentity string
}

func (c IdentityContext) Audience() string {
Expand Down Expand Up @@ -81,6 +86,16 @@ func (c IdentityContext) AuthenticatedAt() time.Time {
return c.authenticatedAt
}

func (c IdentityContext) ExecutionIdentity() string {
return c.executionIdentity
}

// WithExecutionUserIdentifier creates a copy of the original identity context and attach ExecutionIdentity
func (c IdentityContext) WithExecutionUserIdentifier(euid string) IdentityContext {
c.executionIdentity = euid
return c
}

// NewIdentityContext creates a new IdentityContext.
func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes sets.String, userInfo *service.UserInfoResponse, claims map[string]interface{}) (
IdentityContext, error) {
Expand Down
10 changes: 10 additions & 0 deletions auth/identity_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/util/sets"
)

func TestGetClaims(t *testing.T) {
Expand All @@ -23,3 +24,12 @@ func TestGetClaims(t *testing.T) {

assert.NotEmpty(t, withClaimsCtx.UserInfo().AdditionalClaims)
}

func TestWithExecutionUserIdentifier(t *testing.T) {
idctx, err := NewIdentityContext("", "", "", time.Now(), sets.String{}, nil, nil)
assert.NoError(t, err)
newIDCtx := idctx.WithExecutionUserIdentifier("byhsu")
// make sure the original one is intact
assert.Equal(t, "", idctx.ExecutionIdentity())
assert.Equal(t, "byhsu", newIDCtx.ExecutionIdentity())
}
9 changes: 9 additions & 0 deletions auth/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,12 @@ func BlanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnarySer

return handler(ctx, req)
}

// ExecutionUserIdentifierInterceptor injects identityContext.UserID() to identityContext.executionIdentity
func ExecutionUserIdentifierInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
resp interface{}, err error) {
identityContext := IdentityContextFromContext(ctx)
identityContext = identityContext.WithExecutionUserIdentifier(identityContext.UserID())
ctx = identityContext.WithContext(ctx)
return handler(ctx, req)
}
18 changes: 18 additions & 0 deletions auth/interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,21 @@ func TestBlanketAuthorization(t *testing.T) {
assert.False(t, handlerCalled)
})
}

func TestGetUserIdentityFromContext(t *testing.T) {
identityContext := IdentityContext{
userID: "yeee",
}

ctx := identityContext.WithContext(context.Background())

handler := func(ctx context.Context, req interface{}) (interface{}, error) {
identityContext := IdentityContextFromContext(ctx)
euid := identityContext.ExecutionIdentity()
assert.Equal(t, euid, "yeee")
return nil, nil
}

_, err := ExecutionUserIdentifierInterceptor(ctx, nil, nil, handler)
assert.NoError(t, err)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/cloudevents/sdk-go/v2 v2.8.0
github.com/coreos/go-oidc v2.2.1+incompatible
github.com/evanphx/json-patch v4.12.0+incompatible
github.com/flyteorg/flyteidl v1.5.3
github.com/flyteorg/flyteidl v1.5.5
github.com/flyteorg/flyteplugins v1.0.56
github.com/flyteorg/flytepropeller v1.1.87
github.com/flyteorg/flytestdlib v1.0.15
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ=
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/flyteorg/flyteidl v1.5.3 h1:qHyU9kvcxGIkXoloi768ayx9FHrs961dZC3WYziGGZA=
github.com/flyteorg/flyteidl v1.5.3/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I=
github.com/flyteorg/flyteidl v1.5.5 h1:tNNhuXPog4atAMSGE2kyAg6JzYy1TvjqrrQeh1EZVHs=
github.com/flyteorg/flyteidl v1.5.5/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og=
github.com/flyteorg/flyteplugins v1.0.56 h1:kBTDgTpdSi7wcptk4cMwz5vfh1MU82VaUMMboe1InXw=
github.com/flyteorg/flyteplugins v1.0.56/go.mod h1:aFCKSn8TPzxSAILIiogHtUnHlUCN9+y6Vf+r9R4KZDU=
github.com/flyteorg/flytepropeller v1.1.87 h1:Px7ASDjrWyeVrUb15qXmhw9QK7xPcFjL5Yetr2P6iGM=
Expand Down
15 changes: 14 additions & 1 deletion pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,18 @@ func (m *ExecutionManager) getExecutionConfig(ctx context.Context, request *admi
RunAs: &core.Identity{},
}
}

if workflowExecConfig.GetSecurityContext().GetRunAs() == nil {
workflowExecConfig.SecurityContext.RunAs = &core.Identity{}
}

// In the case of reference_launch_plan subworkflow, the context comes from flytepropeller instead of the user side, so user auth is missing.
// We skip getUserIdentityFromContext but can still get ExecUserId because flytepropeller passes it in the execution request.
// https://github.com/flyteorg/flytepropeller/blob/03a6672960ed04e7687ba4f790fee9a02a4057fb/pkg/controller/nodes/subworkflow/launchplan/admin.go#L114
if workflowExecConfig.GetSecurityContext().GetRunAs().GetExecutionIdentity() == "" {
workflowExecConfig.SecurityContext.RunAs.ExecutionIdentity = auth.IdentityContextFromContext(ctx).ExecutionIdentity()
}

logger.Infof(ctx, "getting the workflow execution config from application configuration")
// Defaults to one from the application config
return &workflowExecConfig, nil
Expand Down Expand Up @@ -676,7 +688,8 @@ func resolveSecurityCtx(ctx context.Context, executionConfigSecurityCtx *core.Se
// Use security context from the executionConfigSecurityCtx if its set and non empty or else resolve from authRole
if executionConfigSecurityCtx != nil && executionConfigSecurityCtx.RunAs != nil &&
(len(executionConfigSecurityCtx.RunAs.K8SServiceAccount) > 0 ||
len(executionConfigSecurityCtx.RunAs.IamRole) > 0) {
len(executionConfigSecurityCtx.RunAs.IamRole) > 0 ||
len(executionConfigSecurityCtx.RunAs.ExecutionIdentity) > 0) {
return executionConfigSecurityCtx
}
logger.Warn(ctx, "Setting security context from auth Role")
Expand Down
7 changes: 6 additions & 1 deletion pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4540,7 +4540,11 @@ func TestGetExecutionConfigOverrides(t *testing.T) {
Envs: &admin.Envs{Values: requestEnvironmentVariables},
},
}
execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, nil)
identityContext, err := auth.NewIdentityContext("", "", "", time.Now(), sets.String{}, nil, nil)
assert.NoError(t, err)
identityContext = identityContext.WithExecutionUserIdentifier("yeee")
ctx := identityContext.WithContext(context.Background())
execConfig, err := executionManager.getExecutionConfig(ctx, request, nil)
assert.NoError(t, err)
assert.Equal(t, requestMaxParallelism, execConfig.MaxParallelism)
assert.Equal(t, requestK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount)
Expand All @@ -4549,6 +4553,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) {
assert.Equal(t, requestOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix)
assert.Equal(t, requestLabels, execConfig.GetLabels().Values)
assert.Equal(t, requestAnnotations, execConfig.GetAnnotations().Values)
assert.Equal(t, "yeee", execConfig.GetSecurityContext().GetRunAs().GetExecutionIdentity())
assert.Equal(t, requestEnvironmentVariables, execConfig.GetEnvs().Values)
})
t.Run("request with partial config", func(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion pkg/manager/impl/util/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ func MergeIntoExecConfig(workflowExecConfig admin.WorkflowExecutionConfig, spec
if workflowExecConfig.GetSecurityContext() == nil && spec.GetSecurityContext() != nil {
if spec.GetSecurityContext().GetRunAs() != nil &&
(len(spec.GetSecurityContext().GetRunAs().GetK8SServiceAccount()) > 0 ||
len(spec.GetSecurityContext().GetRunAs().GetIamRole()) > 0) {
len(spec.GetSecurityContext().GetRunAs().GetIamRole()) > 0 ||
len(spec.GetSecurityContext().GetRunAs().GetExecutionIdentity()) > 0) {
workflowExecConfig.SecurityContext = spec.GetSecurityContext()
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
scope promutils.Scope, opts ...grpc.ServerOption) (*grpc.Server, error) {

logger.Infof(ctx, "Registering default middleware with blanket auth validation")
pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(auth.BlanketAuthorization))
pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(auth.BlanketAuthorization, auth.ExecutionUserIdentifierInterceptor))

// Not yet implemented for streaming
var chainedUnaryInterceptors grpc.UnaryServerInterceptor
Expand Down

0 comments on commit 610451d

Please sign in to comment.