diff --git a/go.mod b/go.mod index e2952d6433..bd4d424637 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( golang.org/x/crypto v0.0.0-20200204104054-c9f3fb736b72 // indirect golang.org/x/net v0.0.0-20200202094626-16171245cfb2 golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 // indirect + golang.org/x/tools v0.0.0-20200124170513-3f4d10fc73b4 google.golang.org/api v0.16.0 // indirect google.golang.org/genproto v0.0.0-20200205142000-a86caf926a67 // indirect google.golang.org/grpc v1.27.1 diff --git a/go.sum b/go.sum index 0e2ec3d71e..9910eb12f9 100644 --- a/go.sum +++ b/go.sum @@ -611,6 +611,7 @@ golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200124170513-3f4d10fc73b4 h1:BPUNhs1Rsd9Ly0hbjDwBxaNBrAyo/CKpkMcA3pkTwgg= golang.org/x/tools v0.0.0-20200124170513-3f4d10fc73b4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 h1:/atklqdjdhuosWIl6AIbOeHJjicWYPqR9bpxqxYG2pA= diff --git a/go/tasks/config_load_test.go b/go/tasks/config_load_test.go index 0c1627a279..57d17d6fac 100755 --- a/go/tasks/config_load_test.go +++ b/go/tasks/config_load_test.go @@ -69,6 +69,9 @@ func TestLoadConfig(t *testing.T) { assert.Equal(t, []v1.Toleration{tolStorage}, k8sConfig.ResourceTolerations[v1.ResourceStorage]) assert.Equal(t, "1000m", k8sConfig.DefaultCPURequest) assert.Equal(t, "1024Mi", k8sConfig.DefaultMemoryRequest) + assert.Equal(t, map[string]string{"x/interruptible": "true"}, k8sConfig.InterruptibleNodeSelector) + assert.Equal(t, "x/flyte", k8sConfig.InterruptibleTolerations[0].Key) + assert.Equal(t, "interruptible", k8sConfig.InterruptibleTolerations[0].Value) }) t.Run("logs-config-test", func(t *testing.T) { @@ -80,5 +83,4 @@ func TestLoadConfig(t *testing.T) { assert.NotNil(t, spark.GetSparkConfig()) assert.NotNil(t, spark.GetSparkConfig().DefaultSparkConfig) }) - } diff --git a/go/tasks/pluginmachinery/core/exec_metadata.go b/go/tasks/pluginmachinery/core/exec_metadata.go index 22a3fec14e..1f697b2e1d 100644 --- a/go/tasks/pluginmachinery/core/exec_metadata.go +++ b/go/tasks/pluginmachinery/core/exec_metadata.go @@ -32,4 +32,5 @@ type TaskExecutionMetadata interface { GetLabels() map[string]string GetAnnotations() map[string]string GetK8sServiceAccount() string + IsInterruptible() bool } diff --git a/go/tasks/pluginmachinery/core/mocks/resource_registrar.go b/go/tasks/pluginmachinery/core/mocks/resource_registrar.go index fdfaf11a9c..c2707c49d5 100644 --- a/go/tasks/pluginmachinery/core/mocks/resource_registrar.go +++ b/go/tasks/pluginmachinery/core/mocks/resource_registrar.go @@ -14,11 +14,6 @@ type ResourceRegistrar struct { mock.Mock } -// RegisterResourceNamespaceQuotaProportionCap provides a mock function with given fields: ctx, proportionCap -func (_m *ResourceRegistrar) RegisterResourceNamespaceQuotaProportionCap(ctx context.Context, proportionCap float64) { - _m.Called(ctx, proportionCap) -} - type ResourceRegistrar_RegisterResourceQuota struct { *mock.Call } diff --git a/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go b/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go index c7d7a0b886..be337961c1 100644 --- a/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go +++ b/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go @@ -279,3 +279,35 @@ func (_m *TaskExecutionMetadata) GetTaskExecutionID() core.TaskExecutionID { return r0 } + +type TaskExecutionMetadata_IsInterruptible struct { + *mock.Call +} + +func (_m TaskExecutionMetadata_IsInterruptible) Return(_a0 bool) *TaskExecutionMetadata_IsInterruptible { + return &TaskExecutionMetadata_IsInterruptible{Call: _m.Call.Return(_a0)} +} + +func (_m *TaskExecutionMetadata) OnIsInterruptible() *TaskExecutionMetadata_IsInterruptible { + c := _m.On("IsInterruptible") + return &TaskExecutionMetadata_IsInterruptible{Call: c} +} + +func (_m *TaskExecutionMetadata) OnIsInterruptibleMatch(matchers ...interface{}) *TaskExecutionMetadata_IsInterruptible { + c := _m.On("IsInterruptible", matchers...) + return &TaskExecutionMetadata_IsInterruptible{Call: c} +} + +// IsInterruptible provides a mock function with given fields: +func (_m *TaskExecutionMetadata) IsInterruptible() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/go/tasks/pluginmachinery/flytek8s/config/config.go b/go/tasks/pluginmachinery/flytek8s/config/config.go index f703120a6b..5aba439c18 100755 --- a/go/tasks/pluginmachinery/flytek8s/config/config.go +++ b/go/tasks/pluginmachinery/flytek8s/config/config.go @@ -43,6 +43,12 @@ type K8sPluginConfig struct { DefaultCPURequest string `json:"default-cpus" pflag:",Defines a default value for cpu for containers if not specified."` // default memory requests for a container DefaultMemoryRequest string `json:"default-memory" pflag:",Defines a default value for memory for containers if not specified."` + // Tolerations for interruptible k8s pods: These tolerations are added to the pods that can tolerate getting evicted from a node. We + // can leverage this for better bin-packing and using low-reliability cheaper machines. + InterruptibleTolerations []v1.Toleration `json:"interruptible-tolerations" pflag:"-,Tolerations to be applied for interruptible pods"` + // Node Selector Labels for interruptible pods: Similar to InterruptibleTolerations, these node selector labels are added for pods that can tolerate + // eviction. + InterruptibleNodeSelector map[string]string `json:"interruptible-node-selector" pflag:"-,Defines a set of node selector labels to add to the interruptible pods."` } // Retrieves the current k8s plugin config or default. diff --git a/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds.go b/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds.go index 8d758237ee..6195da8e43 100755 --- a/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds.go +++ b/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds.go @@ -124,7 +124,8 @@ func DecorateEnvVars(ctx context.Context, envVars []v1.EnvVar, id pluginsCore.Ta return envVars } -func GetTolerationsForResources(resourceRequirements ...v1.ResourceRequirements) []v1.Toleration { +func GetPodTolerations(interruptible bool, resourceRequirements ...v1.ResourceRequirements) []v1.Toleration { + // 1. Get the tolerations for the resources requested var tolerations []v1.Toleration resourceNames := sets.NewString() for _, resources := range resourceRequirements { @@ -141,5 +142,11 @@ func GetTolerationsForResources(resourceRequirements ...v1.ResourceRequirements) tolerations = append(tolerations, v...) } } + + // 2. Get the tolerations for interruptible pods + if interruptible && len(config.GetK8sPluginConfig().InterruptibleTolerations) > 0 { + tolerations = append(tolerations, config.GetK8sPluginConfig().InterruptibleTolerations...) + } + return tolerations } diff --git a/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go b/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go index da1f7edcca..6e4d799ac1 100755 --- a/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go +++ b/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go @@ -171,8 +171,8 @@ func TestGetTolerationsForResources(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ResourceTolerations: tt.setVal})) - if got := GetTolerationsForResources(tt.args.resources); len(got) != len(tt.want) { - t.Errorf("GetTolerationsForResources() = %v, want %v", got, tt.want) + if got := GetPodTolerations(true, tt.args.resources); len(got) != len(tt.want) { + t.Errorf("GetPodTolerations() = %v, want %v", got, tt.want) } else { for _, tol := range tt.want { assert.Contains(t, got, tol) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 2a21ca62e0..589f2c5694 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -11,11 +11,14 @@ import ( v12 "k8s.io/apimachinery/pkg/apis/meta/v1" pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" ) const PodKind = "pod" const OOMKilled = "OOMKilled" +const Interrupted = "Interrupted" +const SIGKILL = 137 func ToK8sPodSpec(ctx context.Context, taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskReader pluginsCore.TaskReader, inputs io.InputReader, outputPaths io.OutputFilePaths) (*v1.PodSpec, error) { @@ -32,13 +35,24 @@ func ToK8sPodSpec(ctx context.Context, taskExecutionMetadata pluginsCore.TaskExe containers := []v1.Container{ *c, } + if taskExecutionMetadata.IsInterruptible() && len(config.GetK8sPluginConfig().InterruptibleNodeSelector) > 0 { + return &v1.PodSpec{ + // We could specify Scheduler, Affinity, nodename etc + RestartPolicy: v1.RestartPolicyNever, + Containers: containers, + Tolerations: GetPodTolerations(taskExecutionMetadata.IsInterruptible(), c.Resources), + ServiceAccountName: taskExecutionMetadata.GetK8sServiceAccount(), + NodeSelector: config.GetK8sPluginConfig().InterruptibleNodeSelector, + }, nil + } return &v1.PodSpec{ // We could specify Scheduler, Affinity, nodename etc RestartPolicy: v1.RestartPolicyNever, Containers: containers, - Tolerations: GetTolerationsForResources(c.Resources), + Tolerations: GetPodTolerations(taskExecutionMetadata.IsInterruptible(), c.Resources), ServiceAccountName: taskExecutionMetadata.GetK8sServiceAccount(), }, nil + } func BuildPodWithSpec(podSpec *v1.PodSpec) *v1.Pod { @@ -180,9 +194,9 @@ func DemystifySuccess(status v1.PodStatus, info pluginsCore.TaskInfo) (pluginsCo return pluginsCore.PhaseInfoSuccess(&info), nil } -func ConvertPodFailureToError(status v1.PodStatus) (string, string) { - code := "UnknownError" - message := "Container/Pod failed. No message received from kubernetes." +func ConvertPodFailureToError(status v1.PodStatus) (code, message string) { + code = "UnknownError" + message = "Container/Pod failed. No message received from kubernetes." if len(status.Reason) > 0 { code = status.Reason } @@ -202,7 +216,12 @@ func ConvertPodFailureToError(status v1.PodStatus) (string, string) { if containerState.Terminated != nil { if strings.Contains(c.State.Terminated.Reason, OOMKilled) { code = OOMKilled + } else if containerState.Terminated.ExitCode == SIGKILL { + // in some setups, node termination sends SIGKILL to all the containers running on that node. Capturing and + // tagging that correctly. + code = Interrupted } + message += fmt.Sprintf("\r\nContainer [%v] terminated with exit code (%v). Reason [%v]. Message: [%v].", c.Name, containerState.Terminated.ExitCode, diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index e729d22eff..22b3ae363c 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -4,6 +4,9 @@ import ( "context" "testing" + config1 "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/config/viper" + "github.com/lyft/flytestdlib/storage" "github.com/stretchr/testify/mock" @@ -47,7 +50,7 @@ func dummyTaskExecutionMetadata(resources *v1.ResourceRequirements) pluginsCore. to := &pluginsCoreMock.TaskOverrides{} to.On("GetResources").Return(resources) taskExecutionMetadata.On("GetOverrides").Return(to) - + taskExecutionMetadata.On("IsInterruptible").Return(true) return taskExecutionMetadata } @@ -74,6 +77,39 @@ func dummyInputReader() io.InputReader { return inputReader } +func TestToK8sPodIterruptible(t *testing.T) { + ctx := context.TODO() + configAccessor := viper.NewAccessor(config1.Options{ + StrictMode: true, + SearchPaths: []string{"testdata/config.yaml"}, + }) + err := configAccessor.UpdateConfig(context.TODO()) + assert.NoError(t, err) + + op := &pluginsIOMock.OutputFilePaths{} + op.On("GetOutputPrefixPath").Return(storage.DataReference("")) + + x := dummyTaskExecutionMetadata(&v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1024m"), + v1.ResourceStorage: resource.MustParse("100M"), + ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1024m"), + v1.ResourceStorage: resource.MustParse("100M"), + }, + }) + + p, err := ToK8sPodSpec(ctx, x, dummyTaskReader(), dummyInputReader(), op) + assert.NoError(t, err) + assert.Equal(t, 2, len(p.Tolerations)) + assert.Equal(t, "x/flyte", p.Tolerations[1].Key) + assert.Equal(t, "interruptible", p.Tolerations[1].Value) + assert.Equal(t, 1, len(p.NodeSelector)) + assert.Equal(t, "true", p.NodeSelector["x/interruptible"]) +} + func TestToK8sPod(t *testing.T) { ctx := context.TODO() diff --git a/go/tasks/pluginmachinery/flytek8s/testdata/config.yaml b/go/tasks/pluginmachinery/flytek8s/testdata/config.yaml new file mode 100644 index 0000000000..c47a4a05ac --- /dev/null +++ b/go/tasks/pluginmachinery/flytek8s/testdata/config.yaml @@ -0,0 +1,37 @@ +# Sample plugins config +plugins: + # Set of enabled plugins at root level + enabled-plugins: + - container + # All k8s plugins default configuration + k8s: + default-annotations: + - annotationKey1: annotationValue1 + - annotationKey2: annotationValue2 + default-labels: + - label1: labelValue1 + - label2: labelValue2 + resource-tolerations: + nvidia.com/gpu: + key: flyte/gpu + value: dedicated + operator: Equal + effect: NoSchedule + storage: + - key: storage + value: special + operator: Equal + effect: PreferNoSchedule + interruptible-node-selector: + - x/interruptible: "true" + interruptible-tolerations: + - key: x/flyte + value: interruptible + operator: Equal + effect: NoSchedule + default-env-vars: + - AWS_METADATA_SERVICE_TIMEOUT: 5 + - AWS_METADATA_SERVICE_NUM_ATTEMPTS: 20 + - FLYTE_AWS_ENDPOINT: "http://minio.flyte:9000" + - FLYTE_AWS_ACCESS_KEY_ID: minio + - FLYTE_AWS_SECRET_ACCESS_KEY: miniostorage diff --git a/go/tasks/plugins/hive/test_helpers.go b/go/tasks/plugins/hive/test_helpers.go index 3429f5777c..c97a417dd5 100644 --- a/go/tasks/plugins/hive/test_helpers.go +++ b/go/tasks/plugins/hive/test_helpers.go @@ -68,6 +68,7 @@ func GetMockTaskExecutionMetadata() core.TaskExecutionMetadata { Kind: "node", Name: "blah", }) + taskMetadata.On("IsInterruptible").Return(true) taskMetadata.On("GetK8sServiceAccount").Return("service-account") taskMetadata.On("GetOwnerID").Return(types.NamespacedName{ Namespace: "test-namespace", diff --git a/go/tasks/plugins/k8s/container/container_test.go b/go/tasks/plugins/k8s/container/container_test.go index 3a4364f8cf..9f7288aadf 100755 --- a/go/tasks/plugins/k8s/container/container_test.go +++ b/go/tasks/plugins/k8s/container/container_test.go @@ -60,7 +60,7 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore. to := &pluginsCoreMock.TaskOverrides{} to.On("GetResources").Return(resources) taskMetadata.On("GetOverrides").Return(to) - + taskMetadata.On("IsInterruptible").Return(true) return taskMetadata } diff --git a/go/tasks/plugins/k8s/sidecar/sidecar.go b/go/tasks/plugins/k8s/sidecar/sidecar.go index f86ddc726e..397c65c487 100755 --- a/go/tasks/plugins/k8s/sidecar/sidecar.go +++ b/go/tasks/plugins/k8s/sidecar/sidecar.go @@ -8,6 +8,7 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" @@ -60,7 +61,11 @@ func validateAndFinalizeContainers( } pod.Spec.Containers = finalizedContainers - pod.Spec.Tolerations = flytek8s.GetTolerationsForResources(resReqs...) + pod.Spec.Tolerations = flytek8s.GetPodTolerations(taskCtx.TaskExecutionMetadata().IsInterruptible(), resReqs...) + if taskCtx.TaskExecutionMetadata().IsInterruptible() && len(config.GetK8sPluginConfig().InterruptibleNodeSelector) > 0 { + pod.Spec.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector + } + return &pod, nil } diff --git a/go/tasks/plugins/k8s/sidecar/sidecar_test.go b/go/tasks/plugins/k8s/sidecar/sidecar_test.go index f6ed1d4543..ff4ca9efad 100755 --- a/go/tasks/plugins/k8s/sidecar/sidecar_test.go +++ b/go/tasks/plugins/k8s/sidecar/sidecar_test.go @@ -64,6 +64,7 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore. Kind: "node", Name: "blah", }) + taskMetadata.On("IsInterruptible").Return(true) taskMetadata.On("GetK8sServiceAccount").Return("service-account") taskMetadata.On("GetOwnerID").Return(types.NamespacedName{ Namespace: "test-namespace", diff --git a/go/tasks/plugins/k8s/spark/spark_test.go b/go/tasks/plugins/k8s/spark/spark_test.go index c74cc3cc97..cd23231ab4 100755 --- a/go/tasks/plugins/k8s/spark/spark_test.go +++ b/go/tasks/plugins/k8s/spark/spark_test.go @@ -251,7 +251,7 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExec Kind: "node", Name: "blah", }) - + taskExecutionMetadata.On("IsInterruptible").Return(true) taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) return taskCtx } diff --git a/go/tasks/testdata/config.yaml b/go/tasks/testdata/config.yaml index fa1aef75d8..6b8cad1a9b 100755 --- a/go/tasks/testdata/config.yaml +++ b/go/tasks/testdata/config.yaml @@ -23,6 +23,13 @@ plugins: value: special operator: Equal effect: PreferNoSchedule + interruptible-node-selector: + - x/interruptible: "true" + interruptible-tolerations: + - key: x/flyte + value: interruptible + operator: Equal + effect: NoSchedule default-env-vars: - AWS_METADATA_SERVICE_TIMEOUT: 5 - AWS_METADATA_SERVICE_NUM_ATTEMPTS: 20 diff --git a/tests/end_to_end.go b/tests/end_to_end.go index cd9c64edf2..fe56c4df74 100644 --- a/tests/end_to_end.go +++ b/tests/end_to_end.go @@ -158,6 +158,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i tMeta.OnGetNamespace().Return("fake-development") tMeta.OnGetLabels().Return(map[string]string{}) tMeta.OnGetAnnotations().Return(map[string]string{}) + tMeta.OnIsInterruptible().Return(true) tMeta.OnGetOwnerReference().Return(v12.OwnerReference{}) tMeta.OnGetOwnerID().Return(types.NamespacedName{ Namespace: "fake-development",