Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Inject and Use values from Security Context (#153)
Browse files Browse the repository at this point in the history
* Inject and Use values from Security Context

Signed-off-by: Anand Swaminathan <[email protected]>
  • Loading branch information
anandswaminathan authored Mar 17, 2021
1 parent 5dfe694 commit 93ed208
Show file tree
Hide file tree
Showing 17 changed files with 204 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ type TaskExecutionMetadata interface {
GetMaxAttempts() uint32
GetAnnotations() map[string]string
GetK8sServiceAccount() string
GetSecurityContext() core.SecurityContext
IsInterruptible() bool
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package flytek8s

import (
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
pluginmachinery_core "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
v1 "k8s.io/api/core/v1"
)

Expand All @@ -12,3 +13,18 @@ func ToK8sEnvVar(env []*core.KeyValuePair) []v1.EnvVar {
}
return envVars
}

func GetServiceAccountNameFromTaskExecutionMetadata(taskExecutionMetadata pluginmachinery_core.TaskExecutionMetadata) string {
var serviceAccount string
securityContext := taskExecutionMetadata.GetSecurityContext()
if securityContext.GetRunAs() != nil {
serviceAccount = securityContext.GetRunAs().GetK8SServiceAccount()
}

// TO BE DEPRECATED
if len(serviceAccount) == 0 {
serviceAccount = taskExecutionMetadata.GetK8sServiceAccount()
}

return serviceAccount
}
26 changes: 26 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/utils_test.go
Original file line number Diff line number Diff line change
@@ -1 +1,27 @@
package flytek8s

import (
"testing"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"

"github.com/stretchr/testify/assert"
)

func TestGetServiceAccountNameFromTaskExecutionMetadata(t *testing.T) {
mockTaskExecMetadata := mocks.TaskExecutionMetadata{}
mockTaskExecMetadata.OnGetSecurityContext().Return(core.SecurityContext{
RunAs: &core.Identity{K8SServiceAccount: "service-account"},
})
result := GetServiceAccountNameFromTaskExecutionMetadata(&mockTaskExecMetadata)
assert.Equal(t, "service-account", result)
}

func TestGetServiceAccountNameFromServiceAccount(t *testing.T) {
mockTaskExecMetadata := mocks.TaskExecutionMetadata{}
mockTaskExecMetadata.OnGetSecurityContext().Return(core.SecurityContext{})
mockTaskExecMetadata.OnGetK8sServiceAccount().Return("service-account")
result := GetServiceAccountNameFromTaskExecutionMetadata(&mockTaskExecMetadata)
assert.Equal(t, "service-account", result)
}
7 changes: 4 additions & 3 deletions flyteplugins/go/tasks/plugins/array/awsbatch/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ type Config struct {
// Provide additional environment variable pairs that plugin authors will provide to containers
DefaultEnvVars map[string]string `json:"defaultEnvVars" pflag:"-,Additional environment variable that should be injected into every resource"`
MaxErrorStringLength int `json:"maxErrLength" pflag:",Determines the maximum length of the error string returned for the array."`
RoleAnnotationKey string `json:"roleAnnotationKey" pflag:",Map key to use to lookup role from task annotations."`
OutputAssembler workqueue.Config `json:"outputAssembler"`
ErrorAssembler workqueue.Config `json:"errorAssembler"`
// This can be deprecated. Just having it for backward compatibility
RoleAnnotationKey string `json:"roleAnnotationKey" pflag:",Map key to use to lookup role from task annotations."`
OutputAssembler workqueue.Config `json:"outputAssembler"`
ErrorAssembler workqueue.Config `json:"errorAssembler"`
}

type JobStoreConfig struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionConte
return nil, errors.Errorf(pluginErrors.BadTaskSpecification, "Tasktemplate does not contain a container image.")
}

role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata().GetAnnotations())
role := awsUtils.GetRoleFromSecurityContext(cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata())

cacheKey := definition.NewCacheKey(role, containerImage)
if existingArn, found := definitionCache.Get(cacheKey); found {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestEnsureJobDefinition(t *testing.T) {
tMeta.OnGetTaskExecutionID().Return(tID)
tMeta.OnGetOverrides().Return(overrides)
tMeta.OnGetAnnotations().Return(map[string]string{})

tMeta.OnGetSecurityContext().Return(core.SecurityContext{})
tCtx := &mocks.TaskExecutionContext{}
tCtx.OnTaskReader().Return(tReader)
tCtx.OnTaskExecutionMetadata().Return(tMeta)
Expand Down Expand Up @@ -101,3 +101,70 @@ func TestEnsureJobDefinition(t *testing.T) {
assert.Equal(t, "their-arn", nextState.JobDefinitionArn)
})
}

