Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Inject and Use values from Security Context
Browse files Browse the repository at this point in the history
  • Loading branch information
anandswaminathan committed Mar 2, 2021
1 parent e84585e commit e1bffe0
Show file tree
Hide file tree
Showing 20 changed files with 219 additions and 52 deletions.
2 changes: 1 addition & 1 deletion copilot/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/gogo/protobuf v1.3.1
github.com/golang/protobuf v1.4.2
github.com/imdario/mergo v0.3.9 // indirect
github.com/lyft/flyteidl v0.18.11
github.com/lyft/flyteidl v0.18.13
github.com/lyft/flyteplugins v0.4.4
github.com/lyft/flytestdlib v0.3.9
github.com/mitchellh/go-ps v1.0.0
Expand Down
1 change: 1 addition & 0 deletions copilot/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f/go.mod h1:llRdnz
github.com/lyft/flyteidl v0.18.9/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.18.11 h1:24NaFYWxANhRbwKfvkgu8axGTWUcl1tgZBqNJutKNJ8=
github.com/lyft/flyteidl v0.18.11/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.18.13/go.mod h1:JTJC2VqrpEWM/76lPF2Dj9l4FA8FjZh67U0RYuq/Aes=
github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU=
github.com/lyft/flytestdlib v0.3.9 h1:NaKp9xkeWWwhVvqTOcR/FqlASy1N2gu/kN7PVe4S7YI=
github.com/lyft/flytestdlib v0.3.9/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU=
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ require (
github.com/hashicorp/golang-lru v0.5.4
github.com/kubeflow/pytorch-operator v0.6.0
github.com/kubeflow/tf-operator v0.5.3
github.com/lyft/flyteidl v0.18.9
github.com/lyft/flyteidl v0.18.13
github.com/lyft/flytestdlib v0.3.9
github.com/magiconair/properties v1.8.1
github.com/mitchellh/mapstructure v1.1.2
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ github.com/lyft/flyteidl v0.18.9 h1:p9gLp92whTSSOeMGPtZ4tkgsVHNGuBuXXMQ447s0J9E=
github.com/lyft/flyteidl v0.18.9/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.18.11 h1:24NaFYWxANhRbwKfvkgu8axGTWUcl1tgZBqNJutKNJ8=
github.com/lyft/flyteidl v0.18.11/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.18.13 h1:YnANrIAZ9xTIuXkafDyIX6IaL8tkNsLOvUwdfFz9TKY=
github.com/lyft/flyteidl v0.18.13/go.mod h1:JTJC2VqrpEWM/76lPF2Dj9l4FA8FjZh67U0RYuq/Aes=
github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU=
github.com/lyft/flytestdlib v0.3.9 h1:NaKp9xkeWWwhVvqTOcR/FqlASy1N2gu/kN7PVe4S7YI=
github.com/lyft/flytestdlib v0.3.9/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU=
Expand Down
1 change: 1 addition & 0 deletions go/tasks/pluginmachinery/core/exec_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ type TaskExecutionMetadata interface {
GetMaxAttempts() uint32
GetAnnotations() map[string]string
GetK8sServiceAccount() string
GetSecurityContext() map[string]string
IsInterruptible() bool
}
34 changes: 34 additions & 0 deletions go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions go/tasks/pluginmachinery/flytek8s/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@ import (
v1 "k8s.io/api/core/v1"
)

var serviceAccountNameKey = "serviceAccountName"

func ToK8sEnvVar(env []*core.KeyValuePair) []v1.EnvVar {
envVars := make([]v1.EnvVar, 0, len(env))
for _, kv := range env {
envVars = append(envVars, v1.EnvVar{Name: kv.Key, Value: kv.Value})
}
return envVars
}

func GetServiceAccountNameFromSecurityContext(securityContext map[string]string) string {
return securityContext[serviceAccountNameKey]
}
8 changes: 5 additions & 3 deletions go/tasks/plugins/array/awsbatch/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ type Config struct {
// Provide additional environment variable pairs that plugin authors will provide to containers
DefaultEnvVars map[string]string `json:"defaultEnvVars" pflag:"-,Additional environment variable that should be injected into every resource"`
MaxErrorStringLength int `json:"maxErrLength" pflag:",Determines the maximum length of the error string returned for the array."`
RoleAnnotationKey string `json:"roleAnnotationKey" pflag:",Map key to use to lookup role from task annotations."`
OutputAssembler workqueue.Config `json:"outputAssembler"`
ErrorAssembler workqueue.Config `json:"errorAssembler"`
// This can be deprecated. Just having it for backward compatibility
RoleAnnotationKey string `json:"roleAnnotationKey" pflag:",Map key to use to lookup role from task annotations."`
RoleSecurityContextKey string `json:"roleSecurityContextKey" pflag:",Map key to use to lookup role from security context."`
OutputAssembler workqueue.Config `json:"outputAssembler"`
ErrorAssembler workqueue.Config `json:"errorAssembler"`
}

type JobStoreConfig struct {
Expand Down
1 change: 1 addition & 0 deletions go/tasks/plugins/array/awsbatch/config/config_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions go/tasks/plugins/array/awsbatch/config/config_flags_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion go/tasks/plugins/array/awsbatch/job_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionConte
return nil, errors.Errorf(pluginErrors.BadTaskSpecification, "Tasktemplate does not contain a container image.")
}

role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata().GetAnnotations())
role := awsUtils.GetRole(ctx, cfg.RoleSecurityContextKey, tCtx.TaskExecutionMetadata().GetSecurityContext())

