diff --git a/ray-operator/controllers/ray/expectations/scale_expectation.go b/ray-operator/controllers/ray/expectations/scale_expectation.go new file mode 100644 index 0000000000..bfb5f13f0e --- /dev/null +++ b/ray-operator/controllers/ray/expectations/scale_expectation.go @@ -0,0 +1,152 @@ +package expectations + +import ( + "context" + "fmt" + "time" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/cache" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const DefaultHead = "" + +// ScaleAction is the action of scale, like create and delete. +type ScaleAction string + +const ( + // Create action + Create ScaleAction = "create" + // Delete action + Delete ScaleAction = "delete" +) + +const ( + // GroupIndex indexes pods within the specified group in RayCluster + GroupIndex = "group" + // RayClusterIndex indexes pods within the RayCluster + RayClusterIndex = "raycluster" +) + +// RayClusterScaleExpectation is an interface that to set and wait on expectations of RayCluster groups scale. +type RayClusterScaleExpectation interface { + ExpectScalePod(rayClusterName, group, namespace, name string, action ScaleAction) + IsSatisfied(ctx context.Context, rayClusterName, group, namespace string) bool + Delete(rayClusterName, namespace string) +} + +func newRayPodIndexer() cache.Indexer { + return cache.NewIndexer(rayPodKey, cache.Indexers{GroupIndex: groupIndexFunc, RayClusterIndex: rayClusterIndexFunc}) +} + +func NewRayClusterScaleExpectation(client client.Client) RayClusterScaleExpectation { + return &realRayClusterScaleExpectation{ + Client: client, + itemsCache: newRayPodIndexer(), + } +} + +type realRayClusterScaleExpectation struct { + client.Client + itemsCache cache.Indexer +} + +func (r *realRayClusterScaleExpectation) ExpectScalePod(rayClusterName, group, namespace, name string, action ScaleAction) { + _ = r.itemsCache.Add(&rayPod{ + name: name, + namespace: namespace, + group: group, + rayCluster: rayClusterName, + action: action, + recordTimestamp: time.Now(), + }) +} + +func (r *realRayClusterScaleExpectation) IsSatisfied(ctx context.Context, rayClusterName, group, namespace string) (isSatisfied bool) { + items, _ := r.itemsCache.ByIndex(GroupIndex, fmt.Sprintf("%s/%s/%s", namespace, rayClusterName, group)) + isSatisfied = true + for i := range items { + rp := items[i].(*rayPod) + pod := &corev1.Pod{} + isPodSatisfied := false + switch rp.action { + case Create: + if err := r.Get(ctx, types.NamespacedName{Name: rp.name, Namespace: namespace}, pod); err == nil { + isPodSatisfied = true + } else { + isPodSatisfied = errors.IsNotFound(err) && rp.recordTimestamp.Add(30*time.Second).Before(time.Now()) + } + case Delete: + if err := r.Get(ctx, types.NamespacedName{Name: rp.name, Namespace: namespace}, pod); err != nil { + isPodSatisfied = errors.IsNotFound(err) + } else { + isPodSatisfied = pod.DeletionTimestamp != nil + } + } + // delete satisfied item in cache + if isPodSatisfied { + _ = r.itemsCache.Delete(items[i]) + } else { + isSatisfied = false + } + } + return isSatisfied +} + +func (r *realRayClusterScaleExpectation) Delete(rayClusterName, namespace string) { + items, _ := r.itemsCache.ByIndex(RayClusterIndex, fmt.Sprintf("%s/%s", namespace, rayClusterName)) + for _, item := range items { + _ = r.itemsCache.Delete(item) + } +} + +type rayPod struct { + recordTimestamp time.Time + action ScaleAction + name string + namespace string + rayCluster string + group string +} + +func (p *rayPod) Key() string { + return fmt.Sprintf("%s/%s", p.namespace, p.name) +} + +func (p *rayPod) GroupKey() string { + return fmt.Sprintf("%s/%s/%s", p.namespace, p.rayCluster, p.group) +} + +func (p *rayPod) ClusterKey() string { + return fmt.Sprintf("%s/%s", p.namespace, p.rayCluster) +} + +func rayPodKey(obj interface{}) (string, error) { + return obj.(*rayPod).Key(), nil +} + +func groupIndexFunc(obj interface{}) ([]string, error) { + return []string{obj.(*rayPod).GroupKey()}, nil +} + +func rayClusterIndexFunc(obj interface{}) ([]string, error) { + return []string{obj.(*rayPod).ClusterKey()}, nil +} + +func NewFakeRayClusterScaleExpectation() RayClusterScaleExpectation { + return &fakeRayClusterScaleExpectation{} +} + +type fakeRayClusterScaleExpectation struct{} + +func (r *fakeRayClusterScaleExpectation) ExpectScalePod(_, _, _, _ string, _ ScaleAction) { +} + +func (r *fakeRayClusterScaleExpectation) IsSatisfied(_ context.Context, _, _, _ string) bool { + return true +} + +func (r *fakeRayClusterScaleExpectation) Delete(_, _ string) {} diff --git a/ray-operator/controllers/ray/expectations/scale_expectation_test.go b/ray-operator/controllers/ray/expectations/scale_expectation_test.go new file mode 100644 index 0000000000..dcdd38d8ba --- /dev/null +++ b/ray-operator/controllers/ray/expectations/scale_expectation_test.go @@ -0,0 +1,87 @@ +package expectations + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + clientFake "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func TestRayClusterExpectations(t *testing.T) { + setupTest() + ctx := context.TODO() + fakeClient := clientFake.NewClientBuilder().WithRuntimeObjects().Build() + exp := NewRayClusterScaleExpectation(fakeClient) + namespace := "default" + rayClusterName := "raycluster-test-pod" + + // Test expect create head + exp.ExpectScalePod(rayClusterName, DefaultHead, namespace, testPods[0].(*corev1.Pod).Name, Create) + assert.Equal(t, len(exp.(*realRayClusterScaleExpectation).itemsCache.List()), 1) + assert.Equal(t, exp.IsSatisfied(ctx, rayClusterName, DefaultHead, namespace), false) + err := fakeClient.Create(context.TODO(), testPods[0].(*corev1.Pod)) + assert.Nil(t, err, "Fail to create head pod") + assert.Equal(t, exp.IsSatisfied(ctx, rayClusterName, DefaultHead, namespace), true) + // delete satisfied item in cache + assert.Equal(t, len(exp.(*realRayClusterScaleExpectation).itemsCache.List()), 0) + + // Test expect delete head + exp.ExpectScalePod(rayClusterName, DefaultHead, namespace, testPods[0].(*corev1.Pod).Name, Delete) + assert.Equal(t, exp.IsSatisfied(ctx, rayClusterName, DefaultHead, namespace), false) + // delete pod + err = fakeClient.Delete(context.TODO(), testPods[0].(*corev1.Pod)) + assert.Nil(t, err, "Fail to delete head pod") + assert.Equal(t, exp.IsSatisfied(ctx, rayClusterName, DefaultHead, namespace), true) + + // Test expect create worker + group := "test-group" + exp.ExpectScalePod(rayClusterName, group, namespace, testPods[1].(*corev1.Pod).Name, Create) + exp.ExpectScalePod(rayClusterName, group, namespace, testPods[2].(*corev1.Pod).Name, Create) + assert.Equal(t, exp.IsSatisfied(ctx, rayClusterName, group, namespace), false) + assert.Nil(t, fakeClient.Create(context.TODO(), testPods[1].(*corev1.Pod)), "Fail to create worker pod1") + assert.Equal(t, exp.IsSatisfied(ctx, rayClusterName, group, namespace), false) + assert.Nil(t, fakeClient.Create(context.TODO(), testPods[2].(*corev1.Pod)), "Fail to create worker pod2") + assert.Equal(t, exp.IsSatisfied(ctx, rayClusterName, group, namespace), true) + + // Test delete all + // reset pods + setupTest() + exp.ExpectScalePod(rayClusterName, DefaultHead, namespace, testPods[0].(*corev1.Pod).Name, Create) + exp.ExpectScalePod(rayClusterName, group, namespace, testPods[1].(*corev1.Pod).Name, Delete) + exp.ExpectScalePod(rayClusterName, group, namespace, testPods[2].(*corev1.Pod).Name, Delete) + assert.Equal(t, exp.IsSatisfied(ctx, rayClusterName, DefaultHead, namespace), false) + assert.Equal(t, exp.IsSatisfied(ctx, rayClusterName, group, namespace), false) + exp.Delete(rayClusterName, namespace) + assert.Equal(t, exp.IsSatisfied(ctx, rayClusterName, DefaultHead, namespace), true) + assert.Equal(t, exp.IsSatisfied(ctx, rayClusterName, group, namespace), true) + assert.Equal(t, len(exp.(*realRayClusterScaleExpectation).itemsCache.List()), 0) +} + +var testPods []runtime.Object + +func setupTest() { + testPods = []runtime.Object{ + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: "default", + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod2", + Namespace: "default", + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod3", + Namespace: "default", + }, + }, + } +} diff --git a/ray-operator/controllers/ray/raycluster_controller.go b/ray-operator/controllers/ray/raycluster_controller.go index 9194ce92a7..a54d0fe739 100644 --- a/ray-operator/controllers/ray/raycluster_controller.go +++ b/ray-operator/controllers/ray/raycluster_controller.go @@ -11,41 +11,37 @@ import ( "strings" "time" - "k8s.io/apimachinery/pkg/api/meta" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/utils/ptr" - - configapi "github.com/ray-project/kuberay/ray-operator/apis/config/v1alpha1" - "github.com/ray-project/kuberay/ray-operator/controllers/ray/batchscheduler" - "github.com/ray-project/kuberay/ray-operator/controllers/ray/common" - "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" - "github.com/ray-project/kuberay/ray-operator/pkg/features" - - batchv1 "k8s.io/api/batch/v1" - rbacv1 "k8s.io/api/rbac/v1" - - "k8s.io/client-go/tools/record" - - rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" - "github.com/go-logr/logr" routev1 "github.com/openshift/api/route/v1" - "k8s.io/client-go/discovery" - "k8s.io/client-go/rest" - + batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" networkingv1 "k8s.io/api/networking/v1" + rbacv1 "k8s.io/api/rbac/v1" "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" k8sruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/discovery" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/record" + "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" - controller "sigs.k8s.io/controller-runtime/pkg/controller" + "sigs.k8s.io/controller-runtime/pkg/controller" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" + + configapi "github.com/ray-project/kuberay/ray-operator/apis/config/v1alpha1" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/batchscheduler" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/common" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/expectations" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" + "github.com/ray-project/kuberay/ray-operator/pkg/features" ) type reconcileFunc func(context.Context, *rayv1.RayCluster) error @@ -55,6 +51,8 @@ var ( // Definition of a index field for pod name podUIDIndexField = "metadata.uid" + + rayClusterScaleExpectation expectations.RayClusterScaleExpectation ) // getDiscoveryClient returns a discovery client for the current reconciler @@ -105,7 +103,8 @@ func NewReconciler(ctx context.Context, mgr manager.Manager, options RayClusterR panic(err) } isOpenShift := getClusterType(ctx) - + // init ray cluster expectations + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(mgr.GetClient()) // init the batch scheduler manager schedulerMgr, err := batchscheduler.NewSchedulerManager(rayConfigs, mgr.GetConfig()) if err != nil { @@ -116,7 +115,6 @@ func NewReconciler(ctx context.Context, mgr manager.Manager, options RayClusterR // add schema to runtime schedulerMgr.AddToScheme(mgr.GetScheme()) - return &RayClusterReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), @@ -184,6 +182,8 @@ func (r *RayClusterReconciler) Reconcile(ctx context.Context, request ctrl.Reque // No match found if errors.IsNotFound(err) { + // Clear all related expectations + rayClusterScaleExpectation.Delete(instance.Name, instance.Namespace) logger.Info("Read request instance not found error!") } else { logger.Error(err, "Read request instance error!") @@ -678,9 +678,10 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv return err } } - // Reconcile head Pod - if len(headPods.Items) == 1 { + if !rayClusterScaleExpectation.IsSatisfied(ctx, instance.Name, instance.Namespace, expectations.DefaultHead) { + logger.Info("reconcilePods", "Expectation", "NotSatisfiedHeadExpectations, reconcile head later") + } else if len(headPods.Items) == 1 { headPod := headPods.Items[0] logger.Info("reconcilePods", "Found 1 head Pod", headPod.Name, "Pod status", headPod.Status.Phase, "Pod status reason", headPod.Status.Reason, @@ -696,6 +697,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv headPod.Namespace, headPod.Name, headPod.Status.Phase, headPod.Spec.RestartPolicy, getRayContainerStateTerminated(headPod), err) return errstd.Join(utils.ErrFailedDeleteHeadPod, err) } + rayClusterScaleExpectation.ExpectScalePod(instance.Name, expectations.DefaultHead, headPod.Namespace, headPod.Name, expectations.Delete) r.Recorder.Eventf(instance, corev1.EventTypeNormal, string(utils.DeletedHeadPod), "Deleted head Pod %s/%s; Pod status: %s; Pod restart policy: %s; Ray container terminated status: %v", headPod.Namespace, headPod.Name, headPod.Status.Phase, headPod.Spec.RestartPolicy, getRayContainerStateTerminated(headPod)) @@ -727,11 +729,16 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv if err := r.Delete(ctx, &extraHeadPodToDelete); err != nil { return errstd.Join(utils.ErrFailedDeleteHeadPod, err) } + rayClusterScaleExpectation.ExpectScalePod(instance.Name, expectations.DefaultHead, extraHeadPodToDelete.Namespace, extraHeadPodToDelete.Name, expectations.Delete) } } // Reconcile worker pods now for _, worker := range instance.Spec.WorkerGroupSpecs { + if !rayClusterScaleExpectation.IsSatisfied(ctx, instance.Name, worker.GroupName, instance.Namespace) { + logger.Info("reconcilePods", "Expectation", fmt.Sprintf("NotSatisfiedGroupExpectations, reconcile group %s later", worker.GroupName)) + continue + } // workerReplicas will store the target number of pods for this worker group. var workerReplicas int32 = utils.GetWorkerGroupDesiredReplicas(ctx, worker) logger.Info("reconcilePods", "desired workerReplicas (always adhering to minReplicas/maxReplica)", workerReplicas, "worker group", worker.GroupName, "maxReplicas", worker.MaxReplicas, "minReplicas", worker.MinReplicas, "replicas", worker.Replicas) @@ -757,6 +764,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv workerPod.Namespace, workerPod.Name, workerPod.Status.Phase, workerPod.Spec.RestartPolicy, getRayContainerStateTerminated(workerPod), err) return errstd.Join(utils.ErrFailedDeleteWorkerPod, err) } + rayClusterScaleExpectation.ExpectScalePod(instance.Name, worker.GroupName, workerPod.Namespace, workerPod.Name, expectations.Delete) r.Recorder.Eventf(instance, corev1.EventTypeNormal, string(utils.DeletedWorkerPod), "Deleted worker Pod %s/%s; Pod status: %s; Pod restart policy: %s; Ray container terminated status: %v", workerPod.Namespace, workerPod.Name, workerPod.Status.Phase, workerPod.Spec.RestartPolicy, getRayContainerStateTerminated(workerPod)) @@ -784,6 +792,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv } logger.Info("reconcilePods", "The worker Pod has already been deleted", pod.Name) } else { + rayClusterScaleExpectation.ExpectScalePod(instance.Name, worker.GroupName, pod.Namespace, pod.Name, expectations.Delete) deletedWorkers[pod.Name] = deleted r.Recorder.Eventf(instance, corev1.EventTypeNormal, string(utils.DeletedWorkerPod), "Deleted pod %s/%s", pod.Namespace, pod.Name) } @@ -852,6 +861,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv } logger.Info("reconcilePods", "The worker Pod has already been deleted", randomPodToDelete.Name) } + rayClusterScaleExpectation.ExpectScalePod(instance.Name, worker.GroupName, randomPodToDelete.Namespace, randomPodToDelete.Name, expectations.Delete) r.Recorder.Eventf(instance, corev1.EventTypeNormal, string(utils.DeletedWorkerPod), "Deleted Pod %s/%s", randomPodToDelete.Namespace, randomPodToDelete.Name) } } else { @@ -1017,6 +1027,7 @@ func (r *RayClusterReconciler) createHeadPod(ctx context.Context, instance rayv1 r.Recorder.Eventf(&instance, corev1.EventTypeWarning, string(utils.FailedToCreateHeadPod), "Failed to create head Pod %s/%s, %v", pod.Namespace, pod.Name, err) return err } + rayClusterScaleExpectation.ExpectScalePod(instance.Name, expectations.DefaultHead, pod.Namespace, pod.Name, expectations.Create) logger.Info("Created head Pod for RayCluster", "name", pod.Name) r.Recorder.Eventf(&instance, corev1.EventTypeNormal, string(utils.CreatedHeadPod), "Created head Pod %s/%s", pod.Namespace, pod.Name) return nil @@ -1035,10 +1046,12 @@ func (r *RayClusterReconciler) createWorkerPod(ctx context.Context, instance ray } } - if err := r.Create(ctx, &pod); err != nil { + replica := pod + if err := r.Create(ctx, &replica); err != nil { r.Recorder.Eventf(&instance, corev1.EventTypeWarning, string(utils.FailedToCreateWorkerPod), "Failed to create worker Pod %s/%s, %v", pod.Namespace, pod.Name, err) return err } + rayClusterScaleExpectation.ExpectScalePod(instance.Name, worker.GroupName, replica.Namespace, replica.Name, expectations.Create) logger.Info("Created worker Pod for RayCluster", "name", pod.Name) r.Recorder.Eventf(&instance, corev1.EventTypeNormal, string(utils.CreatedWorkerPod), "Created worker Pod %s/%s", pod.Namespace, pod.Name) return nil diff --git a/ray-operator/controllers/ray/raycluster_controller_test.go b/ray-operator/controllers/ray/raycluster_controller_test.go index 02b9bbab47..a114839702 100644 --- a/ray-operator/controllers/ray/raycluster_controller_test.go +++ b/ray-operator/controllers/ray/raycluster_controller_test.go @@ -21,27 +21,23 @@ import ( "fmt" "time" - "k8s.io/apimachinery/pkg/api/meta" - - "github.com/ray-project/kuberay/ray-operator/pkg/features" - - "github.com/ray-project/kuberay/ray-operator/controllers/ray/common" - "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" - corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" - "k8s.io/utils/ptr" - + "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/util/retry" + "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/common" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" + "github.com/ray-project/kuberay/ray-operator/pkg/features" // +kubebuilder:scaffold:imports ) diff --git a/ray-operator/controllers/ray/raycluster_controller_unit_test.go b/ray-operator/controllers/ray/raycluster_controller_unit_test.go index 4897f48290..bb669aacb3 100644 --- a/ray-operator/controllers/ray/raycluster_controller_unit_test.go +++ b/ray-operator/controllers/ray/raycluster_controller_unit_test.go @@ -25,6 +25,7 @@ import ( rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" "github.com/ray-project/kuberay/ray-operator/controllers/ray/common" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/expectations" "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" "github.com/ray-project/kuberay/ray-operator/pkg/client/clientset/versioned/scheme" "github.com/ray-project/kuberay/ray-operator/pkg/features" @@ -493,7 +494,7 @@ func TestReconcile_RemoveWorkersToDelete_RandomDelete(t *testing.T) { // Simulate the Ray Autoscaler attempting to scale down. assert.Equal(t, expectedNumWorkersToDelete, len(testRayCluster.Spec.WorkerGroupSpecs[0].ScaleStrategy.WorkersToDelete)) - + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, Recorder: &record.FakeRecorder{}, @@ -586,7 +587,7 @@ func TestReconcile_RemoveWorkersToDelete_NoRandomDelete(t *testing.T) { // Simulate the Ray Autoscaler attempting to scale down. assert.Equal(t, expectedNumWorkersToDelete, len(testRayCluster.Spec.WorkerGroupSpecs[0].ScaleStrategy.WorkersToDelete)-tc.numNonExistPods) - + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, Recorder: &record.FakeRecorder{}, @@ -634,7 +635,7 @@ func TestReconcile_RandomDelete_OK(t *testing.T) { assert.Nil(t, err, "Fail to get pod list") assert.Equal(t, len(testPods), len(podList.Items), "Init pod list len is wrong") - + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, Recorder: &record.FakeRecorder{}, @@ -697,6 +698,7 @@ func TestReconcile_PodDeleted_Diff0_OK(t *testing.T) { err = fakeClient.Delete(ctx, &podList.Items[4]) assert.Nil(t, err, "Fail to delete pod") + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize a new RayClusterReconciler. testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -754,6 +756,7 @@ func TestReconcile_PodDeleted_DiffLess0_OK(t *testing.T) { err = fakeClient.Delete(ctx, &podList.Items[3]) assert.Nil(t, err, "Fail to delete pod") + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize a new RayClusterReconciler. testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -809,6 +812,7 @@ func TestReconcile_Diff0_WorkersToDelete_OK(t *testing.T) { assert.Nil(t, err, "Fail to get pod list") assert.Equal(t, oldNumWorkerPods+numHeadPods, len(podList.Items), "Init pod list len is wrong") + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize a new RayClusterReconciler. testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -881,6 +885,7 @@ func TestReconcile_PodCrash_DiffLess0_OK(t *testing.T) { assert.Nil(t, err, "Fail to get pod list") assert.Equal(t, oldNumWorkerPods+numHeadPods, len(podList.Items), "Init pod list len is wrong") + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize a new RayClusterReconciler. testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -943,7 +948,7 @@ func TestReconcile_PodEvicted_DiffLess0_OK(t *testing.T) { WithRuntimeObjects(testPods...). Build() ctx := context.Background() - + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) podList := corev1.PodList{} err := fakeClient.List(ctx, &podList, client.InNamespace(namespaceStr)) @@ -1015,6 +1020,7 @@ func TestReconcileHeadService(t *testing.T) { utils.RayNodeTypeLabelKey: string(rayv1.HeadNode), }) + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize RayCluster reconciler. r := &RayClusterReconciler{ Client: fakeClient, @@ -1084,6 +1090,7 @@ func TestReconcileHeadlessService(t *testing.T) { fakeClient := clientFake.NewClientBuilder().WithScheme(newScheme).WithRuntimeObjects(runtimeObjects...).Build() ctx := context.TODO() + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize RayCluster reconciler. r := &RayClusterReconciler{ Client: fakeClient, @@ -1160,6 +1167,7 @@ func TestReconcile_AutoscalerServiceAccount(t *testing.T) { assert.True(t, k8serrors.IsNotFound(err), "Head group service account should not exist yet") + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, Recorder: &record.FakeRecorder{}, @@ -1192,6 +1200,7 @@ func TestReconcile_Autoscaler_ServiceAccountName(t *testing.T) { fakeClient := clientFake.NewClientBuilder().WithRuntimeObjects(runtimeObjects...).Build() ctx := context.Background() + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize the reconciler testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -1238,6 +1247,7 @@ func TestReconcile_AutoscalerRoleBinding(t *testing.T) { assert.True(t, k8serrors.IsNotFound(err), "autoscaler RoleBinding should not exist yet") + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, Recorder: &record.FakeRecorder{}, @@ -1274,6 +1284,7 @@ func TestReconcile_UpdateClusterReason(t *testing.T) { assert.Nil(t, err, "Fail to get RayCluster") assert.Empty(t, cluster.Status.Reason, "Cluster reason should be empty") + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, Recorder: &record.FakeRecorder{}, @@ -1435,7 +1446,7 @@ func TestGetHeadServiceIPAndName(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { fakeClient := clientFake.NewClientBuilder().WithRuntimeObjects(tc.services...).Build() - + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, Recorder: &record.FakeRecorder{}, @@ -1518,6 +1529,7 @@ func TestUpdateStatusObservedGeneration(t *testing.T) { assert.Equal(t, int64(-1), cluster.Status.ObservedGeneration) assert.Equal(t, int64(0), cluster.ObjectMeta.Generation) + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize RayCluster reconciler. testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -1555,6 +1567,7 @@ func TestReconcile_UpdateClusterState(t *testing.T) { assert.Nil(t, err, "Fail to get RayCluster") assert.Empty(t, cluster.Status.State, "Cluster state should be empty") //nolint:staticcheck // https://github.com/ray-project/kuberay/pull/2288 + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, Recorder: &record.FakeRecorder{}, @@ -1955,6 +1968,7 @@ func Test_TerminatedWorkers_NoAutoscaler(t *testing.T) { assert.Nil(t, err, "Fail to update pod status") } + rayClusterScaleExpectation = expectations.NewFakeRayClusterScaleExpectation() // Initialize a new RayClusterReconciler. testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -2078,6 +2092,7 @@ func Test_TerminatedHead_RestartPolicy(t *testing.T) { err = fakeClient.Status().Update(ctx, &podList.Items[0]) assert.Nil(t, err) + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize a new RayClusterReconciler. testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -2175,6 +2190,7 @@ func Test_RunningPods_RayContainerTerminated(t *testing.T) { err = fakeClient.Status().Update(ctx, &podList.Items[0]) assert.Nil(t, err) + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize a new RayClusterReconciler. testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -2377,6 +2393,7 @@ func Test_RedisCleanupFeatureFlag(t *testing.T) { WithStatusSubresource(cluster). Build() + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize the reconciler testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -2744,6 +2761,7 @@ func TestReconcile_Replicas_Optional(t *testing.T) { assert.Nil(t, err, "Fail to get pod list") assert.Equal(t, oldNumWorkerPods+numHeadPods, len(podList.Items), "Init pod list len is wrong") + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize a new RayClusterReconciler. testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -2836,6 +2854,7 @@ func TestReconcile_Multihost_Replicas(t *testing.T) { assert.Nil(t, err, "Fail to get pod list") assert.Equal(t, oldNumWorkerPods+numHeadPods, len(podList.Items), "Init pod list len is wrong") + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize a new RayClusterReconciler. testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -2904,6 +2923,7 @@ func TestReconcile_NumOfHosts(t *testing.T) { assert.Nil(t, err, "Fail to get pod list") assert.Equal(t, 1, len(podList.Items), "Init pod list len is wrong") + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Initialize a new RayClusterReconciler. testRayClusterReconciler := &RayClusterReconciler{ Client: fakeClient, @@ -3100,6 +3120,7 @@ func TestEvents_FailedPodCreation(t *testing.T) { }).WithRuntimeObjects(testPods...).Build() ctx := context.Background() + rayClusterScaleExpectation = expectations.NewRayClusterScaleExpectation(fakeClient) // Get the pod list from the fake client. podList := corev1.PodList{} err := fakeClient.List(ctx, &podList, client.InNamespace(namespaceStr))