func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) {
ctx := context.Background()

tReader := &mocks.TaskReader{}
tReader.OnReadMatch(mock.Anything).Return(&core.TaskTemplate{
Interface: &core.TypedInterface{
Outputs: &core.VariableMap{
Variables: map[string]*core.Variable{"var1": {Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}},
},
},
Target: &core.TaskTemplate_Container{
Container: createSampleContainerTask(),
},
}, nil)

overrides := &mocks.TaskOverrides{}
overrides.OnGetConfig().Return(&v1.ConfigMap{Data: map[string]string{
DynamicTaskQueueKey: "queue1",
}})

tID := &mocks.TaskExecutionID{}
tID.OnGetGeneratedName().Return("found")

tMeta := &mocks.TaskExecutionMetadata{}
tMeta.OnGetTaskExecutionID().Return(tID)
tMeta.OnGetOverrides().Return(overrides)
tMeta.OnGetAnnotations().Return(map[string]string{})
tMeta.OnGetSecurityContext().Return(core.SecurityContext{
RunAs: &core.Identity{IamRole: "new-role"},
})
tCtx := &mocks.TaskExecutionContext{}
tCtx.OnTaskReader().Return(tReader)
tCtx.OnTaskExecutionMetadata().Return(tMeta)

cfg := &config.Config{}
batchClient := NewCustomBatchClient(batchMocks.NewMockAwsBatchClient(), "", "",
utils.NewRateLimiter("", 10, 20),
utils.NewRateLimiter("", 10, 20))

t.Run("Not Found", func(t *testing.T) {
dCache := definition.NewCache(10)

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
})

assert.NoError(t, err)
assert.NotNil(t, nextState)
assert.Equal(t, "my-arn", nextState.JobDefinitionArn)
p, v := nextState.GetPhase()
assert.Equal(t, arrayCore.PhaseLaunch, p)
assert.Zero(t, v)
})

t.Run("Found", func(t *testing.T) {
dCache := definition.NewCache(10)
assert.NoError(t, dCache.Put(definition.NewCacheKey("new-role", "img1"), "their-arn"))

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
})
assert.NoError(t, err)
assert.NotNil(t, nextState)
assert.Equal(t, "their-arn", nextState.JobDefinitionArn)
})
}
24 changes: 20 additions & 4 deletions flyteplugins/go/tasks/plugins/awsutils/awsutils.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
package awsutils

import "context"
import (
core2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
)