// Continue this for backward compatibility
if len(role) == 0 {
role = awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata().GetAnnotations())
}

cacheKey := definition.NewCacheKey(role, containerImage)
if existingArn, found := definitionCache.Get(cacheKey); found {
Expand Down
68 changes: 68 additions & 0 deletions go/tasks/plugins/array/awsbatch/job_definition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,74 @@ func TestEnsureJobDefinition(t *testing.T) {
tMeta.OnGetTaskExecutionID().Return(tID)
tMeta.OnGetOverrides().Return(overrides)
tMeta.OnGetAnnotations().Return(map[string]string{})
tMeta.OnGetSecurityContext().Return(map[string]string{})

tCtx := &mocks.TaskExecutionContext{}
tCtx.OnTaskReader().Return(tReader)
tCtx.OnTaskExecutionMetadata().Return(tMeta)

cfg := &config.Config{}
batchClient := NewCustomBatchClient(batchMocks.NewMockAwsBatchClient(), "", "",
utils.NewRateLimiter("", 10, 20),
utils.NewRateLimiter("", 10, 20))

t.Run("Not Found", func(t *testing.T) {
dCache := definition.NewCache(10)

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
})

assert.NoError(t, err)
assert.NotNil(t, nextState)
assert.Equal(t, "my-arn", nextState.JobDefinitionArn)
p, v := nextState.GetPhase()
assert.Equal(t, arrayCore.PhaseLaunch, p)
assert.Zero(t, v)
})

t.Run("Found", func(t *testing.T) {
dCache := definition.NewCache(10)
assert.NoError(t, dCache.Put(definition.NewCacheKey("", "img1"), "their-arn"))

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
})
assert.NoError(t, err)
assert.NotNil(t, nextState)
assert.Equal(t, "their-arn", nextState.JobDefinitionArn)
})
}

