From 8db73eb641905b8639d00f3d1336e531a4bb3a51 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 22 Jun 2024 04:26:53 +0800 Subject: [PATCH] fix(ray): Use default svc account if not set in task metadata (#335) ## Overview The default service account is not being used if autoscale is enabled in the ray task. ## Test Plan ```python config = RayJobConfig( worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10)], runtime_env={"pip": ["numpy"]}, enable_autoscaling=True, shutdown_after_job_finishes=True, ttl_seconds_after_finished=20, ) ``` ## Rollout Plan (if applicable) staging -> canary -> production ## Upstream Changes Should this change be upstreamed to OSS (flyteorg/flyte)? If not, please uncheck this box, which is used for auditing. Note, it is the responsibility of each developer to actually upstream their changes. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F). - [x] To be upstreamed to OSS ## Issue https://linear.app/unionai/issue/DX-791/use-default-svc-account-if-not-set-in-task-metadata Customer Issue: https://app.usepylon.com/issues?conversationID=df7bf253-beba-4c3f-9f97-5fc68daddab2 ## Checklist * [x] Added tests * [x] Ran a deploy dry run and shared the terraform plan * [ ] Added logging and metrics * [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list) * [ ] Updated documentation --- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 3 +++ .../go/tasks/plugins/k8s/ray/ray_test.go | 22 ++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index ff0cfc6cd32..e39436002c0 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -190,6 +190,9 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.Ra } serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + if len(serviceAccountName) == 0 { + serviceAccountName = cfg.ServiceAccount + } rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName for index := range rayClusterSpec.WorkerGroupSpecs { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 253bfe8980b..acb8247f578 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -105,7 +105,7 @@ func dummyRayTaskTemplate(id string, rayJob *plugins.RayJob) *core.TaskTemplate } } -func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionContext { +func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage, serviceAccount string) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -176,7 +176,8 @@ func TestBuildResourceRay(t *testing.T) { err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) assert.Nil(t, err) - RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "")) + rayCtx := dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount) + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), rayCtx) assert.Nil(t, err) assert.NotNil(t, RayResource) @@ -207,6 +208,15 @@ func TestBuildResourceRay(t *testing.T) { assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) + + // Make sure the default service account is being used if SA is not provided in the task context + rayCtx = dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", "") + RayResource, err = rayJobResourceHandler.BuildResource(context.TODO(), rayCtx) + assert.Nil(t, err) + assert.NotNil(t, RayResource) + ray, ok = RayResource.(*rayv1.RayJob) + assert.True(t, ok) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, GetConfig().ServiceAccount) } func TestBuildResourceRayContainerImage(t *testing.T) { @@ -240,7 +250,7 @@ func TestBuildResourceRayContainerImage(t *testing.T) { for _, f := range fixtures { t.Run(f.name, func(t *testing.T) { taskTemplate := dummyRayTaskTemplate("id", dummyRayCustomObj()) - taskContext := dummyRayTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride) + taskContext := dummyRayTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride, serviceAccount) rayJobResourceHandler := rayJobResourceHandler{} r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -371,7 +381,7 @@ func TestBuildResourceRayExtendedResources(t *testing.T) { t.Run(p.name, func(t *testing.T) { taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) taskTemplate.ExtendedResources = p.extendedResourcesBase - taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride, "") + taskContext := dummyRayTaskContext(taskTemplate, p.resources, p.extendedResourcesOverride, "", serviceAccount) rayJobResourceHandler := rayJobResourceHandler{} r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -430,7 +440,7 @@ func TestDefaultStartParameters(t *testing.T) { err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) assert.Nil(t, err) - RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "")) + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount)) assert.Nil(t, err) assert.NotNil(t, RayResource) @@ -637,7 +647,7 @@ func TestInjectLogsSidecar(t *testing.T) { assert.NoError(t, SetConfig(&Config{ LogsSidecar: p.logsSidecarCfg, })) - taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil, "") + taskContext := dummyRayTaskContext(&p.taskTemplate, resourceRequirements, nil, "", serviceAccount) rayJobResourceHandler := rayJobResourceHandler{} r, err := rayJobResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err)