Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Inject and Use values from Security Context (#153)
Browse files Browse the repository at this point in the history
* Inject and Use values from Security Context

Signed-off-by: Anand Swaminathan <[email protected]>
  • Loading branch information
anandswaminathan authored Mar 17, 2021
1 parent 030cdef commit d0a6ee2
Show file tree
Hide file tree
Showing 17 changed files with 204 additions and 19 deletions.
1 change: 1 addition & 0 deletions go/tasks/pluginmachinery/core/exec_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ type TaskExecutionMetadata interface {
GetMaxAttempts() uint32
GetAnnotations() map[string]string
GetK8sServiceAccount() string
GetSecurityContext() core.SecurityContext
IsInterruptible() bool
}
34 changes: 34 additions & 0 deletions go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions go/tasks/pluginmachinery/flytek8s/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package flytek8s

import (
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
pluginmachinery_core "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
v1 "k8s.io/api/core/v1"
)

Expand All @@ -12,3 +13,18 @@ func ToK8sEnvVar(env []*core.KeyValuePair) []v1.EnvVar {
}
return envVars
}

func GetServiceAccountNameFromTaskExecutionMetadata(taskExecutionMetadata pluginmachinery_core.TaskExecutionMetadata) string {
var serviceAccount string
securityContext := taskExecutionMetadata.GetSecurityContext()
if securityContext.GetRunAs() != nil {
serviceAccount = securityContext.GetRunAs().GetK8SServiceAccount()
}

// TO BE DEPRECATED
if len(serviceAccount) == 0 {
serviceAccount = taskExecutionMetadata.GetK8sServiceAccount()
}

return serviceAccount
}
26 changes: 26 additions & 0 deletions go/tasks/pluginmachinery/flytek8s/utils_test.go
Original file line number Diff line number Diff line change
@@ -1 +1,27 @@
package flytek8s

import (
"testing"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"

"github.com/stretchr/testify/assert"
)

func TestGetServiceAccountNameFromTaskExecutionMetadata(t *testing.T) {
mockTaskExecMetadata := mocks.TaskExecutionMetadata{}
mockTaskExecMetadata.OnGetSecurityContext().Return(core.SecurityContext{
RunAs: &core.Identity{K8SServiceAccount: "service-account"},
})
result := GetServiceAccountNameFromTaskExecutionMetadata(&mockTaskExecMetadata)
assert.Equal(t, "service-account", result)
}

func TestGetServiceAccountNameFromServiceAccount(t *testing.T) {
mockTaskExecMetadata := mocks.TaskExecutionMetadata{}
mockTaskExecMetadata.OnGetSecurityContext().Return(core.SecurityContext{})
mockTaskExecMetadata.OnGetK8sServiceAccount().Return("service-account")
result := GetServiceAccountNameFromTaskExecutionMetadata(&mockTaskExecMetadata)
assert.Equal(t, "service-account", result)
}
7 changes: 4 additions & 3 deletions go/tasks/plugins/array/awsbatch/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ type Config struct {
// Provide additional environment variable pairs that plugin authors will provide to containers
DefaultEnvVars map[string]string `json:"defaultEnvVars" pflag:"-,Additional environment variable that should be injected into every resource"`
MaxErrorStringLength int `json:"maxErrLength" pflag:",Determines the maximum length of the error string returned for the array."`
RoleAnnotationKey string `json:"roleAnnotationKey" pflag:",Map key to use to lookup role from task annotations."`
OutputAssembler workqueue.Config `json:"outputAssembler"`
ErrorAssembler workqueue.Config `json:"errorAssembler"`
// This can be deprecated. Just having it for backward compatibility
RoleAnnotationKey string `json:"roleAnnotationKey" pflag:",Map key to use to lookup role from task annotations."`
OutputAssembler workqueue.Config `json:"outputAssembler"`
ErrorAssembler workqueue.Config `json:"errorAssembler"`
}

type JobStoreConfig struct {
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/array/awsbatch/job_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionConte
return nil, errors.Errorf(pluginErrors.BadTaskSpecification, "Tasktemplate does not contain a container image.")
}

role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata().GetAnnotations())
role := awsUtils.GetRoleFromSecurityContext(cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata())

cacheKey := definition.NewCacheKey(role, containerImage)
if existingArn, found := definitionCache.Get(cacheKey); found {
Expand Down
69 changes: 68 additions & 1 deletion go/tasks/plugins/array/awsbatch/job_definition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestEnsureJobDefinition(t *testing.T) {
tMeta.OnGetTaskExecutionID().Return(tID)
tMeta.OnGetOverrides().Return(overrides)
tMeta.OnGetAnnotations().Return(map[string]string{})

tMeta.OnGetSecurityContext().Return(core.SecurityContext{})
tCtx := &mocks.TaskExecutionContext{}
tCtx.OnTaskReader().Return(tReader)
tCtx.OnTaskExecutionMetadata().Return(tMeta)
Expand Down Expand Up @@ -101,3 +101,70 @@ func TestEnsureJobDefinition(t *testing.T) {
assert.Equal(t, "their-arn", nextState.JobDefinitionArn)
})
}

