Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Make generate
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Jul 18, 2022
1 parent a2cc4d9 commit 9492c58
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 19 deletions.
4 changes: 4 additions & 0 deletions go/tasks/pluginmachinery/core/exec_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ type TaskExecutionContext interface {
// Returns a reader that retrieves previously stored plugin internal state. the state itself is immutable
PluginStateReader() PluginStateReader

ResourcePluginStateReader() PluginStateReader

// Returns a TaskReader, to retrieve task details
TaskReader() TaskReader

Expand All @@ -59,6 +61,8 @@ type TaskExecutionContext interface {
// These mutation will be visible in the next round
PluginStateWriter() PluginStateWriter

ResourcePluginStateWriter() PluginStateWriter

// Get a handle to catalog client
Catalog() catalog.AsyncClient

Expand Down
68 changes: 68 additions & 0 deletions go/tasks/pluginmachinery/core/mocks/task_execution_context.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 9 additions & 8 deletions go/tasks/pluginmachinery/core/phase_enumer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

260 changes: 260 additions & 0 deletions go/tasks/plugins/cluster_resource/ray/ray_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
package ray

import (
"context"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
pluginIOMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
commonOp "github.com/kubeflow/common/pkg/apis/common/v1"
kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
rayv1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1"
"github.com/stretchr/testify/mock"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"time"
)

const testImage = "image://"
const serviceAccount = "ray_sa"

var (
dummyEnvVars = []*core.KeyValuePair{
{Key: "Env_Var", Value: "Env_Val"},
}

testArgs = []string{
"test-args",
}

resourceRequirements = &corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1000m"),
corev1.ResourceMemory: resource.MustParse("1Gi"),
flytek8s.ResourceNvidiaGPU: resource.MustParse("1"),
},
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("100m"),
corev1.ResourceMemory: resource.MustParse("512Mi"),
flytek8s.ResourceNvidiaGPU: resource.MustParse("1"),
},
}
)

func dummyRayCustomObj() *core.RayCluster {
return &core.RayCluster{
Name: "testRayCluster",
ClusterSpec: &core.ClusterSpec{
HeadGroupSpec: &core.HeadGroupSpec{Image: "rayproject/ray:1.8.0", ServiceType: "NodePort"},
WorkerGroupSpec: []*core.WorkerGroupSpec{{GroupName: "test-group", Replicas: 3}},
},
}
}

func dummyTaskTemplate(id string, rayCustomObj *core.RayCluster) *core.TaskTemplate {

rayObjJSON, err := utils.MarshalToString(rayCustomObj)
if err != nil {
panic(err)
}

structObj := structpb.Struct{}

err = jsonpb.UnmarshalString(rayObjJSON, &structObj)
if err != nil {
panic(err)
}

return &core.TaskTemplate{
Id: &core.Identifier{Name: id},
Type: "container",
Target: &core.TaskTemplate_Container{
Container: &core.Container{
Image: testImage,
Args: testArgs,
Env: dummyEnvVars,
},
},
Custom: &structObj,
}
}

func dummyRayTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecutionContext {
taskCtx := &mocks.TaskExecutionContext{}
inputReader := &pluginIOMocks.InputReader{}
inputReader.OnGetInputPrefixPath().Return("/input/prefix")
inputReader.OnGetInputPath().Return("/input")
inputReader.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil)
taskCtx.OnInputReader().Return(inputReader)

outputReader := &pluginIOMocks.OutputWriter{}
outputReader.OnGetOutputPath().Return("/data/outputs.pb")
outputReader.OnGetOutputPrefixPath().Return("/data/")
outputReader.OnGetRawOutputPrefix().Return("")
outputReader.OnGetCheckpointPrefix().Return("/checkpoint")
outputReader.OnGetPreviousCheckpointsPrefix().Return("/prev")
taskCtx.OnOutputWriter().Return(outputReader)

taskReader := &mocks.TaskReader{}
taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil)
taskCtx.OnTaskReader().Return(taskReader)

tID := &mocks.TaskExecutionID{}
tID.OnGetID().Return(core.TaskExecutionIdentifier{
NodeExecutionId: &core.NodeExecutionIdentifier{
ExecutionId: &core.WorkflowExecutionIdentifier{
Name: "my_name",
Project: "my_project",
Domain: "my_domain",
},
},
})
tID.OnGetGeneratedName().Return("some-acceptable-name")

resources := &mocks.TaskOverrides{}
resources.OnGetResources().Return(resourceRequirements)

taskExecutionMetadata := &mocks.TaskExecutionMetadata{}
taskExecutionMetadata.OnGetTaskExecutionID().Return(tID)
taskExecutionMetadata.OnGetNamespace().Return("test-namespace")
taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"})
taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"})
taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{
Kind: "node",
Name: "blah",
})
taskExecutionMetadata.OnIsInterruptible().Return(true)
taskExecutionMetadata.OnGetOverrides().Return(resources)
taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount)
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)
return taskCtx
}

func dummyRayJobResource(rayResourceHandler rayClusterResourceHandler,
workers int32, psReplicas int32, chiefReplicas int32, conditionType commonOp.JobConditionType) *rayv1alpha1.RayCluster {
var jobConditions []commonOp.JobCondition

now := time.Now()

jobCreated := commonOp.JobCondition{
Type: commonOp.JobCreated,
Status: corev1.ConditionTrue,
Reason: "TensorFlowJobCreated",
Message: "TensorFlowJob the-job is created.",
LastUpdateTime: v1.Time{
Time: now,
},
LastTransitionTime: v1.Time{
Time: now,
},
}
jobRunningActive := commonOp.JobCondition{
Type: commonOp.JobRunning,
Status: corev1.ConditionTrue,
Reason: "TensorFlowJobRunning",
Message: "TensorFlowJob the-job is running.",
LastUpdateTime: v1.Time{
Time: now.Add(time.Minute),
},
LastTransitionTime: v1.Time{
Time: now.Add(time.Minute),
},
}
jobRunningInactive := *jobRunningActive.DeepCopy()
jobRunningInactive.Status = corev1.ConditionFalse
jobSucceeded := commonOp.JobCondition{
Type: commonOp.JobSucceeded,
Status: corev1.ConditionTrue,
Reason: "TensorFlowJobSucceeded",
Message: "TensorFlowJob the-job is successfully completed.",
LastUpdateTime: v1.Time{
Time: now.Add(2 * time.Minute),
},
LastTransitionTime: v1.Time{
Time: now.Add(2 * time.Minute),
},
}
jobFailed := commonOp.JobCondition{
Type: commonOp.JobFailed,
Status: corev1.ConditionTrue,
Reason: "TensorFlowJobFailed",
Message: "TensorFlowJob the-job is failed.",
LastUpdateTime: v1.Time{
Time: now.Add(2 * time.Minute),
},
LastTransitionTime: v1.Time{
Time: now.Add(2 * time.Minute),
},
}
jobRestarting := commonOp.JobCondition{
Type: commonOp.JobRestarting,
Status: corev1.ConditionTrue,
Reason: "TensorFlowJobRestarting",
Message: "TensorFlowJob the-job is restarting because some replica(s) failed.",
LastUpdateTime: v1.Time{
Time: now.Add(3 * time.Minute),
},
LastTransitionTime: v1.Time{
Time: now.Add(3 * time.Minute),
},
}

switch conditionType {
case commonOp.JobCreated:
jobConditions = []commonOp.JobCondition{
jobCreated,
}
case commonOp.JobRunning:
jobConditions = []commonOp.JobCondition{
jobCreated,
jobRunningActive,
}
case commonOp.JobSucceeded:
jobConditions = []commonOp.JobCondition{
jobCreated,
jobRunningInactive,
jobSucceeded,
}
case commonOp.JobFailed:
jobConditions = []commonOp.JobCondition{
jobCreated,
jobRunningInactive,
jobFailed,
}
case commonOp.JobRestarting:
jobConditions = []commonOp.JobCondition{
jobCreated,
jobRunningInactive,
jobFailed,
jobRestarting,
}
}

rayObj := dummyRayCustomObj()
taskTemplate := dummyTaskTemplate("the job", rayObj)
rayResource, err := rayResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate))
if err != nil {
panic(err)
}

return &kubeflowv1.TFJob{
ObjectMeta: v1.ObjectMeta{
Name: jobName,
Namespace: jobNamespace,
},
Spec: resource.(*kubeflowv1.TFJob).Spec,
Status: commonOp.JobStatus{
Conditions: jobConditions,
ReplicaStatuses: nil,
StartTime: nil,
CompletionTime: nil,
LastReconcileTime: nil,
},
}
}
Loading

0 comments on commit 9492c58

Please sign in to comment.