diff --git a/pkg/apb/watch_pod_test.go b/pkg/apb/watch_pod_test.go index d906991c01..985f9a3814 100644 --- a/pkg/apb/watch_pod_test.go +++ b/pkg/apb/watch_pod_test.go @@ -3,8 +3,6 @@ package apb import ( "testing" - "time" - "fmt" core1 "k8s.io/api/core/v1" @@ -28,7 +26,7 @@ func TestWatchPod(t *testing.T) { PodClient func() (v1.PodInterface, *watch.FakeWatcher) UpdatePodStates func(watcher *watch.FakeWatcher) ExpectError bool - Validate func(status []JobState) error + Validate func(status []string) error }{ { Name: "should get error and state update when pod fails", @@ -63,16 +61,13 @@ func TestWatchPod(t *testing.T) { }} podStateUpdater(watcher, podStates) }, - Validate: func(status []JobState) error { + Validate: func(status []string) error { if len(status) != 2 { return fmt.Errorf("expected 2 status updates") } for i, s := range status { - if s.Description != fmt.Sprintf("lastop%v", i) { - return fmt.Errorf("expected description to be lastop%v but got %v", i, s.Description) - } - if s.State != StateInProgress { - return fmt.Errorf("expected state to be %v but was %v", StateInProgress, s.State) + if s != fmt.Sprintf("lastop%v", i) { + return fmt.Errorf("expected description to be lastop%v but got %v", i, s) } } return nil @@ -111,16 +106,13 @@ func TestWatchPod(t *testing.T) { podStateUpdater(watcher, podStates) }, - Validate: func(status []JobState) error { + Validate: func(status []string) error { if len(status) != 2 { return fmt.Errorf("expected 2 status updates") } for i, s := range status { - if s.Description != fmt.Sprintf("lastop%v", i) { - return fmt.Errorf("expected description to be lastop%v but got %v", i, s.Description) - } - if s.State != StateInProgress { - return fmt.Errorf("expected state to be %v but was %v", StateInProgress, s.State) + if s != fmt.Sprintf("lastop%v", i) { + return fmt.Errorf("expected description to be lastop%v but got %v", i, s) } } return nil @@ -150,16 +142,13 @@ func TestWatchPod(t *testing.T) { watcher.Delete(podStates[0]) }, ExpectError: true, - Validate: func(status []JobState) error { + Validate: func(status []string) error { if len(status) != 2 { return fmt.Errorf("expected 2 status updates") } for _, s := range status { - if s.Description != "lastop0" { - return fmt.Errorf("expected description to be lastop0 but got %v", s.Description) - } - if s.State != StateInProgress { - return fmt.Errorf("expected state to be %v but was %v", StateInProgress, s.State) + if s != "lastop0" { + return fmt.Errorf("expected description to be lastop0 but got %v", s) } } return nil @@ -169,22 +158,25 @@ func TestWatchPod(t *testing.T) { for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { - statusReceiver := make(chan JobState) - podClient, podWatch := tc.PodClient() - time.AfterFunc(100*time.Millisecond, func() { - close(statusReceiver) - }) var watchErr error + podClient, podWatch := tc.PodClient() + descriptions := []string{} + done := make(chan bool) + go func() { - watchErr = watchPod("test", "test", podClient, statusReceiver) + watchErr = watchPod("test", "test", podClient, func(d string) { + fmt.Printf("NSK: got d -> %v\n", d) + descriptions = append(descriptions, d) + }) + done <- true }() go tc.UpdatePodStates(podWatch) - var state []JobState - for s := range statusReceiver { - state = append(state, s) - } + + <-done + if nil != tc.Validate { - if err := tc.Validate(state); err != nil { + fmt.Printf("NSK: Now trying to validate the descriptions: %v", descriptions) + if err := tc.Validate(descriptions); err != nil { t.Fatal("unexpected errror validating job state", err) } }