func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) {
ctx := context.Background()

tReader := &mocks.TaskReader{}
tReader.OnReadMatch(mock.Anything).Return(&core.TaskTemplate{
Interface: &core.TypedInterface{
Outputs: &core.VariableMap{
Variables: map[string]*core.Variable{"var1": {Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}},
},
},
Target: &core.TaskTemplate_Container{
Container: createSampleContainerTask(),
},
}, nil)

overrides := &mocks.TaskOverrides{}
overrides.OnGetConfig().Return(&v1.ConfigMap{Data: map[string]string{
DynamicTaskQueueKey: "queue1",
}})

tID := &mocks.TaskExecutionID{}
tID.OnGetGeneratedName().Return("found")

tMeta := &mocks.TaskExecutionMetadata{}
tMeta.OnGetTaskExecutionID().Return(tID)
tMeta.OnGetOverrides().Return(overrides)
tMeta.OnGetAnnotations().Return(map[string]string{})
tMeta.OnGetSecurityContext().Return(core.SecurityContext{
RunAs: &core.Identity{IamRole: "new-role"},
})
tCtx := &mocks.TaskExecutionContext{}
tCtx.OnTaskReader().Return(tReader)
tCtx.OnTaskExecutionMetadata().Return(tMeta)

cfg := &config.Config{}
batchClient := NewCustomBatchClient(batchMocks.NewMockAwsBatchClient(), "", "",
utils.NewRateLimiter("", 10, 20),
utils.NewRateLimiter("", 10, 20))

t.Run("Not Found", func(t *testing.T) {
dCache := definition.NewCache(10)

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
})

assert.NoError(t, err)
assert.NotNil(t, nextState)
assert.Equal(t, "my-arn", nextState.JobDefinitionArn)
p, v := nextState.GetPhase()
assert.Equal(t, arrayCore.PhaseLaunch, p)
assert.Zero(t, v)
})

t.Run("Found", func(t *testing.T) {
dCache := definition.NewCache(10)
assert.NoError(t, dCache.Put(definition.NewCacheKey("new-role", "img1"), "their-arn"))

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
})
assert.NoError(t, err)
assert.NotNil(t, nextState)
assert.Equal(t, "their-arn", nextState.JobDefinitionArn)
})
}
24 changes: 20 additions & 4 deletions go/tasks/plugins/awsutils/awsutils.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
package awsutils

import "context"
import (
core2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
)

func GetRole(_ context.Context, roleAnnotationKey string, annotations map[string]string) string {
if len(roleAnnotationKey) > 0 {
return annotations[roleAnnotationKey]
func GetRoleFromSecurityContext(roleKey string, taskExecutionMetadata core2.TaskExecutionMetadata) string {
var role string
securityContext := taskExecutionMetadata.GetSecurityContext()
if securityContext.GetRunAs() != nil {
role = securityContext.GetRunAs().GetIamRole()
}

// Continue this for backward compatibility
if len(role) == 0 {
role = getRole(roleKey, taskExecutionMetadata.GetAnnotations())
}
return role
}

func getRole(roleKey string, keyValueMap map[string]string) string {
if len(roleKey) > 0 {
return keyValueMap[roleKey]
}

return ""
Expand Down
3 changes: 1 addition & 2 deletions go/tasks/plugins/k8s/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ func (Plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecuti

pod := flytek8s.BuildPodWithSpec(podSpec)

// We want to Also update the serviceAccount to the serviceaccount of the workflow
pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount()
pod.Spec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

return pod, nil
}
Expand Down
3 changes: 3 additions & 0 deletions go/tasks/plugins/k8s/container/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore.
Name: "blah",
})
taskMetadata.On("GetK8sServiceAccount").Return("service-account")
taskMetadata.On("GetSecurityContext").Return(core.SecurityContext{
RunAs: &core.Identity{K8SServiceAccount: "service-account"},
})
taskMetadata.On("GetOwnerID").Return(types.NamespacedName{
Namespace: "test-namespace",
Name: "test-owner-name",
Expand Down
5 changes: 3 additions & 2 deletions go/tasks/plugins/k8s/sagemaker/builtin_training.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ func (m awsSagemakerPlugin) buildResourceForTrainingJob(

inputModeString := strings.Title(strings.ToLower(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputMode().String()))

role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations())
if role == "" {
role := awsUtils.GetRoleFromSecurityContext(cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata())

if len(role) == 0 {
role = cfg.RoleArn
}

Expand Down
5 changes: 3 additions & 2 deletions go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob(
tuningObjectiveTypeString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningObjective().GetObjectiveType().String()))
trainingJobEarlyStoppingTypeString := strings.Title(strings.ToLower(hpoJobConfig.TrainingJobEarlyStoppingType.String()))

role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations())
if role == "" {
role := awsUtils.GetRoleFromSecurityContext(cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata())

if len(role) == 0 {
role = cfg.RoleArn
}

Expand Down
9 changes: 9 additions & 0 deletions go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ func generateMockCustomTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTem
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"iam.amazonaws.com/role": "metadata_role"})
taskExecutionMetadata.OnGetSecurityContext().Return(flyteIdlCore.SecurityContext{
RunAs: &flyteIdlCore.Identity{IamRole: "new-role"},
})

taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Expand Down Expand Up @@ -270,6 +274,7 @@ func generateMockTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate,
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"iam.amazonaws.com/role": "metadata_role"})
taskExecutionMetadata.OnGetSecurityContext().Return(flyteIdlCore.SecurityContext{})
taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Expand Down Expand Up @@ -353,6 +358,7 @@ func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.T
outputReader.OnGetOutputPath().Return(storage.DataReference("/data/outputs.pb"))
outputReader.OnGetOutputPrefixPath().Return(storage.DataReference("/data/"))
outputReader.OnGetRawOutputPrefix().Return(storage.DataReference("/raw/"))

taskCtx.OnOutputWriter().Return(outputReader)

taskReader := &mocks.TaskReader{}
Expand Down Expand Up @@ -384,6 +390,9 @@ func genMockTaskExecutionMetadata() *mocks.TaskExecutionMetadata {
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"iam.amazonaws.com/role": "metadata_role"})
taskExecutionMetadata.OnGetSecurityContext().Return(flyteIdlCore.SecurityContext{
RunAs: &flyteIdlCore.Identity{IamRole: "default_role"},
})
taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Expand Down
3 changes: 1 addition & 2 deletions go/tasks/plugins/k8s/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx plugins
// CrashLoopBackoff after the initial job completion.
pod.Spec.RestartPolicy = k8sv1.RestartPolicyNever

// We want to also update the serviceAccount to the serviceaccount of the workflow
pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount()
pod.Spec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

pod, err = validateAndFinalizePod(ctx, taskCtx, primaryContainerName, *pod)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions go/tasks/plugins/k8s/sidecar/sidecar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore.
taskMetadata := &pluginsCoreMock.TaskExecutionMetadata{}
taskMetadata.On("GetNamespace").Return("test-namespace")
taskMetadata.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"})

taskMetadata.On("GetLabels").Return(map[string]string{"label-1": "val1"})
taskMetadata.On("GetOwnerReference").Return(metav1.OwnerReference{
Kind: "node",
Name: "blah",
})
taskMetadata.On("IsInterruptible").Return(true)
taskMetadata.On("GetSecurityContext").Return(core.SecurityContext{})
taskMetadata.On("GetK8sServiceAccount").Return("service-account")
taskMetadata.On("GetOwnerID").Return(types.NamespacedName{
Namespace: "test-namespace",
Expand Down Expand Up @@ -319,6 +321,7 @@ func TestBuildSidecarResource(t *testing.T) {
assert.Len(t, res.(*v1.Pod).Spec.Containers[0].VolumeMounts, 1)
assert.Equal(t, "volume mount", res.(*v1.Pod).Spec.Containers[0].VolumeMounts[0].Name)

assert.Equal(t, "service-account", res.(*v1.Pod).Spec.ServiceAccountName)
// Assert user-specified tolerations don't get overridden
assert.Len(t, res.(*v1.Pod).Spec.Tolerations, 1)
for _, tol := range res.(*v1.Pod).Spec.Tolerations {
Expand Down
9 changes: 7 additions & 2 deletions go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,19 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
}
sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts()))

serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

if len(serviceAccountName) == 0 {
serviceAccountName = sparkTaskType
}
driverSpec := sparkOp.DriverSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Annotations: annotations,
Labels: labels,
EnvVars: sparkEnvVars,
Image: &container.Image,
},
ServiceAccount: &sparkTaskType,
ServiceAccount: &serviceAccountName,
}

executorSpec := sparkOp.ExecutorSpec{
Expand Down Expand Up @@ -184,7 +189,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
APIVersion: sparkOp.SchemeGroupVersion.String(),
},
Spec: sparkOp.SparkApplicationSpec{
ServiceAccount: &sparkTaskType,
ServiceAccount: &serviceAccountName,
Type: getApplicationType(sparkJob.GetApplicationType()),
Image: &container.Image,
Arguments: modifiedArgs,
Expand Down
Loading

0 comments on commit d0a6ee2

Please sign in to comment.