Skip to content

Commit

Permalink
Enable pod template and Use copy to construct head/worker in ray plug…
Browse files Browse the repository at this point in the history
…in (flyteorg#349)

* Enable pod template and Use copy to construct head/worker in ray plugin

Signed-off-by: byhsu <[email protected]>

* fix linit

Signed-off-by: byhsu <[email protected]>

* wip

Signed-off-by: byhsu <[email protected]>

* fix test

Signed-off-by: byhsu <[email protected]>

---------

Signed-off-by: byhsu <[email protected]>
Co-authored-by: byhsu <[email protected]>
  • Loading branch information
ByronHsu and ByronHsu authored May 19, 2023
1 parent 9a2bbba commit 2afc441
Showing 1 changed file with 61 additions and 33 deletions.
94 changes: 61 additions & 33 deletions go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/logs"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
Expand All @@ -22,6 +21,7 @@ import (

v1 "k8s.io/api/core/v1"

flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors"
"sigs.k8s.io/controller-runtime/pkg/client"
)

Expand All @@ -44,20 +44,35 @@ func (rayJobResourceHandler) GetProperties() k8s.PluginProperties {
func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error())
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error())
} else if taskTemplate == nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "nil task specification")
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification")
}

rayJob := plugins.RayJob{}
err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &rayJob)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
}

container, err := flytek8s.ToK8sContainer(ctx, taskCtx)
podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)

if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "Unable to create container spec: [%v]", err.Error())
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

var container v1.Container
found := false
for _, c := range podSpec.Containers {
if c.Name == primaryContainerName {
container = c
found = true
break
}
}

if !found {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to get primary container from the pod: [%v]", err.Error())
}

headReplicas := int32(1)
Expand All @@ -78,7 +93,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
enableIngress := true
rayClusterSpec := rayv1alpha1.RayClusterSpec{
HeadGroupSpec: rayv1alpha1.HeadGroupSpec{
Template: buildHeadPodTemplate(container, taskCtx),
Template: buildHeadPodTemplate(&container, podSpec, objectMeta, taskCtx),
ServiceType: v1.ServiceType(GetConfig().ServiceType),
Replicas: &headReplicas,
EnableIngress: &enableIngress,
Expand All @@ -88,7 +103,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
}

for _, spec := range rayJob.RayCluster.WorkerGroupSpec {
workerPodTemplate := buildWorkerPodTemplate(container, taskCtx)
workerPodTemplate := buildWorkerPodTemplate(&container, podSpec, objectMeta, taskCtx)

minReplicas := spec.Replicas
maxReplicas := spec.Replicas
Expand Down Expand Up @@ -139,18 +154,20 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
Kind: KindRayJob,
APIVersion: rayv1alpha1.SchemeGroupVersion.String(),
},
Spec: jobSpec,
Spec: jobSpec,
ObjectMeta: *objectMeta,
}

return &rayJobObject, nil
}

func buildHeadPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
// Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97
// They should always be the same, so we could hard code here.
primaryContainer := &v1.Container{Name: "ray-head", Image: container.Image}
primaryContainer.Resources = container.Resources
primaryContainer.Env = []v1.EnvVar{
primaryContainer := container.DeepCopy()
primaryContainer.Name = "ray-head"

envs := []v1.EnvVar{
{
Name: "MY_POD_IP",
ValueFrom: &v1.EnvVarSource{
Expand All @@ -160,8 +177,12 @@ func buildHeadPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecu
},
},
}
primaryContainer.Env = append(primaryContainer.Env, container.Env...)
primaryContainer.Ports = []v1.ContainerPort{

primaryContainer.Args = []string{}

primaryContainer.Env = append(primaryContainer.Env, envs...)

ports := []v1.ContainerPort{
{
Name: "redis",
ContainerPort: 6379,
Expand All @@ -175,20 +196,23 @@ func buildHeadPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecu
ContainerPort: 8265,
},
}
pod := &v1.PodSpec{
Containers: []v1.Container{*primaryContainer},
}
flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{primaryContainer.Resources}, pod)

primaryContainer.Ports = append(primaryContainer.Ports, ports...)

headPodSpec := podSpec.DeepCopy()

headPodSpec.Containers = []v1.Container{*primaryContainer}

podTemplateSpec := v1.PodTemplateSpec{
Spec: *pod,
Spec: *headPodSpec,
ObjectMeta: *objectMeta,
}
podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())))
podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())))
return podTemplateSpec
}

func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
// Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185
// They should always be the same, so we could hard code here.
initContainers := []v1.Container{
Expand All @@ -203,10 +227,12 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe
Resources: container.Resources,
},
}
primaryContainer := container.DeepCopy()
primaryContainer.Name = "ray-worker"

primaryContainer := &v1.Container{Name: "ray-worker", Image: container.Image}
primaryContainer.Resources = container.Resources
primaryContainer.Env = []v1.EnvVar{
primaryContainer.Args = []string{}

envs := []v1.EnvVar{
{
Name: "RAY_DISABLE_DOCKER_CPU_WARNING",
Value: "1",
Expand Down Expand Up @@ -268,7 +294,9 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe
},
},
}
primaryContainer.Env = append(primaryContainer.Env, container.Env...)

primaryContainer.Env = append(primaryContainer.Env, envs...)

primaryContainer.Lifecycle = &v1.Lifecycle{
PreStop: &v1.LifecycleHandler{
Exec: &v1.ExecAction{
Expand All @@ -279,7 +307,7 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe
},
}

primaryContainer.Ports = []v1.ContainerPort{
ports := []v1.ContainerPort{
{
Name: "redis",
ContainerPort: 6379,
Expand All @@ -293,15 +321,15 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe
ContainerPort: 8265,
},
}
primaryContainer.Ports = append(primaryContainer.Ports, ports...)

pod := &v1.PodSpec{
Containers: []v1.Container{*primaryContainer},
InitContainers: initContainers,
}
flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{primaryContainer.Resources}, pod)
workerPodSpec := podSpec.DeepCopy()
workerPodSpec.Containers = []v1.Container{*primaryContainer}
workerPodSpec.InitContainers = initContainers

podTemplateSpec := v1.PodTemplateSpec{
Spec: *pod,
Spec: *workerPodSpec,
ObjectMeta: *objectMetadata,
}
podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())))
podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())))
Expand Down Expand Up @@ -350,7 +378,7 @@ func (rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s
return pluginsCore.PhaseInfoNotReady(time.Now(), pluginsCore.DefaultPhaseVersion, "job is pending"), nil
case rayv1alpha1.JobStatusFailed:
reason := fmt.Sprintf("Failed to create Ray job: %s", rayJob.Name)
return pluginsCore.PhaseInfoFailure(errors.TaskFailedWithError, reason, info), nil
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
case rayv1alpha1.JobStatusSucceeded:
return pluginsCore.PhaseInfoSuccess(info), nil
case rayv1alpha1.JobStatusRunning:
Expand Down

0 comments on commit 2afc441

Please sign in to comment.