func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) {
ctx := context.Background()

tReader := &mocks.TaskReader{}
tReader.OnReadMatch(mock.Anything).Return(&core.TaskTemplate{
Interface: &core.TypedInterface{
Outputs: &core.VariableMap{
Variables: map[string]*core.Variable{"var1": {Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}},
},
},
Target: &core.TaskTemplate_Container{
Container: createSampleContainerTask(),
},
}, nil)

overrides := &mocks.TaskOverrides{}
overrides.OnGetConfig().Return(&v1.ConfigMap{Data: map[string]string{
DynamicTaskQueueKey: "queue1",
"roleSecurityContextKey": "iam/role",
}})

tID := &mocks.TaskExecutionID{}
tID.OnGetGeneratedName().Return("found")

tMeta := &mocks.TaskExecutionMetadata{}
tMeta.OnGetTaskExecutionID().Return(tID)
tMeta.OnGetOverrides().Return(overrides)
tMeta.OnGetAnnotations().Return(map[string]string{})
tMeta.OnGetSecurityContext().Return(map[string]string{"iam/role": "new-role"})

tCtx := &mocks.TaskExecutionContext{}
tCtx.OnTaskReader().Return(tReader)
Expand Down
11 changes: 9 additions & 2 deletions go/tasks/plugins/k8s/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,15 @@ func (Plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecuti

pod := flytek8s.BuildPodWithSpec(podSpec)

// We want to Also update the serviceAccount to the serviceaccount of the workflow
pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount()
serviceAccountName := flytek8s.GetServiceAccountNameFromSecurityContext(taskCtx.TaskExecutionMetadata().GetSecurityContext())

// TO BE DEPRECATED
if len(serviceAccountName) == 0 {
// We want to Also update the serviceAccount to the serviceaccount of the workflow
serviceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount()
}

pod.Spec.ServiceAccountName = serviceAccountName

return pod, nil
}
Expand Down
1 change: 1 addition & 0 deletions go/tasks/plugins/k8s/container/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore.
Kind: "node",
Name: "blah",
})
taskMetadata.On("GetSecurityContext").Return(map[string]string{})
taskMetadata.On("GetK8sServiceAccount").Return("service-account")
taskMetadata.On("GetOwnerID").Return(types.NamespacedName{
Namespace: "test-namespace",
Expand Down
11 changes: 9 additions & 2 deletions go/tasks/plugins/k8s/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,15 @@ func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx plugins
// CrashLoopBackoff after the initial job completion.
pod.Spec.RestartPolicy = k8sv1.RestartPolicyNever

// We want to Also update the serviceAccount to the serviceaccount of the workflow
pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount()
serviceAccountName := flytek8s.GetServiceAccountNameFromSecurityContext(taskCtx.TaskExecutionMetadata().GetSecurityContext())

// TO BE DEPRECATED
if len(serviceAccountName) == 0 {
// We want to Also update the serviceAccount to the serviceaccount of the workflow
serviceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount()
}

pod.Spec.ServiceAccountName = serviceAccountName

pod, err = validateAndFinalizePod(ctx, taskCtx, sidecarJob.PrimaryContainerName, *pod)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions go/tasks/plugins/k8s/sidecar/sidecar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore.
taskMetadata := &pluginsCoreMock.TaskExecutionMetadata{}
taskMetadata.On("GetNamespace").Return("test-namespace")
taskMetadata.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"})
taskMetadata.On("GetSecurityContext").Return(map[string]string{})

taskMetadata.On("GetLabels").Return(map[string]string{"label-1": "val1"})
taskMetadata.On("GetOwnerReference").Return(metav1.OwnerReference{
Kind: "node",
Expand Down Expand Up @@ -179,6 +181,7 @@ func TestBuildSidecarResource(t *testing.T) {
assert.Len(t, res.(*v1.Pod).Spec.Containers[0].VolumeMounts, 1)
assert.Equal(t, "volume mount", res.(*v1.Pod).Spec.Containers[0].VolumeMounts[0].Name)

assert.Equal(t, "service-account", res.(*v1.Pod).Spec.ServiceAccountName)
// Assert user-specified tolerations don't get overridden
assert.Len(t, res.(*v1.Pod).Spec.Tolerations, 2)
for _, tol := range res.(*v1.Pod).Spec.Tolerations {
Expand Down
9 changes: 7 additions & 2 deletions go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,19 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
}
sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts()))

serviceAccountName := flytek8s.GetServiceAccountNameFromSecurityContext(taskCtx.TaskExecutionMetadata().GetSecurityContext())

if len(serviceAccountName) == 0 {
serviceAccountName = sparkTaskType
}
driverSpec := sparkOp.DriverSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Annotations: annotations,
Labels: labels,
EnvVars: sparkEnvVars,
Image: &container.Image,
},
ServiceAccount: &sparkTaskType,
ServiceAccount: &serviceAccountName,
}

executorSpec := sparkOp.ExecutorSpec{
Expand Down Expand Up @@ -182,7 +187,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
APIVersion: sparkOp.SchemeGroupVersion.String(),
},
Spec: sparkOp.SparkApplicationSpec{
ServiceAccount: &sparkTaskType,
ServiceAccount: &serviceAccountName,
Type: getApplicationType(sparkJob.GetApplicationType()),
Mode: sparkOp.ClusterMode,
Image: &container.Image,
Expand Down
2 changes: 2 additions & 0 deletions go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool)
Kind: "node",
Name: "blah",
})
taskExecutionMetadata.On("GetSecurityContext").Return(map[string]string{"serviceAccountName": "new-val"})
taskExecutionMetadata.On("IsInterruptible").Return(interruptible)
taskExecutionMetadata.On("GetMaxAttempts").Return(uint32(1))
taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata)
Expand Down Expand Up @@ -369,6 +370,7 @@ func TestBuildResourceSpark(t *testing.T) {
execCores, _ := strconv.Atoi(dummySparkConf["spark.executor.cores"])
execInstances, _ := strconv.Atoi(dummySparkConf["spark.executor.instances"])

assert.Equal(t, "new-val", *sparkApp.Spec.ServiceAccount)
assert.Equal(t, int32(driverCores), *sparkApp.Spec.Driver.Cores)
assert.Equal(t, int32(execCores), *sparkApp.Spec.Executor.Cores)
assert.Equal(t, int32(execInstances), *sparkApp.Spec.Executor.Instances)
Expand Down
16 changes: 8 additions & 8 deletions go/tasks/plugins/webapi/athena/config_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e1bffe0

Please sign in to comment.