diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager_test.go b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager_test.go index a96536e7ac..2c7fda0f6e 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager_test.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager_test.go @@ -200,6 +200,7 @@ func getMockTaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { taskExecutionMetadata.On("GetAnnotations").Return(map[string]string{"aKey": "aVal"}) taskExecutionMetadata.On("GetLabels").Return(map[string]string{"lKey": "lVal"}) taskExecutionMetadata.On("GetOwnerReference").Return(metav1.OwnerReference{Name: "x"}) + taskExecutionMetadata.On("GetSecurityContext").Return(core.SecurityContext{RunAs: &core.Identity{}}) id := &pluginsCoreMock.TaskExecutionID{} id.On("GetGeneratedName").Return("test") diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/task_exec_context.go b/flytepropeller/pkg/controller/nodes/task/k8s/task_exec_context.go index 14984e3f97..17bbce5398 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/task_exec_context.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/task_exec_context.go @@ -7,6 +7,8 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils/secrets" ) +const executionIdentityVariable = "execution-identity" + // TaskExecutionContext provides a layer on top of core TaskExecutionContext with a custom TaskExecutionMetadata. type TaskExecutionContext struct { pluginsCore.TaskExecutionContext @@ -42,25 +44,28 @@ func (t TaskExecutionMetadata) GetAnnotations() map[string]string { } // newTaskExecutionMetadata creates a TaskExecutionMetadata with secrets serialized as annotations and a label added -// to trigger the flyte pod webhook +// to trigger the flyte pod webhook. If known, the execution identity is injected as a label. func newTaskExecutionMetadata(tCtx pluginsCore.TaskExecutionMetadata, taskTmpl *core.TaskTemplate) (TaskExecutionMetadata, error) { var err error secretsMap := make(map[string]string) - injectSecretsLabel := make(map[string]string) + injectLabels := make(map[string]string) if taskTmpl.SecurityContext != nil && len(taskTmpl.SecurityContext.Secrets) > 0 { secretsMap, err = secrets.MarshalSecretsToMapStrings(taskTmpl.SecurityContext.Secrets) if err != nil { return TaskExecutionMetadata{}, err } - injectSecretsLabel = map[string]string{ - secrets.PodLabel: secrets.PodLabelValue, - } + injectLabels[secrets.PodLabel] = secrets.PodLabelValue + } + + id := tCtx.GetSecurityContext().RunAs.ExecutionIdentity + if len(id) > 0 { + injectLabels[executionIdentityVariable] = id } return TaskExecutionMetadata{ TaskExecutionMetadata: tCtx, annotations: utils.UnionMaps(tCtx.GetAnnotations(), secretsMap), - labels: utils.UnionMaps(tCtx.GetLabels(), injectSecretsLabel), + labels: utils.UnionMaps(tCtx.GetLabels(), injectLabels), }, nil } diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/task_exec_context_test.go b/flytepropeller/pkg/controller/nodes/task/k8s/task_exec_context_test.go index 8807236bff..bf9ca1eadb 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/task_exec_context_test.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/task_exec_context_test.go @@ -21,6 +21,7 @@ func Test_newTaskExecutionMetadata(t *testing.T) { "existingLabel": "existingLabelValue", } existingMetadata.OnGetLabels().Return(existingLabels) + existingMetadata.OnGetSecurityContext().Return(core.SecurityContext{RunAs: &core.Identity{}}) actual, err := newTaskExecutionMetadata(existingMetadata, &core.TaskTemplate{}) assert.NoError(t, err) @@ -40,6 +41,7 @@ func Test_newTaskExecutionMetadata(t *testing.T) { "existingLabel": "existingLabelValue", } existingMetadata.OnGetLabels().Return(existingLabels) + existingMetadata.OnGetSecurityContext().Return(core.SecurityContext{RunAs: &core.Identity{}}) actual, err := newTaskExecutionMetadata(existingMetadata, &core.TaskTemplate{ SecurityContext: &core.SecurityContext{ @@ -64,6 +66,26 @@ func Test_newTaskExecutionMetadata(t *testing.T) { "inject-flyte-secrets": "true", }, actual.GetLabels()) }) + + t.Run("Inject exec identity", func(t *testing.T) { + + existingMetadata := &mocks.TaskExecutionMetadata{} + existingAnnotations := map[string]string{} + existingMetadata.OnGetAnnotations().Return(existingAnnotations) + + existingMetadata.OnGetSecurityContext().Return(core.SecurityContext{RunAs: &core.Identity{ExecutionIdentity: "test-exec-identity"}}) + + existingLabels := map[string]string{ + "existingLabel": "existingLabelValue", + } + existingMetadata.OnGetLabels().Return(existingLabels) + + actual, err := newTaskExecutionMetadata(existingMetadata, &core.TaskTemplate{}) + assert.NoError(t, err) + + assert.Equal(t, 2, len(actual.GetLabels())) + assert.Equal(t, "test-exec-identity", actual.GetLabels()[executionIdentityVariable]) + }) } func Test_newTaskExecutionContext(t *testing.T) {