Skip to content

Commit

Permalink
fix(ray): Use default svc account if not set in task metadata (#335)
Browse files Browse the repository at this point in the history
## Overview
The default service account is not being used if autoscale is enabled in the ray task.

## Test Plan
```python
config = RayJobConfig(
    worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10)],
    runtime_env={"pip": ["numpy"]},
    enable_autoscaling=True,
    shutdown_after_job_finishes=True,
    ttl_seconds_after_finished=20,
)
```

## Rollout Plan (if applicable)
staging -> canary -> production

## Upstream Changes
Should this change be upstreamed to OSS (flyteorg/flyte)? If not, please uncheck this box, which is used for auditing. Note, it is the responsibility of each developer to actually upstream their changes. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F).
- [x] To be upstreamed to OSS

## Issue
https://linear.app/unionai/issue/DX-791/use-default-svc-account-if-not-set-in-task-metadata
Customer Issue:
https://app.usepylon.com/issues?conversationID=df7bf253-beba-4c3f-9f97-5fc68daddab2

## Checklist
* [x] Added tests
* [x] Ran a deploy dry run and shared the terraform plan
* [ ] Added logging and metrics
* [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list)
* [ ] Updated documentation
  • Loading branch information
pingsutw authored Jun 21, 2024
1 parent 0f50412 commit 8db73eb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
3 changes: 3 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra
}

serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
if len(serviceAccountName) == 0 {
serviceAccountName = cfg.ServiceAccount
}

rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName
for index := range rayClusterSpec.WorkerGroupSpecs {
Expand Down
22 changes: 16 additions & 6 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func dummyRayTaskTemplate(id string, rayJob *plugins.RayJob) *core.TaskTemplate
}
}

func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionContext {
func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage, serviceAccount string) pluginsCore.TaskExecutionContext {
taskCtx := &mocks.TaskExecutionContext{}
inputReader := &pluginIOMocks.InputReader{}
inputReader.OnGetInputPrefixPath().Return("/input/prefix")
Expand Down Expand Up @@ -176,7 +176,8 @@ func TestBuildResourceRay(t *testing.T) {
err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration})
assert.Nil(t, err)

RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, ""))
rayCtx := dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount)
RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), rayCtx)
assert.Nil(t, err)

assert.NotNil(t, RayResource)
Expand Down Expand Up @@ -207,6 +208,15 @@ func TestBuildResourceRay(t *testing.T) {
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"})
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"})
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration)

// Make sure the default service account is being used if SA is not provided in the task context
rayCtx = dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", "")
RayResource, err = rayJobResourceHandler.BuildResource(context.TODO(), rayCtx)
assert.Nil(t, err)
assert.NotNil(t, RayResource)
ray, ok = RayResource.(*rayv1.RayJob)
assert.True(t, ok)
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, GetConfig().ServiceAccount)
}

func TestBuildResourceRayContainerImage(t *testing.T) {
Expand Down Expand Up @@ -240,7 +250,7 @@ func TestBuildResourceRayContainerImage(t *testing.T) {
for _, f := range fixtures {
t.Run(f.name, func(t *testing.T) {
taskTemplate := dummyRayTaskTemplate("id", dummyRayCustomObj())
taskContext := dummyRayTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride)
taskContext := dummyRayTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride, serviceAccount)
rayJobResourceHandler := rayJobResourceHandler{}
r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
Expand Down Expand Up @@ -371,7 +381,7 @@ func TestBuildResourceRayExtendedResources(t *testing.T) {
t.Run(p.name, func(t *testing.T) {
taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj())
taskTemplate.ExtendedResources = p.extendedResourcesBase
taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride, "")
taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride, "", serviceAccount)
rayJobResourceHandler := rayJobResourceHandler{}
r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
Expand Down Expand Up @@ -430,7 +440,7 @@ func TestDefaultStartParameters(t *testing.T) {
err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration})
assert.Nil(t, err)

RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, ""))
RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount))
assert.Nil(t, err)

assert.NotNil(t, RayResource)
Expand Down Expand Up @@ -637,7 +647,7 @@ func TestInjectLogsSidecar(t *testing.T) {
assert.NoError(t, SetConfig(&Config{
LogsSidecar: p.logsSidecarCfg,
}))
taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil, "")
taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil, "", serviceAccount)
rayJobResourceHandler := rayJobResourceHandler{}
r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
Expand Down

0 comments on commit 8db73eb

Please sign in to comment.