From 61da5cf1a7bcf48e726a6b234cecb75baff6adc2 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Wed, 6 Dec 2023 16:20:45 +0000 Subject: [PATCH] [Spark plugin] Fix environment variable ValueFrom for pod templates (#4532) * Update tests Signed-off-by: Thomas Newton * Use Copy Container.Env to SparkPodSpec.Env instead of SparkPodSpec.EnvVars Signed-off-by: Thomas Newton * Fix list initialisation Signed-off-by: Thomas Newton --------- Signed-off-by: Thomas Newton Co-authored-by: Dan Rammer Signed-off-by: Paul Dittamo --- .../go/tasks/plugins/k8s/spark/spark.go | 8 ++-- .../go/tasks/plugins/k8s/spark/spark_test.go | 47 ++++++++++++++----- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index defcb275fe..22240e9e45 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -145,17 +145,17 @@ func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.Po annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) - sparkEnvVars := make(map[string]string) + sparkEnv := make([]v1.EnvVar, 0) for _, envVar := range container.Env { - sparkEnvVars[envVar.Name] = envVar.Value + sparkEnv = append(sparkEnv, *envVar.DeepCopy()) } - sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts())) + sparkEnv = append(sparkEnv, v1.EnvVar{Name: "FLYTE_MAX_ATTEMPTS", Value: strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts()))}) spec := sparkOp.SparkPodSpec{ Affinity: podSpec.Affinity, Annotations: annotations, Labels: labels, - EnvVars: sparkEnvVars, + Env: sparkEnv, Image: &container.Image, SecurityContenxt: podSpec.SecurityContext.DeepCopy(), DNSConfig: podSpec.DNSConfig.DeepCopy(), diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index be2f9477b6..b07ab0ef33 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -49,6 +49,18 @@ var ( {Key: "Env_Var", Value: "Env_Val"}, } + dummyEnvVarsWithSecretRef = []corev1.EnvVar{ + {Name: "Env_Var", Value: "Env_Val"}, + {Name: "SECRET", ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + Key: "key", + LocalObjectReference: corev1.LocalObjectReference{ + Name: "secret-name", + }, + }, + }}, + } + testArgs = []string{ "execute-spark-task", } @@ -261,7 +273,7 @@ func dummyPodSpec() *corev1.PodSpec { Name: "primary", Image: testImage, Args: testArgs, - Env: flytek8s.ToK8sEnvVar(dummyEnvVars), + Env: dummyEnvVarsWithSecretRef, }, { Name: "secondary", @@ -512,6 +524,15 @@ func defaultPluginConfig() *config.K8sPluginConfig { return config } +func findEnvVarByName(envVars []corev1.EnvVar, name string) *corev1.EnvVar { + for _, envVar := range envVars { + if envVar.Name == name { + return &envVar + } + } + return nil +} + func TestBuildResourceContainer(t *testing.T) { sparkResourceHandler := sparkResourceHandler{} @@ -639,11 +660,11 @@ func TestBuildResourceContainer(t *testing.T) { assert.Greater(t, len(sparkApp.Spec.SparkConf["spark.kubernetes.driverEnv.FLYTE_START_TIME"]), 1) assert.Equal(t, dummySparkConf["spark.flyteorg.feature3.enabled"], sparkApp.Spec.SparkConf["spark.flyteorg.feature3.enabled"]) - assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1) - assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"]) - assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"]) - assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"]) - assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"]) + assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value) assert.Equal(t, &corev1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ @@ -789,9 +810,11 @@ func TestBuildResourcePodTemplate(t *testing.T) { // Driver assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations) assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels) - assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1) - assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"]) - assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"]) + assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value) + assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET")) + assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env)) assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image) assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount) assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt) @@ -825,8 +848,10 @@ func TestBuildResourcePodTemplate(t *testing.T) { // Executor assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations) assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels) - assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"]) - assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value) + assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET")) + assert.Equal(t, 9, len(sparkApp.Spec.Executor.Env)) assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image) assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt) assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig)