From 9492c58a8d73b50f55c2fa88ddac4c885a377d4c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 16 Jul 2022 01:40:20 +0800 Subject: [PATCH] Make generate Signed-off-by: Kevin Su --- go/tasks/pluginmachinery/core/exec_context.go | 4 + .../core/mocks/task_execution_context.go | 68 +++++ go/tasks/pluginmachinery/core/phase_enumer.go | 17 +- .../plugins/cluster_resource/ray/ray_test.go | 260 ++++++++++++++++++ .../kfoperators/tensorflow/tensorflow_test.go | 21 +- 5 files changed, 351 insertions(+), 19 deletions(-) create mode 100644 go/tasks/plugins/cluster_resource/ray/ray_test.go diff --git a/go/tasks/pluginmachinery/core/exec_context.go b/go/tasks/pluginmachinery/core/exec_context.go index e724c601a..28394f355 100644 --- a/go/tasks/pluginmachinery/core/exec_context.go +++ b/go/tasks/pluginmachinery/core/exec_context.go @@ -43,6 +43,8 @@ type TaskExecutionContext interface { // Returns a reader that retrieves previously stored plugin internal state. the state itself is immutable PluginStateReader() PluginStateReader + ResourcePluginStateReader() PluginStateReader + // Returns a TaskReader, to retrieve task details TaskReader() TaskReader @@ -59,6 +61,8 @@ type TaskExecutionContext interface { // These mutation will be visible in the next round PluginStateWriter() PluginStateWriter + ResourcePluginStateWriter() PluginStateWriter + // Get a handle to catalog client Catalog() catalog.AsyncClient diff --git a/go/tasks/pluginmachinery/core/mocks/task_execution_context.go b/go/tasks/pluginmachinery/core/mocks/task_execution_context.go index c7ff4961c..08dddfb8c 100644 --- a/go/tasks/pluginmachinery/core/mocks/task_execution_context.go +++ b/go/tasks/pluginmachinery/core/mocks/task_execution_context.go @@ -322,6 +322,74 @@ func (_m *TaskExecutionContext) ResourceManager() core.ResourceManager { return r0 } +type TaskExecutionContext_ResourcePluginStateReader struct { + *mock.Call +} + +func (_m TaskExecutionContext_ResourcePluginStateReader) Return(_a0 core.PluginStateReader) *TaskExecutionContext_ResourcePluginStateReader { + return &TaskExecutionContext_ResourcePluginStateReader{Call: _m.Call.Return(_a0)} +} + +func (_m *TaskExecutionContext) OnResourcePluginStateReader() *TaskExecutionContext_ResourcePluginStateReader { + c_call := _m.On("ResourcePluginStateReader") + return &TaskExecutionContext_ResourcePluginStateReader{Call: c_call} +} + +func (_m *TaskExecutionContext) OnResourcePluginStateReaderMatch(matchers ...interface{}) *TaskExecutionContext_ResourcePluginStateReader { + c_call := _m.On("ResourcePluginStateReader", matchers...) + return &TaskExecutionContext_ResourcePluginStateReader{Call: c_call} +} + +// ResourcePluginStateReader provides a mock function with given fields: +func (_m *TaskExecutionContext) ResourcePluginStateReader() core.PluginStateReader { + ret := _m.Called() + + var r0 core.PluginStateReader + if rf, ok := ret.Get(0).(func() core.PluginStateReader); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(core.PluginStateReader) + } + } + + return r0 +} + +type TaskExecutionContext_ResourcePluginStateWriter struct { + *mock.Call +} + +func (_m TaskExecutionContext_ResourcePluginStateWriter) Return(_a0 core.PluginStateWriter) *TaskExecutionContext_ResourcePluginStateWriter { + return &TaskExecutionContext_ResourcePluginStateWriter{Call: _m.Call.Return(_a0)} +} + +func (_m *TaskExecutionContext) OnResourcePluginStateWriter() *TaskExecutionContext_ResourcePluginStateWriter { + c_call := _m.On("ResourcePluginStateWriter") + return &TaskExecutionContext_ResourcePluginStateWriter{Call: c_call} +} + +func (_m *TaskExecutionContext) OnResourcePluginStateWriterMatch(matchers ...interface{}) *TaskExecutionContext_ResourcePluginStateWriter { + c_call := _m.On("ResourcePluginStateWriter", matchers...) + return &TaskExecutionContext_ResourcePluginStateWriter{Call: c_call} +} + +// ResourcePluginStateWriter provides a mock function with given fields: +func (_m *TaskExecutionContext) ResourcePluginStateWriter() core.PluginStateWriter { + ret := _m.Called() + + var r0 core.PluginStateWriter + if rf, ok := ret.Get(0).(func() core.PluginStateWriter); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(core.PluginStateWriter) + } + } + + return r0 +} + type TaskExecutionContext_SecretManager struct { *mock.Call } diff --git a/go/tasks/pluginmachinery/core/phase_enumer.go b/go/tasks/pluginmachinery/core/phase_enumer.go index 8101f2f3d..8eedb325e 100644 --- a/go/tasks/pluginmachinery/core/phase_enumer.go +++ b/go/tasks/pluginmachinery/core/phase_enumer.go @@ -7,9 +7,9 @@ import ( "fmt" ) -const _PhaseName = "PhaseUndefinedPhaseNotReadyPhaseWaitingForResourcesPhaseQueuedPhaseInitializingPhaseRunningPhaseSuccessPhaseRetryableFailurePhasePermanentFailurePhaseWaitingForCache" +const _PhaseName = "PhaseUndefinedPhaseNotReadyPhaseWaitingForResourcesPhaseQueuedPhaseInitializingPhaseClusterRunningPhaseRunningPhaseSuccessPhaseRetryableFailurePhasePermanentFailurePhaseWaitingForCache" -var _PhaseIndex = [...]uint8{0, 14, 27, 51, 62, 79, 91, 103, 124, 145, 165} +var _PhaseIndex = [...]uint8{0, 14, 27, 51, 62, 79, 98, 110, 122, 143, 164, 184} func (i Phase) String() string { if i < 0 || i >= Phase(len(_PhaseIndex)-1) { @@ -18,7 +18,7 @@ func (i Phase) String() string { return _PhaseName[_PhaseIndex[i]:_PhaseIndex[i+1]] } -var _PhaseValues = []Phase{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +var _PhaseValues = []Phase{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} var _PhaseNameToValueMap = map[string]Phase{ _PhaseName[0:14]: 0, @@ -26,11 +26,12 @@ var _PhaseNameToValueMap = map[string]Phase{ _PhaseName[27:51]: 2, _PhaseName[51:62]: 3, _PhaseName[62:79]: 4, - _PhaseName[79:91]: 5, - _PhaseName[91:103]: 6, - _PhaseName[103:124]: 7, - _PhaseName[124:145]: 8, - _PhaseName[145:165]: 9, + _PhaseName[79:98]: 5, + _PhaseName[98:110]: 6, + _PhaseName[110:122]: 7, + _PhaseName[122:143]: 8, + _PhaseName[143:164]: 9, + _PhaseName[164:184]: 10, } // PhaseString retrieves an enum value from the enum constants string name. diff --git a/go/tasks/plugins/cluster_resource/ray/ray_test.go b/go/tasks/plugins/cluster_resource/ray/ray_test.go new file mode 100644 index 000000000..75be68a24 --- /dev/null +++ b/go/tasks/plugins/cluster_resource/ray/ray_test.go @@ -0,0 +1,260 @@ +package ray + +import ( + "context" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + pluginIOMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/golang/protobuf/jsonpb" + structpb "github.com/golang/protobuf/ptypes/struct" + commonOp "github.com/kubeflow/common/pkg/apis/common/v1" + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + rayv1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1" + "github.com/stretchr/testify/mock" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "time" +) + +const testImage = "image://" +const serviceAccount = "ray_sa" + +var ( + dummyEnvVars = []*core.KeyValuePair{ + {Key: "Env_Var", Value: "Env_Val"}, + } + + testArgs = []string{ + "test-args", + } + + resourceRequirements = &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1000m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("100m"), + corev1.ResourceMemory: resource.MustParse("512Mi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + } +) + +func dummyRayCustomObj() *core.RayCluster { + return &core.RayCluster{ + Name: "testRayCluster", + ClusterSpec: &core.ClusterSpec{ + HeadGroupSpec: &core.HeadGroupSpec{Image: "rayproject/ray:1.8.0", ServiceType: "NodePort"}, + WorkerGroupSpec: []*core.WorkerGroupSpec{{GroupName: "test-group", Replicas: 3}}, + }, + } +} + +func dummyTaskTemplate(id string, rayCustomObj *core.RayCluster) *core.TaskTemplate { + + rayObjJSON, err := utils.MarshalToString(rayCustomObj) + if err != nil { + panic(err) + } + + structObj := structpb.Struct{} + + err = jsonpb.UnmarshalString(rayObjJSON, &structObj) + if err != nil { + panic(err) + } + + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "container", + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + Args: testArgs, + Env: dummyEnvVars, + }, + }, + Custom: &structObj, + } +} + +func dummyRayTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecutionContext { + taskCtx := &mocks.TaskExecutionContext{} + inputReader := &pluginIOMocks.InputReader{} + inputReader.OnGetInputPrefixPath().Return("/input/prefix") + inputReader.OnGetInputPath().Return("/input") + inputReader.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) + taskCtx.OnInputReader().Return(inputReader) + + outputReader := &pluginIOMocks.OutputWriter{} + outputReader.OnGetOutputPath().Return("/data/outputs.pb") + outputReader.OnGetOutputPrefixPath().Return("/data/") + outputReader.OnGetRawOutputPrefix().Return("") + outputReader.OnGetCheckpointPrefix().Return("/checkpoint") + outputReader.OnGetPreviousCheckpointsPrefix().Return("/prev") + taskCtx.OnOutputWriter().Return(outputReader) + + taskReader := &mocks.TaskReader{} + taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) + taskCtx.OnTaskReader().Return(taskReader) + + tID := &mocks.TaskExecutionID{} + tID.OnGetID().Return(core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + tID.OnGetGeneratedName().Return("some-acceptable-name") + + resources := &mocks.TaskOverrides{} + resources.OnGetResources().Return(resourceRequirements) + + taskExecutionMetadata := &mocks.TaskExecutionMetadata{} + taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) + taskExecutionMetadata.OnGetNamespace().Return("test-namespace") + taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"}) + taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) + taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ + Kind: "node", + Name: "blah", + }) + taskExecutionMetadata.OnIsInterruptible().Return(true) + taskExecutionMetadata.OnGetOverrides().Return(resources) + taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) + taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) + taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + return taskCtx +} + +func dummyRayJobResource(rayResourceHandler rayClusterResourceHandler, + workers int32, psReplicas int32, chiefReplicas int32, conditionType commonOp.JobConditionType) *rayv1alpha1.RayCluster { + var jobConditions []commonOp.JobCondition + + now := time.Now() + + jobCreated := commonOp.JobCondition{ + Type: commonOp.JobCreated, + Status: corev1.ConditionTrue, + Reason: "TensorFlowJobCreated", + Message: "TensorFlowJob the-job is created.", + LastUpdateTime: v1.Time{ + Time: now, + }, + LastTransitionTime: v1.Time{ + Time: now, + }, + } + jobRunningActive := commonOp.JobCondition{ + Type: commonOp.JobRunning, + Status: corev1.ConditionTrue, + Reason: "TensorFlowJobRunning", + Message: "TensorFlowJob the-job is running.", + LastUpdateTime: v1.Time{ + Time: now.Add(time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(time.Minute), + }, + } + jobRunningInactive := *jobRunningActive.DeepCopy() + jobRunningInactive.Status = corev1.ConditionFalse + jobSucceeded := commonOp.JobCondition{ + Type: commonOp.JobSucceeded, + Status: corev1.ConditionTrue, + Reason: "TensorFlowJobSucceeded", + Message: "TensorFlowJob the-job is successfully completed.", + LastUpdateTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + } + jobFailed := commonOp.JobCondition{ + Type: commonOp.JobFailed, + Status: corev1.ConditionTrue, + Reason: "TensorFlowJobFailed", + Message: "TensorFlowJob the-job is failed.", + LastUpdateTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + } + jobRestarting := commonOp.JobCondition{ + Type: commonOp.JobRestarting, + Status: corev1.ConditionTrue, + Reason: "TensorFlowJobRestarting", + Message: "TensorFlowJob the-job is restarting because some replica(s) failed.", + LastUpdateTime: v1.Time{ + Time: now.Add(3 * time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(3 * time.Minute), + }, + } + + switch conditionType { + case commonOp.JobCreated: + jobConditions = []commonOp.JobCondition{ + jobCreated, + } + case commonOp.JobRunning: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningActive, + } + case commonOp.JobSucceeded: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningInactive, + jobSucceeded, + } + case commonOp.JobFailed: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningInactive, + jobFailed, + } + case commonOp.JobRestarting: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningInactive, + jobFailed, + jobRestarting, + } + } + + rayObj := dummyRayCustomObj() + taskTemplate := dummyTaskTemplate("the job", rayObj) + rayResource, err := rayResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate)) + if err != nil { + panic(err) + } + + return &kubeflowv1.TFJob{ + ObjectMeta: v1.ObjectMeta{ + Name: jobName, + Namespace: jobNamespace, + }, + Spec: resource.(*kubeflowv1.TFJob).Spec, + Status: commonOp.JobStatus{ + Conditions: jobConditions, + ReplicaStatuses: nil, + StartTime: nil, + CompletionTime: nil, + LastReconcileTime: nil, + }, + } +} diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 0ac145128..2785459f4 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -28,10 +28,9 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" + kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/stretchr/testify/assert" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - tfOp "github.com/kubeflow/training-operator/pkg/controller.v1/tensorflow" ) const testImage = "image://" @@ -152,7 +151,7 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.Tas } func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorResourceHandler, - workers int32, psReplicas int32, chiefReplicas int32, conditionType commonOp.JobConditionType) *tfOp.TFJob { + workers int32, psReplicas int32, chiefReplicas int32, conditionType commonOp.JobConditionType) *kubeflowv1.TFJob { var jobConditions []commonOp.JobCondition now := time.Now() @@ -258,12 +257,12 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso panic(err) } - return &tfOp.TFJob{ + return &kubeflowv1.TFJob{ ObjectMeta: v1.ObjectMeta{ Name: jobName, Namespace: jobNamespace, }, - Spec: resource.(*tfOp.TFJob).Spec, + Spec: resource.(*kubeflowv1.TFJob).Spec, Status: commonOp.JobStatus{ Conditions: jobConditions, ReplicaStatuses: nil, @@ -284,17 +283,17 @@ func TestBuildResourceTensorFlow(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, resource) - tensorflowJob, ok := resource.(*tfOp.TFJob) + tensorflowJob, ok := resource.(*kubeflowv1.TFJob) assert.True(t, ok) - assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[tfOp.TFReplicaTypeWorker].Replicas) - assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[tfOp.TFReplicaTypePS].Replicas) - assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[tfOp.TFReplicaTypeChief].Replicas) + assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas) + assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas) for _, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { var hasContainerWithDefaultTensorFlowName = false for _, container := range replicaSpec.Template.Spec.Containers { - if container.Name == tfOp.DefaultContainerName { + if container.Name == kubeflowv1.TFJobDefaultContainerName { hasContainerWithDefaultTensorFlowName = true } @@ -310,7 +309,7 @@ func TestGetTaskPhase(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} ctx := context.TODO() - dummyTensorFlowJobResourceCreator := func(conditionType commonOp.JobConditionType) *tfOp.TFJob { + dummyTensorFlowJobResourceCreator := func(conditionType commonOp.JobConditionType) *kubeflowv1.TFJob { return dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, conditionType) }