Skip to content

Commit

Permalink
Add ray test for container image overriding
Browse files Browse the repository at this point in the history
Signed-off-by: Fabio Graetz <[email protected]>
  • Loading branch information
fg91 committed Feb 9, 2024
1 parent 55b738f commit 7f934cb
Showing 1 changed file with 63 additions and 6 deletions.
69 changes: 63 additions & 6 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func dummyRayTaskTemplate(id string, rayJob *plugins.RayJob) *core.TaskTemplate
}
}

func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext {
func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionContext {
taskCtx := &mocks.TaskExecutionContext{}
inputReader := &pluginIOMocks.InputReader{}
inputReader.OnGetInputPrefixPath().Return("/input/prefix")
Expand Down Expand Up @@ -140,7 +140,7 @@ func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso
overrides := &mocks.TaskOverrides{}
overrides.OnGetResources().Return(resources)
overrides.OnGetExtendedResources().Return(extendedResources)
overrides.OnGetContainerImage().Return("")
overrides.OnGetContainerImage().Return(containerImage)

taskExecutionMetadata := &mocks.TaskExecutionMetadata{}
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
Expand Down Expand Up @@ -175,7 +175,7 @@ func TestBuildResourceRay(t *testing.T) {
err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration})
assert.Nil(t, err)

RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil))
RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, ""))
assert.Nil(t, err)

assert.NotNil(t, RayResource)
Expand Down Expand Up @@ -210,6 +210,63 @@ func TestBuildResourceRay(t *testing.T) {
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration)
}

func TestBuildResourceRayContainerImage(t *testing.T) {
assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{}))

fixtures := []struct {
name string
resources *corev1.ResourceRequirements
containerImageOverride string
}{
{
"without overrides",
&corev1.ResourceRequirements{
Limits: corev1.ResourceList{
flytek8s.ResourceNvidiaGPU: resource.MustParse("1"),
},
},
"",
},
{
"with overrides",
&corev1.ResourceRequirements{
Limits: corev1.ResourceList{
flytek8s.ResourceNvidiaGPU: resource.MustParse("1"),
},
},
"container-image-override",
},
}

for _, f := range fixtures {
t.Run(f.name, func(t *testing.T) {
taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj())
taskContext := dummyRayTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride)
rayJobResourceHandler := rayJobResourceHandler{}
r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
rayJob, ok := r.(*rayv1alpha1.RayJob)
assert.True(t, ok)

var expectedContainerImage string
if len(f.containerImageOverride) > 0 {
expectedContainerImage = f.containerImageOverride
} else {
expectedContainerImage = testImage
}

// Head node
headNodeSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec
assert.Equal(t, expectedContainerImage, headNodeSpec.Containers[0].Image)

// Worker node
workerNodeSpec := rayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec
assert.Equal(t, expectedContainerImage, workerNodeSpec.Containers[0].Image)
})
}
}

func TestBuildResourceRayExtendedResources(t *testing.T) {
assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
GpuDeviceNodeLabel: "gpu-node-label",
Expand Down Expand Up @@ -315,7 +372,7 @@ func TestBuildResourceRayExtendedResources(t *testing.T) {
t.Run(p.name, func(t *testing.T) {
taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj())
taskTemplate.ExtendedResources = p.extendedResourcesBase
taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride)
taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride, "")
rayJobResourceHandler := rayJobResourceHandler{}
r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
Expand Down Expand Up @@ -374,7 +431,7 @@ func TestDefaultStartParameters(t *testing.T) {
err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration})
assert.Nil(t, err)

RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil))
RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, ""))
assert.Nil(t, err)

assert.NotNil(t, RayResource)
Expand Down Expand Up @@ -582,7 +639,7 @@ func TestInjectLogsSidecar(t *testing.T) {
assert.NoError(t, SetConfig(&Config{
LogsSidecar: p.logsSidecarCfg,
}))
taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil)
taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil, "")
rayJobResourceHandler := rayJobResourceHandler{}
r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
Expand Down

0 comments on commit 7f934cb

Please sign in to comment.