Skip to content

Commit

Permalink
Fix waiting for releasables. Add more tests with semaphores.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Apr 24, 2024
1 parent 803b2cd commit bb84b2e
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 2 deletions.
2 changes: 2 additions & 0 deletions dali/core/exec/tasking/scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ void Scheduler::Notify(Waitable *w) {
waiting.emplace_back(w->waiting_[i]);

for (auto &task : waiting) {
if (!is_completion_event && !w->IsAcquirable())
break;
if (task->Ready())
continue;

Expand Down
100 changes: 98 additions & 2 deletions dali/core/exec/tasking_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,19 @@ TEST(TaskingTest, DependentTasksAreSequential) {

std::atomic_int parallel = 0;
std::atomic_int max_parallel = 0;
int last_task_id = -1;
SharedTask last_task;
for (int i = 0; i < num_tasks; i++) {
auto task = Task::Create([&]() {
auto task = Task::Create([&, i]() {
int p = ++parallel;
int expected = max_parallel.load();
while (!max_parallel.compare_exchange_strong(expected, std::max(p, expected))) {}
std::this_thread::sleep_for(std::chrono::milliseconds(1));

--parallel;

if (last_task_id != i - 1)
throw std::runtime_error("The task order is incorrect.");
last_task_id = i;
});
if (last_task)
task->Succeed(last_task);
Expand All @@ -91,6 +95,68 @@ TEST(TaskingTest, DependentTasksAreSequential) {
EXPECT_EQ(parallel, 0) << "The tasks didn't finish cleanly";
}

TEST(TaskingTest, GuardedTasksAreNonParallel) {
int num_threads = 4;
Executor ex(num_threads);
ex.Start();

int num_tasks = 10;

std::atomic_int parallel = 0;
std::atomic_int max_parallel = 0;
std::vector<SharedTask> tasks;
auto sem = std::make_shared<Semaphore>(1);
for (int i = 0; i < num_tasks; i++) {
auto task = Task::Create([&]() {
int p = ++parallel;
int expected = max_parallel.load();
while (!max_parallel.compare_exchange_strong(expected, std::max(p, expected))) {}
std::this_thread::sleep_for(std::chrono::milliseconds(1));
--parallel;
});
task->GuardWith(sem);
ex.AddSilentTask(task);
tasks.push_back(std::move(task));
}
for (auto &t : tasks)
t->Wait();
EXPECT_EQ(max_parallel, 1)
<< "The parallelism counter should not exceed 1 for a group of tasks guarded by a semaphore.";
EXPECT_EQ(parallel, 0) << "The tasks didn't finish cleanly";
}


TEST(TaskingTest, SemaphoreConcurencyLimit) {
int num_threads = 8;
int max_count = 3;
Executor ex(num_threads);
ex.Start();

int num_tasks = 15;

std::atomic_int parallel = 0;
std::atomic_int max_parallel = 0;
std::vector<SharedTask> tasks;
auto sem = std::make_shared<Semaphore>(max_count);
for (int i = 0; i < num_tasks; i++) {
auto task = Task::Create([&]() {
int p = ++parallel;
int expected = max_parallel.load();
while (!max_parallel.compare_exchange_strong(expected, std::max(p, expected))) {}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
--parallel;
});
task->GuardWith(sem);
ex.AddSilentTask(task);
tasks.push_back(std::move(task));
}
for (auto &t : tasks)
t->Wait();
EXPECT_EQ(max_parallel, max_count)
<< "The parallelism counter should not exceed the max. count of a guarding semaphore.";
EXPECT_EQ(parallel, 0) << "The tasks didn't finish cleanly";
}

namespace {

struct NoCopyAtRunTime {
Expand Down Expand Up @@ -305,6 +371,36 @@ TEST(TaskingTest, MultiOutputLifespan) {
EXPECT_EQ(InstanceCounter<int>::num_instances, 0);
}

TEST(TaskingTest, ReleaseAfterRun) {
Scheduler s;
auto sem = std::make_shared<Semaphore>(1);
auto t1 = Task::Create([]() {});
t1->Succeed(sem);
auto t2 = Task::Create([]() {});
t2->Succeed(t1);
t2->ReleaseAfterRun(sem);
auto t3 = Task::Create([]() {});
auto t4 = Task::Create([]() {});
t3->GuardWith(sem); // should work the same as Succeed -> Release
t4->Succeed(sem)->ReleaseAfterRun(sem);
s.AddSilentTask(t1);
s.AddSilentTask(t3);
s.AddSilentTask(t4);
s.AddSilentTask(t2);
auto t = s.Pop();
EXPECT_EQ(t, t1);
t->Run();
t = s.Pop();
EXPECT_EQ(t, t2);
t->Run();
t = s.Pop();
EXPECT_EQ(t, t3);
t->Run();
t = s.Pop();
EXPECT_EQ(t, t4);
t->Run();
}

namespace {

template <typename RNG>
Expand Down

0 comments on commit bb84b2e

Please sign in to comment.