From a30fe259a3600bc0d3e3c776e6c0d5f7d8450a3e Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 4 Dec 2023 15:34:40 +0000 Subject: [PATCH 1/3] Update tests Signed-off-by: Thomas Newton --- .../go/tasks/plugins/k8s/spark/spark_test.go | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index be2f9477b6..b513ea3c33 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -49,6 +49,18 @@ var ( {Key: "Env_Var", Value: "Env_Val"}, } + dummyEnvVarsWithSecretRef = []corev1.EnvVar{ + {Name: "Env_Var", Value: "Env_Val"}, + {Name: "SECRET", ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + Key: "key", + LocalObjectReference: corev1.LocalObjectReference{ + Name: "secret-name", + }, + }, + }}, + } + testArgs = []string{ "execute-spark-task", } @@ -261,7 +273,7 @@ func dummyPodSpec() *corev1.PodSpec { Name: "primary", Image: testImage, Args: testArgs, - Env: flytek8s.ToK8sEnvVar(dummyEnvVars), + Env: dummyEnvVarsWithSecretRef, }, { Name: "secondary", @@ -512,6 +524,15 @@ func defaultPluginConfig() *config.K8sPluginConfig { return config } +func findEnvVarByName(envVars []corev1.EnvVar, name string) *corev1.EnvVar { + for _, envVar := range envVars { + if envVar.Name == name { + return &envVar + } + } + return nil +} + func TestBuildResourceContainer(t *testing.T) { sparkResourceHandler := sparkResourceHandler{} @@ -639,11 +660,11 @@ func TestBuildResourceContainer(t *testing.T) { assert.Greater(t, len(sparkApp.Spec.SparkConf["spark.kubernetes.driverEnv.FLYTE_START_TIME"]), 1) assert.Equal(t, dummySparkConf["spark.flyteorg.feature3.enabled"], sparkApp.Spec.SparkConf["spark.flyteorg.feature3.enabled"]) - assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1) - assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"]) - assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"]) - assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"]) - assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"]) + assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value) assert.Equal(t, &corev1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ @@ -789,9 +810,10 @@ func TestBuildResourcePodTemplate(t *testing.T) { // Driver assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations) assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels) - assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1) - assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"]) - assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"]) + assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value) + assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET")) assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image) assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount) assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt) @@ -825,8 +847,9 @@ func TestBuildResourcePodTemplate(t *testing.T) { // Executor assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations) assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels) - assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"]) - assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value) + assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET")) assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image) assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt) assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig) From 2531a2d185edaa9316ac237f7026c517cecb954c Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 4 Dec 2023 15:35:23 +0000 Subject: [PATCH 2/3] Use Copy Container.Env to SparkPodSpec.Env instead of SparkPodSpec.EnvVars Signed-off-by: Thomas Newton --- flyteplugins/go/tasks/plugins/k8s/spark/spark.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index defcb275fe..7a6f5e9538 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -145,17 +145,17 @@ func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.Po annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) - sparkEnvVars := make(map[string]string) + sparkEnv := make([]v1.EnvVar, len(container.Env)+1) for _, envVar := range container.Env { - sparkEnvVars[envVar.Name] = envVar.Value + sparkEnv = append(sparkEnv, *envVar.DeepCopy()) } - sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts())) + sparkEnv = append(sparkEnv, v1.EnvVar{Name: "FLYTE_MAX_ATTEMPTS", Value: strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts()))}) spec := sparkOp.SparkPodSpec{ Affinity: podSpec.Affinity, Annotations: annotations, Labels: labels, - EnvVars: sparkEnvVars, + Env: sparkEnv, Image: &container.Image, SecurityContenxt: podSpec.SecurityContext.DeepCopy(), DNSConfig: podSpec.DNSConfig.DeepCopy(), From 9231d602d207485f7460fdd69033dcffc8cb1dac Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Mon, 4 Dec 2023 20:24:02 +0000 Subject: [PATCH 3/3] Fix list initialisation Signed-off-by: Thomas Newton --- flyteplugins/go/tasks/plugins/k8s/spark/spark.go | 2 +- flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 7a6f5e9538..22240e9e45 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -145,7 +145,7 @@ func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.Po annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) - sparkEnv := make([]v1.EnvVar, len(container.Env)+1) + sparkEnv := make([]v1.EnvVar, 0) for _, envVar := range container.Env { sparkEnv = append(sparkEnv, *envVar.DeepCopy()) } diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index b513ea3c33..b07ab0ef33 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -814,6 +814,7 @@ func TestBuildResourcePodTemplate(t *testing.T) { assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value) assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value) assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET")) + assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env)) assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image) assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount) assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt) @@ -850,6 +851,7 @@ func TestBuildResourcePodTemplate(t *testing.T) { assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value) assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value) assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET")) + assert.Equal(t, 9, len(sparkApp.Spec.Executor.Env)) assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image) assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt) assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig)