Skip to content

Commit

Permalink
Don't override sidecar tolerations (flyteorg#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Jul 20, 2020
1 parent ab09d3e commit 603aff9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
10 changes: 7 additions & 3 deletions flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type sidecarResourceHandler struct{}

// This method handles templatizing primary container input args, env variables and adds a GPU toleration to the pod
// spec if necessary.
func validateAndFinalizeContainers(
func validateAndFinalizePod(
ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, primaryContainerName string, pod k8sv1.Pod) (*k8sv1.Pod, error) {
var hasPrimaryContainer bool

Expand Down Expand Up @@ -61,7 +61,11 @@ func validateAndFinalizeContainers(

}
pod.Spec.Containers = finalizedContainers
pod.Spec.Tolerations = flytek8s.GetPodTolerations(taskCtx.TaskExecutionMetadata().IsInterruptible(), resReqs...)
if pod.Spec.Tolerations == nil {
pod.Spec.Tolerations = make([]k8sv1.Toleration, 0)
}
pod.Spec.Tolerations = append(
flytek8s.GetPodTolerations(taskCtx.TaskExecutionMetadata().IsInterruptible(), resReqs...), pod.Spec.Tolerations...)
if taskCtx.TaskExecutionMetadata().IsInterruptible() && len(config.GetK8sPluginConfig().InterruptibleNodeSelector) > 0 {
pod.Spec.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector
}
Expand Down Expand Up @@ -90,7 +94,7 @@ func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx plugins
// We want to Also update the serviceAccount to the serviceaccount of the workflow
pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount()

pod, err = validateAndFinalizeContainers(ctx, taskCtx, sidecarJob.PrimaryContainerName, *pod)
pod, err = validateAndFinalizePod(ctx, taskCtx, sidecarJob.PrimaryContainerName, *pod)
if err != nil {
return nil, err
}
Expand Down
23 changes: 23 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,29 @@ func TestBuildSidecarResource(t *testing.T) {
actualGpuLimit, ok := res.(*v1.Pod).Spec.Containers[0].Resources.Limits[ResourceNvidiaGPU]
assert.True(t, ok)
assert.True(t, expectedGpuLimit.Equal(actualGpuLimit))

// Assert volumes & volume mounts are preserved
assert.Len(t, res.(*v1.Pod).Spec.Volumes, 1)
assert.Equal(t, "dshm", res.(*v1.Pod).Spec.Volumes[0].Name)

assert.Len(t, res.(*v1.Pod).Spec.Containers[0].VolumeMounts, 1)
assert.Equal(t, "volume mount", res.(*v1.Pod).Spec.Containers[0].VolumeMounts[0].Name)

// Assert user-specified tolerations don't get overridden
assert.Len(t, res.(*v1.Pod).Spec.Tolerations, 2)
expectedTolerations := []v1.Toleration{
{
Key: "flyte/gpu",
Operator: "Equal",
Value: "dedicated",
Effect: "NoSchedule",
},
{
Key: "my toleration key",
Value: "my toleration value",
},
}
assert.EqualValues(t, expectedTolerations, res.(*v1.Pod).Spec.Tolerations)
}

func TestBuildSidecarResourceMissingPrimary(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@
}
},
"name": "dshm"
}]
}],
"tolerations": [{
"key": "my toleration key",
"value": "my toleration value"
}]
},
"primaryContainerName": "a container"
}

0 comments on commit 603aff9

Please sign in to comment.