Skip to content

Commit

Permalink
[Spark plugin] Fix environment variable ValueFrom for pod templates (#…
Browse files Browse the repository at this point in the history
…4532)

* Update tests

Signed-off-by: Thomas Newton <[email protected]>

* Use Copy Container.Env to SparkPodSpec.Env instead of SparkPodSpec.EnvVars

Signed-off-by: Thomas Newton <[email protected]>

* Fix list initialisation

Signed-off-by: Thomas Newton <[email protected]>

---------

Signed-off-by: Thomas Newton <[email protected]>
Co-authored-by: Dan Rammer <[email protected]>
  • Loading branch information
Tom-Newton and hamersaw authored Dec 6, 2023
1 parent aff1150 commit 94d79f5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 15 deletions.
8 changes: 4 additions & 4 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,17 @@ func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.Po
annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))
labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))

sparkEnvVars := make(map[string]string)
sparkEnv := make([]v1.EnvVar, 0)
for _, envVar := range container.Env {
sparkEnvVars[envVar.Name] = envVar.Value
sparkEnv = append(sparkEnv, *envVar.DeepCopy())
}
sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts()))
sparkEnv = append(sparkEnv, v1.EnvVar{Name: "FLYTE_MAX_ATTEMPTS", Value: strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts()))})

spec := sparkOp.SparkPodSpec{
Affinity: podSpec.Affinity,
Annotations: annotations,
Labels: labels,
EnvVars: sparkEnvVars,
Env: sparkEnv,
Image: &container.Image,
SecurityContenxt: podSpec.SecurityContext.DeepCopy(),
DNSConfig: podSpec.DNSConfig.DeepCopy(),
Expand Down
47 changes: 36 additions & 11 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ var (
{Key: "Env_Var", Value: "Env_Val"},
}

dummyEnvVarsWithSecretRef = []corev1.EnvVar{
{Name: "Env_Var", Value: "Env_Val"},
{Name: "SECRET", ValueFrom: &corev1.EnvVarSource{
SecretKeyRef: &corev1.SecretKeySelector{
Key: "key",
LocalObjectReference: corev1.LocalObjectReference{
Name: "secret-name",
},
},
}},
}

testArgs = []string{
"execute-spark-task",
}
Expand Down Expand Up @@ -261,7 +273,7 @@ func dummyPodSpec() *corev1.PodSpec {
Name: "primary",
Image: testImage,
Args: testArgs,
Env: flytek8s.ToK8sEnvVar(dummyEnvVars),
Env: dummyEnvVarsWithSecretRef,
},
{
Name: "secondary",
Expand Down Expand Up @@ -512,6 +524,15 @@ func defaultPluginConfig() *config.K8sPluginConfig {
return config
}

func findEnvVarByName(envVars []corev1.EnvVar, name string) *corev1.EnvVar {
for _, envVar := range envVars {
if envVar.Name == name {
return &envVar
}
}
return nil
}

func TestBuildResourceContainer(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}

Expand Down Expand Up @@ -639,11 +660,11 @@ func TestBuildResourceContainer(t *testing.T) {
assert.Greater(t, len(sparkApp.Spec.SparkConf["spark.kubernetes.driverEnv.FLYTE_START_TIME"]), 1)
assert.Equal(t, dummySparkConf["spark.flyteorg.feature3.enabled"], sparkApp.Spec.SparkConf["spark.flyteorg.feature3.enabled"])

assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"])
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"])
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"])
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"])
assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value)

assert.Equal(t, &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
Expand Down Expand Up @@ -789,9 +810,11 @@ func TestBuildResourcePodTemplate(t *testing.T) {
// Driver
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels)
assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"])
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"])
assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value)
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET"))
assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env))
assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image)
assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount)
assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt)
Expand Down Expand Up @@ -825,8 +848,10 @@ func TestBuildResourcePodTemplate(t *testing.T) {
// Executor
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations)
assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"])
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"])
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value)
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET"))
assert.Equal(t, 9, len(sparkApp.Spec.Executor.Env))
assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image)
assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt)
assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig)
Expand Down

0 comments on commit 94d79f5

Please sign in to comment.