Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Jul 20, 2022
1 parent 9492c58 commit c122723
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 144 deletions.
7 changes: 4 additions & 3 deletions go/tasks/plugins/cluster_resource/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package ray
import (
"context"
"fmt"
"time"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog"
v1 "k8s.io/api/core/v1"
"time"

"sigs.k8s.io/controller-runtime/pkg/client"

Expand Down Expand Up @@ -102,8 +103,8 @@ func (rayClusterResourceHandler) BuildResource(ctx context.Context, taskCtx plug

serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
rayCluster.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName
for _, worker := range rayCluster.Spec.WorkerGroupSpecs {
worker.Template.Spec.ServiceAccountName = serviceAccountName
for index, _ := range rayCluster.Spec.WorkerGroupSpecs {
rayCluster.Spec.WorkerGroupSpecs[index].Template.Spec.ServiceAccountName = serviceAccountName
}

return rayCluster, nil
Expand Down
189 changes: 48 additions & 141 deletions go/tasks/plugins/cluster_resource/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@ package ray

import (
"context"
"testing"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
pluginIOMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
commonOp "github.com/kubeflow/common/pkg/apis/common/v1"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
rayv1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
corev1 "k8s.io/api/core/v1"
k8sV1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"time"
)

const testImage = "image://"
Expand All @@ -44,6 +43,10 @@ var (
flytek8s.ResourceNvidiaGPU: resource.MustParse("1"),
},
}

clusterName = "testRayCluster"
rayImage = "rayproject/ray:1.8.0"
workerGroupName = "worker-group"
)

func dummyRayCustomObj() *core.RayCluster {
Expand All @@ -56,20 +59,7 @@ func dummyRayCustomObj() *core.RayCluster {
}
}

func dummyTaskTemplate(id string, rayCustomObj *core.RayCluster) *core.TaskTemplate {

rayObjJSON, err := utils.MarshalToString(rayCustomObj)
if err != nil {
panic(err)
}

structObj := structpb.Struct{}

err = jsonpb.UnmarshalString(rayObjJSON, &structObj)
if err != nil {
panic(err)
}

func dummyRayTaskTemplate(id string) *core.TaskTemplate {
return &core.TaskTemplate{
Id: &core.Identifier{Name: id},
Type: "container",
Expand All @@ -80,7 +70,13 @@ func dummyTaskTemplate(id string, rayCustomObj *core.RayCluster) *core.TaskTempl
Env: dummyEnvVars,
},
},
Custom: &structObj,
Resources: map[string]*core.Resource{id: {Value: &core.Resource_Ray{Ray: &core.RayCluster{
Name: clusterName,
ClusterSpec: &core.ClusterSpec{
HeadGroupSpec: &core.HeadGroupSpec{Image: rayImage, ServiceType: "ClusterIP"},
WorkerGroupSpec: []*core.WorkerGroupSpec{{Image: rayImage, GroupName: workerGroupName, Replicas: 3}},
},
}}}},
}
}

Expand Down Expand Up @@ -132,129 +128,40 @@ func dummyRayTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecut
taskExecutionMetadata.OnGetOverrides().Return(resources)
taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount)
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskExecutionMetadata.OnGetSecurityContext().Return(core.SecurityContext{RunAs: &core.Identity{K8SServiceAccount: serviceAccount}})
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)
return taskCtx
}

func dummyRayJobResource(rayResourceHandler rayClusterResourceHandler,
workers int32, psReplicas int32, chiefReplicas int32, conditionType commonOp.JobConditionType) *rayv1alpha1.RayCluster {
var jobConditions []commonOp.JobCondition

now := time.Now()

jobCreated := commonOp.JobCondition{
Type: commonOp.JobCreated,
Status: corev1.ConditionTrue,
Reason: "TensorFlowJobCreated",
Message: "TensorFlowJob the-job is created.",
LastUpdateTime: v1.Time{
Time: now,
},
LastTransitionTime: v1.Time{
Time: now,
},
}
jobRunningActive := commonOp.JobCondition{
Type: commonOp.JobRunning,
Status: corev1.ConditionTrue,
Reason: "TensorFlowJobRunning",
Message: "TensorFlowJob the-job is running.",
LastUpdateTime: v1.Time{
Time: now.Add(time.Minute),
},
LastTransitionTime: v1.Time{
Time: now.Add(time.Minute),
},
}
jobRunningInactive := *jobRunningActive.DeepCopy()
jobRunningInactive.Status = corev1.ConditionFalse
jobSucceeded := commonOp.JobCondition{
Type: commonOp.JobSucceeded,
Status: corev1.ConditionTrue,
Reason: "TensorFlowJobSucceeded",
Message: "TensorFlowJob the-job is successfully completed.",
LastUpdateTime: v1.Time{
Time: now.Add(2 * time.Minute),
},
LastTransitionTime: v1.Time{
Time: now.Add(2 * time.Minute),
},
}
jobFailed := commonOp.JobCondition{
Type: commonOp.JobFailed,
Status: corev1.ConditionTrue,
Reason: "TensorFlowJobFailed",
Message: "TensorFlowJob the-job is failed.",
LastUpdateTime: v1.Time{
Time: now.Add(2 * time.Minute),
},
LastTransitionTime: v1.Time{
Time: now.Add(2 * time.Minute),
},
}
jobRestarting := commonOp.JobCondition{
Type: commonOp.JobRestarting,
Status: corev1.ConditionTrue,
Reason: "TensorFlowJobRestarting",
Message: "TensorFlowJob the-job is restarting because some replica(s) failed.",
LastUpdateTime: v1.Time{
Time: now.Add(3 * time.Minute),
},
LastTransitionTime: v1.Time{
Time: now.Add(3 * time.Minute),
},
}

switch conditionType {
case commonOp.JobCreated:
jobConditions = []commonOp.JobCondition{
jobCreated,
}
case commonOp.JobRunning:
jobConditions = []commonOp.JobCondition{
jobCreated,
jobRunningActive,
}
case commonOp.JobSucceeded:
jobConditions = []commonOp.JobCondition{
jobCreated,
jobRunningInactive,
jobSucceeded,
}
case commonOp.JobFailed:
jobConditions = []commonOp.JobCondition{
jobCreated,
jobRunningInactive,
jobFailed,
}
case commonOp.JobRestarting:
jobConditions = []commonOp.JobCondition{
jobCreated,
jobRunningInactive,
jobFailed,
jobRestarting,
}
}

rayObj := dummyRayCustomObj()
taskTemplate := dummyTaskTemplate("the job", rayObj)
rayResource, err := rayResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate))
if err != nil {
panic(err)
}
func TestBuildResourceRay(t *testing.T) {
rayResourceHandler := rayClusterResourceHandler{}
taskTemplate := dummyRayTaskTemplate("ray-id")

RayResource, err := rayResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate))
assert.Nil(t, err)

assert.NotNil(t, RayResource)
ray, ok := RayResource.(*rayv1alpha1.RayCluster)
assert.True(t, ok)
assert.Equal(t, ray.Name, clusterName)

headReplica := int32(1)
assert.Equal(t, ray.Spec.HeadGroupSpec.Replicas, &headReplica)
assert.Equal(t, ray.Spec.HeadGroupSpec.ServiceType, k8sV1.ServiceType("ClusterIP"))
assert.Equal(t, ray.Spec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount)
assert.Equal(t, ray.Spec.HeadGroupSpec.RayStartParams, map[string]string{"node-ip-address": "$MY_POD_IP"})

workerReplica := int32(3)
assert.Equal(t, ray.Spec.WorkerGroupSpecs[0].Replicas, &workerReplica)
assert.Equal(t, ray.Spec.WorkerGroupSpecs[0].MinReplicas, &workerReplica)
assert.Equal(t, ray.Spec.WorkerGroupSpecs[0].MaxReplicas, &workerReplica)
assert.Equal(t, ray.Spec.WorkerGroupSpecs[0].GroupName, workerGroupName)
assert.Equal(t, ray.Spec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount)
assert.Equal(t, ray.Spec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"node-ip-address": "$MY_POD_IP"})
}

return &kubeflowv1.TFJob{
ObjectMeta: v1.ObjectMeta{
Name: jobName,
Namespace: jobNamespace,
},
Spec: resource.(*kubeflowv1.TFJob).Spec,
Status: commonOp.JobStatus{
Conditions: jobConditions,
ReplicaStatuses: nil,
StartTime: nil,
CompletionTime: nil,
LastReconcileTime: nil,
},
}
func TestGetPropertiesRay(t *testing.T) {
rayResourceHandler := rayClusterResourceHandler{}
expected := k8s.PluginProperties{}
assert.Equal(t, expected, rayResourceHandler.GetProperties())
}

0 comments on commit c122723

Please sign in to comment.