diff --git a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go index 9669e11f8b..18b11f27fd 100755 --- a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go +++ b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go @@ -29,7 +29,7 @@ type sidecarResourceHandler struct{} // This method handles templatizing primary container input args, env variables and adds a GPU toleration to the pod // spec if necessary. -func validateAndFinalizeContainers( +func validateAndFinalizePod( ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, primaryContainerName string, pod k8sv1.Pod) (*k8sv1.Pod, error) { var hasPrimaryContainer bool @@ -61,7 +61,11 @@ func validateAndFinalizeContainers( } pod.Spec.Containers = finalizedContainers - pod.Spec.Tolerations = flytek8s.GetPodTolerations(taskCtx.TaskExecutionMetadata().IsInterruptible(), resReqs...) + if pod.Spec.Tolerations == nil { + pod.Spec.Tolerations = make([]k8sv1.Toleration, 0) + } + pod.Spec.Tolerations = append( + flytek8s.GetPodTolerations(taskCtx.TaskExecutionMetadata().IsInterruptible(), resReqs...), pod.Spec.Tolerations...) if taskCtx.TaskExecutionMetadata().IsInterruptible() && len(config.GetK8sPluginConfig().InterruptibleNodeSelector) > 0 { pod.Spec.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector } @@ -90,7 +94,7 @@ func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx plugins // We want to Also update the serviceAccount to the serviceaccount of the workflow pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount() - pod, err = validateAndFinalizeContainers(ctx, taskCtx, sidecarJob.PrimaryContainerName, *pod) + pod, err = validateAndFinalizePod(ctx, taskCtx, sidecarJob.PrimaryContainerName, *pod) if err != nil { return nil, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go index 8ddc2cccc4..57fc71d819 100755 --- a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go @@ -170,6 +170,29 @@ func TestBuildSidecarResource(t *testing.T) { actualGpuLimit, ok := res.(*v1.Pod).Spec.Containers[0].Resources.Limits[ResourceNvidiaGPU] assert.True(t, ok) assert.True(t, expectedGpuLimit.Equal(actualGpuLimit)) + + // Assert volumes & volume mounts are preserved + assert.Len(t, res.(*v1.Pod).Spec.Volumes, 1) + assert.Equal(t, "dshm", res.(*v1.Pod).Spec.Volumes[0].Name) + + assert.Len(t, res.(*v1.Pod).Spec.Containers[0].VolumeMounts, 1) + assert.Equal(t, "volume mount", res.(*v1.Pod).Spec.Containers[0].VolumeMounts[0].Name) + + // Assert user-specified tolerations don't get overridden + assert.Len(t, res.(*v1.Pod).Spec.Tolerations, 2) + expectedTolerations := []v1.Toleration{ + { + Key: "flyte/gpu", + Operator: "Equal", + Value: "dedicated", + Effect: "NoSchedule", + }, + { + Key: "my toleration key", + Value: "my toleration value", + }, + } + assert.EqualValues(t, expectedTolerations, res.(*v1.Pod).Spec.Tolerations) } func TestBuildSidecarResourceMissingPrimary(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/k8s/sidecar/testdata/sidecar_custom b/flyteplugins/go/tasks/plugins/k8s/sidecar/testdata/sidecar_custom index 499fc765ef..b2bf2c574c 100755 --- a/flyteplugins/go/tasks/plugins/k8s/sidecar/testdata/sidecar_custom +++ b/flyteplugins/go/tasks/plugins/k8s/sidecar/testdata/sidecar_custom @@ -56,7 +56,11 @@ } }, "name": "dshm" - }] + }], + "tolerations": [{ + "key": "my toleration key", + "value": "my toleration value" + }] }, "primaryContainerName": "a container" }