Skip to content

Commit

Permalink
Remove race in coordinator_test so it passes on tsan. Reduce sleep in…
Browse files Browse the repository at this point in the history
…tervals

so it runs faster (11s on tsan instead of ~30s).
Change: 147893428
  • Loading branch information
tensorflower-gardener committed Feb 18, 2017
1 parent 79c3b47 commit eb96240
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions tensorflow/cc/training/coordinator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,28 @@ namespace {

using error::Code;

void WaitForStopThread(Coordinator* coord, bool* stopped, Notification* done) {
void WaitForStopThread(Coordinator* coord, Notification* about_to_wait,
Notification* done) {
about_to_wait->Notify();
coord->WaitForStop();
*stopped = true;
done->Notify();
}

TEST(CoordinatorTest, TestStopAndWaitOnStop) {
Coordinator coord;
EXPECT_EQ(coord.ShouldStop(), false);

bool stopped = false;
Notification about_to_wait;
Notification done;
Env::Default()->SchedClosure(
std::bind(&WaitForStopThread, &coord, &stopped, &done));
Env::Default()->SleepForMicroseconds(10000000);
EXPECT_EQ(stopped, false);
std::bind(&WaitForStopThread, &coord, &about_to_wait, &done));
about_to_wait.WaitForNotification();
Env::Default()->SleepForMicroseconds(1000 * 1000);
EXPECT_FALSE(done.HasBeenNotified());

TF_EXPECT_OK(coord.RequestStop());
done.WaitForNotification();
EXPECT_EQ(stopped, true);
EXPECT_EQ(coord.ShouldStop(), true);
EXPECT_TRUE(coord.ShouldStop());
}

class MockQueueRunner : public RunnerInterface {
Expand All @@ -66,14 +67,16 @@ class MockQueueRunner : public RunnerInterface {
join_counter_ = join_counter;
}

void StartCounting(std::atomic<int>* counter, int until) {
void StartCounting(std::atomic<int>* counter, int until,
Notification* start = nullptr) {
thread_pool_->Schedule(
std::bind(&MockQueueRunner::CountThread, this, counter, until));
std::bind(&MockQueueRunner::CountThread, this, counter, until, start));
}

void StartSettingStatus(const Status& status, BlockingCounter* counter) {
thread_pool_->Schedule(
std::bind(&MockQueueRunner::SetStatusThread, this, status, counter));
void StartSettingStatus(const Status& status, BlockingCounter* counter,
Notification* start) {
thread_pool_->Schedule(std::bind(&MockQueueRunner::SetStatusThread, this,
status, counter, start));
}

Status Join() {
Expand All @@ -93,15 +96,17 @@ class MockQueueRunner : public RunnerInterface {
void Stop() { stopped_ = true; }

private:
void CountThread(std::atomic<int>* counter, int until) {
void CountThread(std::atomic<int>* counter, int until, Notification* start) {
if (start != nullptr) start->WaitForNotification();
while (!coord_->ShouldStop() && counter->load() < until) {
(*counter)++;
Env::Default()->SleepForMicroseconds(100000);
Env::Default()->SleepForMicroseconds(10 * 1000);
}
coord_->RequestStop().IgnoreError();
}
void SetStatusThread(const Status& status, BlockingCounter* counter) {
Env::Default()->SleepForMicroseconds(100000);
void SetStatusThread(const Status& status, BlockingCounter* counter,
Notification* start) {
start->WaitForNotification();
SetStatus(status);
counter->DecrementCount();
}
Expand Down Expand Up @@ -130,20 +135,22 @@ TEST(CoordinatorTest, TestRealStop) {
TF_EXPECT_OK(coord.RequestStop());

int temp_counter = counter.load();
Env::Default()->SleepForMicroseconds(10000000);
Env::Default()->SleepForMicroseconds(1000 * 1000);
EXPECT_EQ(temp_counter, counter.load());
TF_EXPECT_OK(coord.Join());
}

TEST(CoordinatorTest, TestRequestStop) {
Coordinator coord;
std::atomic<int> counter(0);
Notification start;
std::unique_ptr<MockQueueRunner> qr;
for (int i = 0; i < 10; i++) {
qr.reset(new MockQueueRunner(&coord));
qr->StartCounting(&counter, 10);
qr->StartCounting(&counter, 10, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr)));
}
start.Notify();

coord.WaitForStop();
EXPECT_EQ(coord.ShouldStop(), true);
Expand All @@ -168,20 +175,22 @@ TEST(CoordinatorTest, TestJoin) {

TEST(CoordinatorTest, StatusReporting) {
Coordinator coord({Code::CANCELLED, Code::OUT_OF_RANGE});
Notification start;
BlockingCounter counter(3);

std::unique_ptr<MockQueueRunner> qr1(new MockQueueRunner(&coord));
qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter);
qr1->StartSettingStatus(Status(Code::CANCELLED, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr1)));

std::unique_ptr<MockQueueRunner> qr2(new MockQueueRunner(&coord));
qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter);
qr2->StartSettingStatus(Status(Code::INVALID_ARGUMENT, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr2)));

std::unique_ptr<MockQueueRunner> qr3(new MockQueueRunner(&coord));
qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter);
qr3->StartSettingStatus(Status(Code::OUT_OF_RANGE, ""), &counter, &start);
TF_ASSERT_OK(coord.RegisterRunner(std::move(qr3)));

start.Notify();
counter.Wait();
TF_EXPECT_OK(coord.RequestStop());
EXPECT_EQ(coord.Join().code(), Code::INVALID_ARGUMENT);
Expand Down

0 comments on commit eb96240

Please sign in to comment.