func GetRole(_ context.Context, roleAnnotationKey string, annotations map[string]string) string {
if len(roleAnnotationKey) > 0 {
return annotations[roleAnnotationKey]
func GetRoleFromSecurityContext(roleKey string, taskExecutionMetadata core2.TaskExecutionMetadata) string {
var role string
securityContext := taskExecutionMetadata.GetSecurityContext()
if securityContext.GetRunAs() != nil {
role = securityContext.GetRunAs().GetIamRole()
}

// Continue this for backward compatibility
if len(role) == 0 {
role = getRole(roleKey, taskExecutionMetadata.GetAnnotations())
}
return role
}

func getRole(roleKey string, keyValueMap map[string]string) string {
if len(roleKey) > 0 {
return keyValueMap[roleKey]
}

return ""
Expand Down
3 changes: 1 addition & 2 deletions flyteplugins/go/tasks/plugins/k8s/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ func (Plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecuti

pod := flytek8s.BuildPodWithSpec(podSpec)

// We want to Also update the serviceAccount to the serviceaccount of the workflow
pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount()
pod.Spec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

return pod, nil
}
Expand Down
3 changes: 3 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/container/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore.
Name: "blah",
})
taskMetadata.On("GetK8sServiceAccount").Return("service-account")
taskMetadata.On("GetSecurityContext").Return(core.SecurityContext{
RunAs: &core.Identity{K8SServiceAccount: "service-account"},
})
taskMetadata.On("GetOwnerID").Return(types.NamespacedName{
Namespace: "test-namespace",
Name: "test-owner-name",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ func (m awsSagemakerPlugin) buildResourceForTrainingJob(

inputModeString := strings.Title(strings.ToLower(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputMode().String()))

role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations())
if role == "" {
role := awsUtils.GetRoleFromSecurityContext(cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata())

if len(role) == 0 {
role = cfg.RoleArn
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob(
tuningObjectiveTypeString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningObjective().GetObjectiveType().String()))
trainingJobEarlyStoppingTypeString := strings.Title(strings.ToLower(hpoJobConfig.TrainingJobEarlyStoppingType.String()))

role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations())
if role == "" {
role := awsUtils.GetRoleFromSecurityContext(cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata())

if len(role) == 0 {
role = cfg.RoleArn
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ func generateMockCustomTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTem
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"iam.amazonaws.com/role": "metadata_role"})
taskExecutionMetadata.OnGetSecurityContext().Return(flyteIdlCore.SecurityContext{
RunAs: &flyteIdlCore.Identity{IamRole: "new-role"},
})

taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Expand Down Expand Up @@ -270,6 +274,7 @@ func generateMockTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate,
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"iam.amazonaws.com/role": "metadata_role"})
taskExecutionMetadata.OnGetSecurityContext().Return(flyteIdlCore.SecurityContext{})
taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Expand Down Expand Up @@ -353,6 +358,7 @@ func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.T
outputReader.OnGetOutputPath().Return(storage.DataReference("/data/outputs.pb"))
outputReader.OnGetOutputPrefixPath().Return(storage.DataReference("/data/"))
outputReader.OnGetRawOutputPrefix().Return(storage.DataReference("/raw/"))

taskCtx.OnOutputWriter().Return(outputReader)

taskReader := &mocks.TaskReader{}
Expand Down Expand Up @@ -384,6 +390,9 @@ func genMockTaskExecutionMetadata() *mocks.TaskExecutionMetadata {
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"iam.amazonaws.com/role": "metadata_role"})
taskExecutionMetadata.OnGetSecurityContext().Return(flyteIdlCore.SecurityContext{
RunAs: &flyteIdlCore.Identity{IamRole: "default_role"},
})
taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Expand Down
3 changes: 1 addition & 2 deletions flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx plugins
// CrashLoopBackoff after the initial job completion.
pod.Spec.RestartPolicy = k8sv1.RestartPolicyNever

// We want to also update the serviceAccount to the serviceaccount of the workflow
pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount()
pod.Spec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

pod, err = validateAndFinalizePod(ctx, taskCtx, primaryContainerName, *pod)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore.
taskMetadata := &pluginsCoreMock.TaskExecutionMetadata{}
taskMetadata.On("GetNamespace").Return("test-namespace")
taskMetadata.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"})

taskMetadata.On("GetLabels").Return(map[string]string{"label-1": "val1"})
taskMetadata.On("GetOwnerReference").Return(metav1.OwnerReference{
Kind: "node",
Name: "blah",
})
taskMetadata.On("IsInterruptible").Return(true)
taskMetadata.On("GetSecurityContext").Return(core.SecurityContext{})
taskMetadata.On("GetK8sServiceAccount").Return("service-account")
taskMetadata.On("GetOwnerID").Return(types.NamespacedName{
Namespace: "test-namespace",
Expand Down Expand Up @@ -319,6 +321,7 @@ func TestBuildSidecarResource(t *testing.T) {
assert.Len(t, res.(*v1.Pod).Spec.Containers[0].VolumeMounts, 1)
assert.Equal(t, "volume mount", res.(*v1.Pod).Spec.Containers[0].VolumeMounts[0].Name)

assert.Equal(t, "service-account", res.(*v1.Pod).Spec.ServiceAccountName)
// Assert user-specified tolerations don't get overridden
assert.Len(t, res.(*v1.Pod).Spec.Tolerations, 1)
for _, tol := range res.(*v1.Pod).Spec.Tolerations {
Expand Down
9 changes: 7 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,19 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
}
sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts()))

serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

if len(serviceAccountName) == 0 {
serviceAccountName = sparkTaskType
}
driverSpec := sparkOp.DriverSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Annotations: annotations,
Labels: labels,
EnvVars: sparkEnvVars,
Image: &container.Image,
},
ServiceAccount: &sparkTaskType,
ServiceAccount: &serviceAccountName,
}

executorSpec := sparkOp.ExecutorSpec{
Expand Down Expand Up @@ -184,7 +189,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
APIVersion: sparkOp.SchemeGroupVersion.String(),
},
Spec: sparkOp.SparkApplicationSpec{
ServiceAccount: &sparkTaskType,
ServiceAccount: &serviceAccountName,
Type: getApplicationType(sparkJob.GetApplicationType()),
Image: &container.Image,
Arguments: modifiedArgs,
Expand Down
Loading

0 comments on commit 93ed208

Please sign in to comment.