Skip to content

Commit

Permalink
Leverage KubeRay v1 instead of v1alpha1 for resources (#4818)
Browse files Browse the repository at this point in the history
* initial

* Clean up

* init

* Clean up

* Add TestGetEventInfo_LogTemplatesV1

* Add more tests

* Fix tests

* Remove dupe

* Fix lint

* add comment

Signed-off-by: peterghaddad <[email protected]>

---------

Signed-off-by: peterghaddad <[email protected]>
Co-authored-by: Neil <[email protected]>
  • Loading branch information
peterghaddad and neilisaur authored Feb 13, 2024
1 parent fe1204c commit d6747c1
Show file tree
Hide file tree
Showing 3 changed files with 438 additions and 14 deletions.
2 changes: 2 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/ray/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ var (
IncludeDashboard: true,
DashboardHost: "0.0.0.0",
EnableUsageStats: false,
KubeRayCrdVersion: "v1alpha1",
Defaults: DefaultConfig{
HeadNode: NodeConfig{
StartParameters: map[string]string{
Expand Down Expand Up @@ -85,6 +86,7 @@ type Config struct {
DashboardURLTemplate *tasklog.TemplateLogPlugin `json:"dashboardURLTemplate" pflag:"-,Template for URL of Ray dashboard running on a head node."`
Defaults DefaultConfig `json:"defaults" pflag:"-,Default configuration for ray jobs"`
EnableUsageStats bool `json:"enableUsageStats" pflag:",Enable usage stats for ray jobs. These stats are submitted to usage-stats.ray.io per https://docs.ray.io/en/latest/cluster/usage-stats.html"`
KubeRayCrdVersion string `json:"kubeRayCrdVersion" pflag:",Version of the Ray CRD to use when creating RayClusters or RayJobs."`
}

type DefaultConfig struct {
Expand Down
254 changes: 240 additions & 14 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"
"time"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
rayv1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -28,14 +29,15 @@ import (
)

const (
rayStateMountPath = "/tmp/ray"
defaultRayStateVolName = "system-ray-state"
rayTaskType = "ray"
KindRayJob = "RayJob"
IncludeDashboard = "include-dashboard"
NodeIPAddress = "node-ip-address"
DashboardHost = "dashboard-host"
DisableUsageStatsStartParameter = "disable-usage-stats"
rayStateMountPath = "/tmp/ray"
defaultRayStateVolName = "system-ray-state"
rayTaskType = "ray"
KindRayJob = "RayJob"
IncludeDashboard = "include-dashboard"
NodeIPAddress = "node-ip-address"
DashboardHost = "dashboard-host"
DisableUsageStatsStartParameter = "disable-usage-stats"
DisableUsageStatsStartParameterVal = "true"
)

var logTemplateRegexes = struct {
Expand All @@ -52,7 +54,7 @@ func (rayJobResourceHandler) GetProperties() k8s.PluginProperties {
return k8s.PluginProperties{}
}

// BuildResource Creates a new ray job resource.
// BuildResource Creates a new ray job resource for v1 or v1alpha1.
func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
Expand Down Expand Up @@ -109,11 +111,22 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
}

if _, exists := headNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats {
headNodeRayStartParams[DisableUsageStatsStartParameter] = "true"
headNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal
}

enableIngress := true
headPodSpec := podSpec.DeepCopy()

if cfg.KubeRayCrdVersion == "v1" {
return constructV1Job(taskCtx, rayJob, objectMeta, *podSpec, headPodSpec, headReplicas, headNodeRayStartParams, primaryContainerIdx, *primaryContainer), nil
}

return constructV1Alpha1Job(taskCtx, rayJob, objectMeta, *podSpec, headPodSpec, headReplicas, headNodeRayStartParams, primaryContainerIdx, *primaryContainer), nil

}

func constructV1Alpha1Job(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.RayJob, objectMeta *metav1.ObjectMeta, podSpec v1.PodSpec, headPodSpec *v1.PodSpec, headReplicas int32, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) *rayv1alpha1.RayJob {
enableIngress := true
cfg := GetConfig()
rayClusterSpec := rayv1alpha1.RayClusterSpec{
HeadGroupSpec: rayv1alpha1.HeadGroupSpec{
Template: buildHeadPodTemplate(
Expand Down Expand Up @@ -152,7 +165,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
}

if _, exists := workerNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats {
workerNodeRayStartParams[DisableUsageStatsStartParameter] = "true"
workerNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal
}

minReplicas := spec.MinReplicas
Expand Down Expand Up @@ -198,16 +211,111 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
RuntimeEnv: rayJob.RuntimeEnv,
}

rayJobObject := rayv1alpha1.RayJob{
return &rayv1alpha1.RayJob{
TypeMeta: metav1.TypeMeta{
Kind: KindRayJob,
APIVersion: rayv1alpha1.SchemeGroupVersion.String(),
},
Spec: jobSpec,
ObjectMeta: *objectMeta,
}
}

func constructV1Job(taskCtx pluginsCore.TaskExecutionContext, rayJob plugins.RayJob, objectMeta *metav1.ObjectMeta, podSpec v1.PodSpec, headPodSpec *v1.PodSpec, headReplicas int32, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) *rayv1.RayJob {
enableIngress := true
cfg := GetConfig()
rayClusterSpec := rayv1.RayClusterSpec{
HeadGroupSpec: rayv1.HeadGroupSpec{
Template: buildHeadPodTemplate(
&headPodSpec.Containers[primaryContainerIdx],
headPodSpec,
objectMeta,
taskCtx,
),
ServiceType: v1.ServiceType(cfg.ServiceType),
Replicas: &headReplicas,
EnableIngress: &enableIngress,
RayStartParams: headNodeRayStartParams,
},
WorkerGroupSpecs: []rayv1.WorkerGroupSpec{},
EnableInTreeAutoscaling: &rayJob.RayCluster.EnableAutoscaling,
}

for _, spec := range rayJob.RayCluster.WorkerGroupSpec {
workerPodSpec := podSpec.DeepCopy()
workerPodTemplate := buildWorkerPodTemplate(
&workerPodSpec.Containers[primaryContainerIdx],
workerPodSpec,
objectMeta,
taskCtx,
)

workerNodeRayStartParams := make(map[string]string)
if spec.RayStartParams != nil {
workerNodeRayStartParams = spec.RayStartParams
} else if workerNode := cfg.Defaults.WorkerNode; len(workerNode.StartParameters) > 0 {
workerNodeRayStartParams = workerNode.StartParameters
}

if _, exist := workerNodeRayStartParams[NodeIPAddress]; !exist {
workerNodeRayStartParams[NodeIPAddress] = cfg.Defaults.WorkerNode.IPAddress
}

if _, exists := workerNodeRayStartParams[DisableUsageStatsStartParameter]; !exists && !cfg.EnableUsageStats {
workerNodeRayStartParams[DisableUsageStatsStartParameter] = DisableUsageStatsStartParameterVal
}

return &rayJobObject, nil
minReplicas := spec.MinReplicas
if minReplicas > spec.Replicas {
minReplicas = spec.Replicas
}
maxReplicas := spec.MaxReplicas
if maxReplicas < spec.Replicas {
maxReplicas = spec.Replicas
}

workerNodeSpec := rayv1.WorkerGroupSpec{
GroupName: spec.GroupName,
MinReplicas: &minReplicas,
MaxReplicas: &maxReplicas,
Replicas: &spec.Replicas,
RayStartParams: workerNodeRayStartParams,
Template: workerPodTemplate,
}

rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec)
}

serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName
for index := range rayClusterSpec.WorkerGroupSpecs {
rayClusterSpec.WorkerGroupSpecs[index].Template.Spec.ServiceAccountName = serviceAccountName
}

shutdownAfterJobFinishes := cfg.ShutdownAfterJobFinishes
ttlSecondsAfterFinished := &cfg.TTLSecondsAfterFinished
if rayJob.ShutdownAfterJobFinishes {
shutdownAfterJobFinishes = true
ttlSecondsAfterFinished = &rayJob.TtlSecondsAfterFinished
}

jobSpec := rayv1.RayJobSpec{
RayClusterSpec: &rayClusterSpec,
Entrypoint: strings.Join(primaryContainer.Args, " "),
ShutdownAfterJobFinishes: shutdownAfterJobFinishes,
TTLSecondsAfterFinished: ttlSecondsAfterFinished,
RuntimeEnv: rayJob.RuntimeEnv,
}

return &rayv1.RayJob{
TypeMeta: metav1.TypeMeta{
Kind: KindRayJob,
APIVersion: rayv1alpha1.SchemeGroupVersion.String(),
},
Spec: jobSpec,
ObjectMeta: *objectMeta,
}
}

func injectLogsSidecar(primaryContainer *v1.Container, podSpec *v1.PodSpec) {
Expand Down Expand Up @@ -503,7 +611,125 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon
return &pluginsCore.TaskInfo{Logs: taskLogs}, nil
}

func getEventInfoForRayJobV1(logConfig logs.LogConfig, pluginContext k8s.PluginContext, rayJob *rayv1.RayJob) (*pluginsCore.TaskInfo, error) {
logPlugin, err := logs.InitializeLogPlugins(&logConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err)
}

var taskLogs []*core.TaskLog

taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID()
input := tasklog.Input{
Namespace: rayJob.Namespace,
TaskExecutionID: taskExecID,
ExtraTemplateVars: []tasklog.TemplateVar{},
}
if rayJob.Status.JobId != "" {
input.ExtraTemplateVars = append(
input.ExtraTemplateVars,
tasklog.TemplateVar{
Regex: logTemplateRegexes.RayJobID,
Value: rayJob.Status.JobId,
},
)
}
if rayJob.Status.RayClusterName != "" {
input.ExtraTemplateVars = append(
input.ExtraTemplateVars,
tasklog.TemplateVar{
Regex: logTemplateRegexes.RayClusterName,
Value: rayJob.Status.RayClusterName,
},
)
}

// TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs
// RayJob CRD does not include the name of the worker or head pod for now
logOutput, err := logPlugin.GetTaskLogs(input)
if err != nil {
return nil, fmt.Errorf("failed to generate task logs. Error: %w", err)
}
taskLogs = append(taskLogs, logOutput.TaskLogs...)

// Handling for Ray Dashboard
dashboardURLTemplate := GetConfig().DashboardURLTemplate
if dashboardURLTemplate != nil &&
rayJob.Status.DashboardURL != "" &&
rayJob.Status.JobStatus == rayv1.JobStatusRunning {
dashboardURLOutput, err := dashboardURLTemplate.GetTaskLogs(input)
if err != nil {
return nil, fmt.Errorf("failed to generate Ray dashboard link. Error: %w", err)
}
taskLogs = append(taskLogs, dashboardURLOutput.TaskLogs...)
}

return &pluginsCore.TaskInfo{Logs: taskLogs}, nil
}

func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
crdVersion := GetConfig().KubeRayCrdVersion
if crdVersion == "v1" {
return plugin.GetTaskPhaseV1(ctx, pluginContext, resource)
}

return plugin.GetTaskPhaseV1Alpha1(ctx, pluginContext, resource)
}

func (plugin rayJobResourceHandler) GetTaskPhaseV1(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
rayJob := resource.(*rayv1.RayJob)
info, err := getEventInfoForRayJobV1(GetConfig().Logs, pluginContext, rayJob)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

if len(rayJob.Status.JobDeploymentStatus) == 0 {
return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling"), nil
}

// KubeRay creates a Ray cluster first, and then submits a Ray job to the cluster
switch rayJob.Status.JobDeploymentStatus {
case rayv1.JobDeploymentStatusInitializing:
return pluginsCore.PhaseInfoInitializing(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "cluster is creating", info), nil
case rayv1.JobDeploymentStatusFailedToGetOrCreateRayCluster:
reason := fmt.Sprintf("Failed to create Ray cluster %s with error: %s", rayJob.Name, rayJob.Status.Message)
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
case rayv1.JobDeploymentStatusFailedJobDeploy:
reason := fmt.Sprintf("Failed to submit Ray job %s with error: %s", rayJob.Name, rayJob.Status.Message)
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
// JobDeploymentStatusSuspended is used when the suspend flag is set in rayJob. The suspend flag allows the temporary suspension of a Job's execution, which can be resumed later.
// Certain versions of KubeRay use a K8s job to submit a Ray job to the Ray cluster. JobDeploymentStatusWaitForK8sJob indicates that the K8s job is under creation.
case rayv1.JobDeploymentStatusWaitForDashboard, rayv1.JobDeploymentStatusFailedToGetJobStatus, rayv1.JobDeploymentStatusWaitForDashboardReady, rayv1.JobDeploymentStatusWaitForK8sJob, rayv1.JobDeploymentStatusSuspended:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil
case rayv1.JobDeploymentStatusRunning, rayv1.JobDeploymentStatusComplete:
switch rayJob.Status.JobStatus {
case rayv1.JobStatusFailed:
reason := fmt.Sprintf("Failed to run Ray job %s with error: %s", rayJob.Name, rayJob.Status.Message)
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
case rayv1.JobStatusSucceeded:
return pluginsCore.PhaseInfoSuccess(info), nil
// JobStatusStopped can occur when the suspend flag is set in rayJob.
case rayv1.JobStatusPending, rayv1.JobStatusStopped:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil
case rayv1.JobStatusRunning:
phaseInfo := pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info)
if len(info.Logs) > 0 {
phaseInfo = phaseInfo.WithVersion(pluginsCore.DefaultPhaseVersion + 1)
}
return phaseInfo, nil
default:
// We already handle all known job status, so this should never happen unless a future version of ray
// introduced a new job status.
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job status: %s", rayJob.Status.JobStatus)
}
default:
// We already handle all known deployment status, so this should never happen unless a future version of ray
// introduced a new job status.
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job deployment status: %s", rayJob.Status.JobDeploymentStatus)
}
}

func (plugin rayJobResourceHandler) GetTaskPhaseV1Alpha1(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
rayJob := resource.(*rayv1alpha1.RayJob)
info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob)
if err != nil {
Expand Down
Loading

0 comments on commit d6747c1

Please sign in to comment.