Skip to content

Commit

Permalink
refactoer (#1145)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhujl1991 authored Mar 18, 2020
1 parent f6433c5 commit 63baf43
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
18 changes: 9 additions & 9 deletions pkg/common/util/v1/testutil/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ var (
controllerKind = tfv1.SchemeGroupVersionKind
)

func NewBasePod(name string, tfJob *tfv1.TFJob, t *testing.T) *v1.Pod {
func NewBasePod(name string, tfJob *tfv1.TFJob) *v1.Pod {
return &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Expand All @@ -46,18 +46,18 @@ func NewBasePod(name string, tfJob *tfv1.TFJob, t *testing.T) *v1.Pod {
}
}

func NewPod(tfJob *tfv1.TFJob, typ string, index int, t *testing.T) *v1.Pod {
pod := NewBasePod(fmt.Sprintf("%s-%d", typ, index), tfJob, t)
func NewPod(tfJob *tfv1.TFJob, typ string, index int) *v1.Pod {
pod := NewBasePod(fmt.Sprintf("%s-%d", typ, index), tfJob)
pod.Labels[tfReplicaTypeLabel] = typ
pod.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index)
return pod
}

// create count pods with the given phase for the given tfJob
func NewPodList(count int32, status v1.PodPhase, tfJob *tfv1.TFJob, typ string, start int32, t *testing.T) []*v1.Pod {
func NewPodList(count int32, status v1.PodPhase, tfJob *tfv1.TFJob, typ string, start int32) []*v1.Pod {
pods := []*v1.Pod{}
for i := int32(0); i < count; i++ {
newPod := NewPod(tfJob, typ, int(start+i), t)
newPod := NewPod(tfJob, typ, int(start+i))
newPod.Status = v1.PodStatus{Phase: status}
pods = append(pods, newPod)
}
Expand All @@ -66,13 +66,13 @@ func NewPodList(count int32, status v1.PodPhase, tfJob *tfv1.TFJob, typ string,

func SetPodsStatuses(podIndexer cache.Indexer, tfJob *tfv1.TFJob, typ string, pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32, t *testing.T) {
var index int32
for _, pod := range NewPodList(pendingPods, v1.PodPending, tfJob, typ, index, t) {
for _, pod := range NewPodList(pendingPods, v1.PodPending, tfJob, typ, index) {
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err)
}
}
index += pendingPods
for i, pod := range NewPodList(activePods, v1.PodRunning, tfJob, typ, index, t) {
for i, pod := range NewPodList(activePods, v1.PodRunning, tfJob, typ, index) {
if restartCounts != nil {
pod.Status.ContainerStatuses = []v1.ContainerStatus{{RestartCount: restartCounts[i]}}
}
Expand All @@ -81,13 +81,13 @@ func SetPodsStatuses(podIndexer cache.Indexer, tfJob *tfv1.TFJob, typ string, pe
}
}
index += activePods
for _, pod := range NewPodList(succeededPods, v1.PodSucceeded, tfJob, typ, index, t) {
for _, pod := range NewPodList(succeededPods, v1.PodSucceeded, tfJob, typ, index) {
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err)
}
}
index += succeededPods
for _, pod := range NewPodList(failedPods, v1.PodFailed, tfJob, typ, index, t) {
for _, pod := range NewPodList(failedPods, v1.PodFailed, tfJob, typ, index) {
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/controller.v1/tensorflow/pod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import (
"k8s.io/client-go/rest"
"k8s.io/kubernetes/pkg/controller"

"github.com/kubeflow/tf-operator/cmd/tf-operator.v1/app/options"
common "github.com/kubeflow/common/job_controller/api/v1"
"github.com/kubeflow/tf-operator/cmd/tf-operator.v1/app/options"
tfv1 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1"
tfjobclientset "github.com/kubeflow/tf-operator/pkg/client/clientset/versioned"
"github.com/kubeflow/tf-operator/pkg/common/util/v1/testutil"
Expand Down Expand Up @@ -89,7 +89,7 @@ func TestAddPod(t *testing.T) {
if err := tfJobIndexer.Add(unstructured); err != nil {
t.Errorf("Failed to add tfjob to tfJobIndexer: %v", err)
}
pod := testutil.NewPod(tfJob, testutil.LabelWorker, 0, t)
pod := testutil.NewPod(tfJob, testutil.LabelWorker, 0)
ctr.AddPod(pod)

syncChan <- "sync"
Expand Down Expand Up @@ -317,7 +317,7 @@ func TestExitCode(t *testing.T) {
if err := tfJobIndexer.Add(unstructured); err != nil {
t.Errorf("Failed to add tfjob to tfJobIndexer: %v", err)
}
pod := testutil.NewPod(tfJob, testutil.LabelWorker, 0, t)
pod := testutil.NewPod(tfJob, testutil.LabelWorker, 0)
pod.Status.Phase = v1.PodFailed
pod.Spec.Containers = append(pod.Spec.Containers, v1.Container{})
pod.Status.ContainerStatuses = append(pod.Status.ContainerStatuses, v1.ContainerStatus{
Expand Down
6 changes: 3 additions & 3 deletions pkg/controller.v1/tensorflow/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import (
"k8s.io/client-go/tools/record"
"k8s.io/kubernetes/pkg/controller"

"github.com/kubeflow/tf-operator/cmd/tf-operator.v1/app/options"
common "github.com/kubeflow/common/job_controller/api/v1"
"github.com/kubeflow/tf-operator/cmd/tf-operator.v1/app/options"
tfv1 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1"
tfjobclientset "github.com/kubeflow/tf-operator/pkg/client/clientset/versioned"
"github.com/kubeflow/tf-operator/pkg/common/util/v1/testutil"
Expand Down Expand Up @@ -65,7 +65,7 @@ func TestFailed(t *testing.T) {

tfJob := testutil.NewTFJob(3, 0)
initializeTFReplicaStatuses(tfJob, tfv1.TFReplicaTypeWorker)
pod := testutil.NewBasePod("pod", tfJob, t)
pod := testutil.NewBasePod("pod", tfJob)
pod.Status.Phase = v1.PodFailed
updateTFJobReplicaStatuses(tfJob, tfv1.TFReplicaTypeWorker, pod)
if tfJob.Status.ReplicaStatuses[common.ReplicaType(tfv1.TFReplicaTypeWorker)].Failed != 1 {
Expand Down Expand Up @@ -465,7 +465,7 @@ func TestStatus(t *testing.T) {
}

func setStatusForTest(tfJob *tfv1.TFJob, typ tfv1.TFReplicaType, failed, succeeded, active int32, t *testing.T) {
pod := testutil.NewBasePod("pod", tfJob, t)
pod := testutil.NewBasePod("pod", tfJob)
var i int32
for i = 0; i < failed; i++ {
pod.Status.Phase = v1.PodFailed
Expand Down

0 comments on commit 63baf43

Please sign in to comment.