Skip to content

Commit

Permalink
Update Spark Plugin (flyteorg#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
akhurana001 authored Nov 6, 2020
1 parent 76c110c commit 3421687
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
32 changes: 32 additions & 0 deletions go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,24 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
}
sparkConfig["spark.kubernetes.executor.podNamePrefix"] = taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()

// Add driver/executor defaults to CRD Driver/Executor Spec as well.
cores, err := strconv.Atoi(sparkConfig["spark.driver.cores"])
if err == nil {
driverSpec.Cores = intPtr(int32(cores))
}
driverSpec.Memory = strPtr(sparkConfig["spark.driver.memory"])

execCores, err := strconv.Atoi(sparkConfig["spark.executor.cores"])
if err == nil {
executorSpec.Cores = intPtr(int32(execCores))
}

execCount, err := strconv.Atoi(sparkConfig["spark.executor.instances"])
if err == nil {
executorSpec.Instances = intPtr(int32(execCount))
}
executorSpec.Memory = strPtr(sparkConfig["spark.executor.memory"])

j := &sparkOp.SparkApplication{
TypeMeta: metav1.TypeMeta{
Kind: KindSparkApplication,
Expand Down Expand Up @@ -392,3 +410,17 @@ func init() {
DefaultForTaskTypes: []pluginsCore.TaskType{sparkTaskType},
})
}

func strPtr(str string) *string {
if str == "" {
return nil
}
return &str
}

func intPtr(val int32) *int32 {
if val == 0 {
return nil
}
return &val
}
18 changes: 16 additions & 2 deletions go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package spark
import (
"context"
"fmt"
"strconv"
"testing"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
Expand Down Expand Up @@ -37,9 +38,9 @@ const sparkUIAddress = "spark-ui.flyte"

var (
dummySparkConf = map[string]string{
"spark.driver.memory": "500M",
"spark.driver.memory": "200M",
"spark.driver.cores": "1",
"spark.executor.cores": "1",
"spark.executor.cores": "2",
"spark.executor.instances": "3",
"spark.executor.memory": "500M",
"spark.flyte.feature1.enabled": "true",
Expand Down Expand Up @@ -316,6 +317,19 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, sj.PythonApplicationType, sparkApp.Spec.Type)
assert.Equal(t, testArgs, sparkApp.Spec.Arguments)
assert.Equal(t, testImage, *sparkApp.Spec.Image)

//Validate Driver/Executor Spec.

driverCores, _ := strconv.Atoi(dummySparkConf["spark.driver.cores"])
execCores, _ := strconv.Atoi(dummySparkConf["spark.executor.cores"])
execInstances, _ := strconv.Atoi(dummySparkConf["spark.executor.instances"])

assert.Equal(t, int32(driverCores), *sparkApp.Spec.Driver.Cores)
assert.Equal(t, int32(execCores), *sparkApp.Spec.Executor.Cores)
assert.Equal(t, int32(execInstances), *sparkApp.Spec.Executor.Instances)
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)

// Validate Interruptible Toleration and NodeSelector set for Executor but not Driver.
assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector))
Expand Down

0 comments on commit 3421687

Please sign in to comment.