diff --git a/go/tasks/config_load_test.go b/go/tasks/config_load_test.go index 8c2d6420d..74ecfff18 100755 --- a/go/tasks/config_load_test.go +++ b/go/tasks/config_load_test.go @@ -84,6 +84,12 @@ func TestLoadConfig(t *testing.T) { t.Run("spark-config-test", func(t *testing.T) { assert.NotNil(t, spark.GetSparkConfig()) assert.NotNil(t, spark.GetSparkConfig().DefaultSparkConfig) + assert.Equal(t, 2, len(spark.GetSparkConfig().Features)) + assert.Equal(t, "feature1", spark.GetSparkConfig().Features[0].Name) + assert.Equal(t, "feature2", spark.GetSparkConfig().Features[1].Name) + assert.Equal(t, 2, len(spark.GetSparkConfig().Features[0].SparkConfig)) + assert.Equal(t, 2, len(spark.GetSparkConfig().Features[1].SparkConfig)) + }) t.Run("sagemaker-config-test", func(t *testing.T) { diff --git a/go/tasks/plugins/k8s/spark/spark.go b/go/tasks/plugins/k8s/spark/spark.go index 4b962902c..84a3c06d1 100755 --- a/go/tasks/plugins/k8s/spark/spark.go +++ b/go/tasks/plugins/k8s/spark/spark.go @@ -3,17 +3,17 @@ package spark import ( "context" "fmt" - "time" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + "github.com/lyft/flyteplugins/go/tasks/errors" "github.com/lyft/flyteplugins/go/tasks/logs" pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" - "k8s.io/client-go/kubernetes/scheme" sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta1" @@ -22,7 +22,9 @@ import ( "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/lyft/flyteplugins/go/tasks/errors" + "regexp" + "strings" + "time" pluginsConfig "github.com/lyft/flyteplugins/go/tasks/config" ) @@ -31,12 +33,21 @@ const KindSparkApplication = "SparkApplication" const sparkDriverUI = "sparkDriverUI" const sparkHistoryUI = "sparkHistoryUI" +var featureRegex = regexp.MustCompile(`^spark.((lyft)|(flyte)).(.+).enabled$`) + var sparkTaskType = "spark" // Spark-specific configs type Config struct { DefaultSparkConfig map[string]string `json:"spark-config-default" pflag:",Key value pairs of default spark configuration that should be applied to every SparkJob"` SparkHistoryServerURL string `json:"spark-history-server-url" pflag:",URL for SparkHistory Server that each job will publish the execution history to."` + Features []Feature `json:"features" pflag:",List of optional features supported."` +} + +// Optional feature with name and corresponding spark-config to use. +type Feature struct { + Name string `json:"name"` + SparkConfig map[string]string `json:"spark-config"` } var ( @@ -139,7 +150,12 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } for k, v := range sparkJob.GetSparkConf() { - sparkConfig[k] = v + // Add optional features if present. + if featureRegex.MatchString(k) { + addConfig(sparkConfig, k, v) + } else { + sparkConfig[k] = v + } } // Set pod limits. @@ -184,6 +200,29 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo return j, nil } +func addConfig(sparkConfig map[string]string, key string, value string) { + + if strings.ToLower(strings.TrimSpace(value)) != "true" { + return + } + + matches := featureRegex.FindAllStringSubmatch(key, -1) + if len(matches) == 0 || len(matches[0]) == 0 { + return + } + featureName := matches[0][len(matches[0])-1] + // Use the first matching feature in-case of duplicates. + for _, feature := range GetSparkConfig().Features { + if feature.Name == featureName { + for k, v := range feature.SparkConfig { + sparkConfig[k] = v + } + break + } + + } +} + // Convert SparkJob ApplicationType to Operator CRD ApplicationType func getApplicationType(applicationType plugins.SparkApplication_Type) sparkOp.SparkApplicationType { switch applicationType { diff --git a/go/tasks/plugins/k8s/spark/spark_test.go b/go/tasks/plugins/k8s/spark/spark_test.go index 5707738b8..655940639 100755 --- a/go/tasks/plugins/k8s/spark/spark_test.go +++ b/go/tasks/plugins/k8s/spark/spark_test.go @@ -34,11 +34,13 @@ const sparkUIAddress = "spark-ui.flyte" var ( dummySparkConf = map[string]string{ - "spark.driver.memory": "500M", - "spark.driver.cores": "1", - "spark.executor.cores": "1", - "spark.executor.instances": "3", - "spark.executor.memory": "500M", + "spark.driver.memory": "500M", + "spark.driver.cores": "1", + "spark.executor.cores": "1", + "spark.executor.instances": "3", + "spark.executor.memory": "500M", + "spark.flyte.feature1.enabled": "true", + "spark.lyft.feature2.enabled": "true", } dummyEnvVars = []*core.KeyValuePair{ @@ -271,6 +273,19 @@ func TestBuildResourceSpark(t *testing.T) { // Case1: Valid Spark Task-Template taskTemplate := dummySparkTaskTemplate("blah-1") + // Set spark custom feature config. + assert.NoError(t, setSparkConfig(&Config{ + Features: []Feature{ + { + Name: "feature1", + SparkConfig: map[string]string{"spark.hadoop.feature1": "true"}, + }, + { + Name: "feature2", + SparkConfig: map[string]string{"spark.hadoop.feature2": "true"}, + }, + }, + })) resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate)) assert.Nil(t, err) @@ -285,10 +300,25 @@ func TestBuildResourceSpark(t *testing.T) { for confKey, confVal := range dummySparkConf { exists := false - for k, v := range sparkApp.Spec.SparkConf { - if k == confKey { - assert.Equal(t, v, confVal) - exists = true + + if featureRegex.MatchString(confKey) { + match := featureRegex.FindAllStringSubmatch(confKey, -1) + feature := match[0][len(match[0])-1] + assert.True(t, feature == "feature1" || feature == "feature2") + for k, v := range sparkApp.Spec.SparkConf { + key := "spark.hadoop." + feature + if k == key { + assert.Equal(t, v, "true") + exists = true + } + } + } else { + for k, v := range sparkApp.Spec.SparkConf { + + if k == confKey { + assert.Equal(t, v, confVal) + exists = true + } } } assert.True(t, exists) diff --git a/go/tasks/testdata/config.yaml b/go/tasks/testdata/config.yaml index 283ef9143..f33d3245c 100755 --- a/go/tasks/testdata/config.yaml +++ b/go/tasks/testdata/config.yaml @@ -65,6 +65,15 @@ plugins: - spark.hadoop.fs.s3a.multipart.threshold: "536870912" - spark.blacklist.enabled: "true" - spark.blacklist.timeout: "5m" + features: + - name: "feature1" + spark-config: + - spark.hadoop.feature1 : "true" + - spark.sql.feature1 : "true" + - name: "feature2" + spark-config: + - spark.hadoop.feature2: "true" + - spark.sql.feature2: "true" # Logging configuration logs: kubernetes-enabled: true