diff --git a/go/tasks/plugins/k8s/spark/spark.go b/go/tasks/plugins/k8s/spark/spark.go index 11801dcb8..48ad9ee9a 100755 --- a/go/tasks/plugins/k8s/spark/spark.go +++ b/go/tasks/plugins/k8s/spark/spark.go @@ -210,6 +210,10 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo }, } + if val, ok := sparkConfig["spark.batchScheduler"]; ok { + j.Spec.BatchScheduler = &val + } + if sparkJob.MainApplicationFile != "" { j.Spec.MainApplicationFile = &sparkJob.MainApplicationFile } diff --git a/go/tasks/plugins/k8s/spark/spark_test.go b/go/tasks/plugins/k8s/spark/spark_test.go index 394d06d5d..4b6a662c1 100755 --- a/go/tasks/plugins/k8s/spark/spark_test.go +++ b/go/tasks/plugins/k8s/spark/spark_test.go @@ -47,6 +47,7 @@ var ( "spark.flyte.feature1.enabled": "true", "spark.flyteorg.feature2.enabled": "true", "spark.flyteorg.feature3.enabled": "true", + "spark.batchScheduler": "volcano", } dummyEnvVars = []*core.KeyValuePair{ @@ -384,6 +385,7 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, int32(execInstances), *sparkApp.Spec.Executor.Instances) assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) + assert.Equal(t, dummySparkConf["spark.batchScheduler"], *sparkApp.Spec.BatchScheduler) // Validate Interruptible Toleration and NodeSelector set for Executor but not Driver. assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations))