diff --git a/go/tasks/plugins/k8s/ray/ray.go b/go/tasks/plugins/k8s/ray/ray.go index 614d1af9fe..f3fcaff38b 100644 --- a/go/tasks/plugins/k8s/ray/ray.go +++ b/go/tasks/plugins/k8s/ray/ray.go @@ -9,7 +9,6 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" - "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/logs" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -22,6 +21,7 @@ import ( v1 "k8s.io/api/core/v1" + flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -44,20 +44,35 @@ func (rayJobResourceHandler) GetProperties() k8s.PluginProperties { func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error()) + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error()) } else if taskTemplate == nil { - return nil, errors.Errorf(errors.BadTaskSpecification, "nil task specification") + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") } rayJob := plugins.RayJob{} err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &rayJob) if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - container, err := flytek8s.ToK8sContainer(ctx, taskCtx) + podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, "Unable to create container spec: [%v]", err.Error()) + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) + } + + var container v1.Container + found := false + for _, c := range podSpec.Containers { + if c.Name == primaryContainerName { + container = c + found = true + break + } + } + + if !found { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to get primary container from the pod: [%v]", err.Error()) } headReplicas := int32(1) @@ -78,7 +93,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC enableIngress := true rayClusterSpec := rayv1alpha1.RayClusterSpec{ HeadGroupSpec: rayv1alpha1.HeadGroupSpec{ - Template: buildHeadPodTemplate(container, taskCtx), + Template: buildHeadPodTemplate(&container, podSpec, objectMeta, taskCtx), ServiceType: v1.ServiceType(GetConfig().ServiceType), Replicas: &headReplicas, EnableIngress: &enableIngress, @@ -88,7 +103,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } for _, spec := range rayJob.RayCluster.WorkerGroupSpec { - workerPodTemplate := buildWorkerPodTemplate(container, taskCtx) + workerPodTemplate := buildWorkerPodTemplate(&container, podSpec, objectMeta, taskCtx) minReplicas := spec.Replicas maxReplicas := spec.Replicas @@ -139,18 +154,20 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC Kind: KindRayJob, APIVersion: rayv1alpha1.SchemeGroupVersion.String(), }, - Spec: jobSpec, + Spec: jobSpec, + ObjectMeta: *objectMeta, } return &rayJobObject, nil } -func buildHeadPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { +func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97 // They should always be the same, so we could hard code here. - primaryContainer := &v1.Container{Name: "ray-head", Image: container.Image} - primaryContainer.Resources = container.Resources - primaryContainer.Env = []v1.EnvVar{ + primaryContainer := container.DeepCopy() + primaryContainer.Name = "ray-head" + + envs := []v1.EnvVar{ { Name: "MY_POD_IP", ValueFrom: &v1.EnvVarSource{ @@ -160,8 +177,12 @@ func buildHeadPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecu }, }, } - primaryContainer.Env = append(primaryContainer.Env, container.Env...) - primaryContainer.Ports = []v1.ContainerPort{ + + primaryContainer.Args = []string{} + + primaryContainer.Env = append(primaryContainer.Env, envs...) + + ports := []v1.ContainerPort{ { Name: "redis", ContainerPort: 6379, @@ -175,20 +196,23 @@ func buildHeadPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecu ContainerPort: 8265, }, } - pod := &v1.PodSpec{ - Containers: []v1.Container{*primaryContainer}, - } - flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{primaryContainer.Resources}, pod) + + primaryContainer.Ports = append(primaryContainer.Ports, ports...) + + headPodSpec := podSpec.DeepCopy() + + headPodSpec.Containers = []v1.Container{*primaryContainer} podTemplateSpec := v1.PodTemplateSpec{ - Spec: *pod, + Spec: *headPodSpec, + ObjectMeta: *objectMeta, } podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) return podTemplateSpec } -func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { +func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185 // They should always be the same, so we could hard code here. initContainers := []v1.Container{ @@ -203,10 +227,12 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe Resources: container.Resources, }, } + primaryContainer := container.DeepCopy() + primaryContainer.Name = "ray-worker" - primaryContainer := &v1.Container{Name: "ray-worker", Image: container.Image} - primaryContainer.Resources = container.Resources - primaryContainer.Env = []v1.EnvVar{ + primaryContainer.Args = []string{} + + envs := []v1.EnvVar{ { Name: "RAY_DISABLE_DOCKER_CPU_WARNING", Value: "1", @@ -268,7 +294,9 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe }, }, } - primaryContainer.Env = append(primaryContainer.Env, container.Env...) + + primaryContainer.Env = append(primaryContainer.Env, envs...) + primaryContainer.Lifecycle = &v1.Lifecycle{ PreStop: &v1.LifecycleHandler{ Exec: &v1.ExecAction{ @@ -279,7 +307,7 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe }, } - primaryContainer.Ports = []v1.ContainerPort{ + ports := []v1.ContainerPort{ { Name: "redis", ContainerPort: 6379, @@ -293,15 +321,15 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe ContainerPort: 8265, }, } + primaryContainer.Ports = append(primaryContainer.Ports, ports...) - pod := &v1.PodSpec{ - Containers: []v1.Container{*primaryContainer}, - InitContainers: initContainers, - } - flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{primaryContainer.Resources}, pod) + workerPodSpec := podSpec.DeepCopy() + workerPodSpec.Containers = []v1.Container{*primaryContainer} + workerPodSpec.InitContainers = initContainers podTemplateSpec := v1.PodTemplateSpec{ - Spec: *pod, + Spec: *workerPodSpec, + ObjectMeta: *objectMetadata, } podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) @@ -350,7 +378,7 @@ func (rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s return pluginsCore.PhaseInfoNotReady(time.Now(), pluginsCore.DefaultPhaseVersion, "job is pending"), nil case rayv1alpha1.JobStatusFailed: reason := fmt.Sprintf("Failed to create Ray job: %s", rayJob.Name) - return pluginsCore.PhaseInfoFailure(errors.TaskFailedWithError, reason, info), nil + return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil case rayv1alpha1.JobStatusSucceeded: return pluginsCore.PhaseInfoSuccess(info), nil case rayv1alpha1.JobStatusRunning: