Skip to content

Commit

Permalink
Generalize to include pre/post tests and combinations of both
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Oct 17, 2024
1 parent 7873501 commit 3088fce
Showing 1 changed file with 91 additions and 42 deletions.
133 changes: 91 additions & 42 deletions cpp/tests/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ namespace {
using ::testing::Combine;
using ::testing::Values;

enum class GenericCallbackType {
None = 0,
Pre,
Post,
PrePost,
PostPre,
};

struct ExtraParams {
GenericCallbackType genericCallbackType{GenericCallbackType::None};
};

class WorkerTest : public ::testing::Test {
protected:
std::shared_ptr<ucxx::Context> _context{
Expand All @@ -44,16 +56,18 @@ class WorkerCapabilityTest : public ::testing::Test,
}
};

class WorkerProgressTest : public WorkerTest,
public ::testing::WithParamInterface<std::tuple<bool, ProgressMode>> {
class WorkerProgressTest
: public WorkerTest,
public ::testing::WithParamInterface<std::tuple<bool, ProgressMode, ExtraParams>> {
protected:
std::function<void()> _progressWorker;
bool _enableDelayedSubmission;
ProgressMode _progressMode;
ExtraParams _extraParams;

void SetUp()
{
std::tie(_enableDelayedSubmission, _progressMode) = GetParam();
std::tie(_enableDelayedSubmission, _progressMode, _extraParams) = GetParam();

_worker = _context->createWorker(_enableDelayedSubmission);

Expand All @@ -70,6 +84,8 @@ class WorkerProgressTest : public WorkerTest,

class WorkerGenericCallbackTest : public WorkerProgressTest {};

class WorkerGenericCallbackSingleTest : public WorkerProgressTest {};

TEST_F(WorkerTest, HandleIsValid) { ASSERT_TRUE(_worker->getHandle() != nullptr); }

TEST_P(WorkerCapabilityTest, CheckCapability)
Expand Down Expand Up @@ -323,38 +339,33 @@ TEST_P(WorkerProgressTest, ProgressTagMulti)
}
}

TEST_P(WorkerGenericCallbackTest, RegisterGenericPre)
TEST_P(WorkerGenericCallbackTest, RegisterGeneric)
{
bool done = false;
auto callback = [&done]() { done = true; };

ASSERT_TRUE(_worker->registerGenericPre(callback));
ASSERT_TRUE(done);
}

TEST_P(WorkerGenericCallbackTest, RegisterGenericPost)
{
bool done = false;
auto callback = [&done]() { done = true; };

ASSERT_TRUE(_worker->registerGenericPost(callback));
ASSERT_TRUE(done);
}

TEST_P(WorkerGenericCallbackTest, RegisterGenericPrePost)
{
bool donePre = false;
bool donePost = false;
auto callbackPre = [&donePre]() { donePre = true; };
auto callbackPost = [&donePost]() { donePost = true; };

ASSERT_TRUE(_worker->registerGenericPre(callbackPre));
ASSERT_TRUE(_worker->registerGenericPost(callbackPost));
ASSERT_TRUE(donePre);
ASSERT_TRUE(donePost);
bool done1 = false;
bool done2 = false;
auto callback1 = [&done1]() { done1 = true; };
auto callback2 = [&done2]() { done2 = true; };

if (_extraParams.genericCallbackType == GenericCallbackType::Pre) {
ASSERT_TRUE(_worker->registerGenericPre(callback1));
ASSERT_TRUE(done1);
} else if (_extraParams.genericCallbackType == GenericCallbackType::Post) {
ASSERT_TRUE(_worker->registerGenericPre(callback1));
ASSERT_TRUE(done1);
} else if (_extraParams.genericCallbackType == GenericCallbackType::PrePost) {
ASSERT_TRUE(_worker->registerGenericPre(callback1));
ASSERT_TRUE(_worker->registerGenericPost(callback2));
ASSERT_TRUE(done1);
ASSERT_TRUE(done2);
} else if (_extraParams.genericCallbackType == GenericCallbackType::PostPre) {
ASSERT_TRUE(_worker->registerGenericPost(callback1));
ASSERT_TRUE(_worker->registerGenericPre(callback2));
ASSERT_TRUE(done1);
ASSERT_TRUE(done2);
}
}

TEST_P(WorkerGenericCallbackTest, RegisterGenericPreCancel)
TEST_P(WorkerGenericCallbackTest, RegisterGenericCancel)
{
bool threadStarted = false;
bool terminateThread = false;
Expand All @@ -378,7 +389,13 @@ TEST_P(WorkerGenericCallbackTest, RegisterGenericPreCancel)
}
};

ASSERT_TRUE(_worker->registerGenericPre(threadCallback));
if (_extraParams.genericCallbackType == GenericCallbackType::Pre ||
_extraParams.genericCallbackType == GenericCallbackType::PrePost) {
ASSERT_TRUE(_worker->registerGenericPre(threadCallback));
} else if (_extraParams.genericCallbackType == GenericCallbackType::Post ||
_extraParams.genericCallbackType == GenericCallbackType::PostPre) {
ASSERT_TRUE(_worker->registerGenericPost(threadCallback));
}
});

{
Expand All @@ -388,7 +405,14 @@ TEST_P(WorkerGenericCallbackTest, RegisterGenericPreCancel)
}

// The thread should be running, therefore the callback will be canceled before running.
ASSERT_FALSE(_worker->registerGenericPre(callback, 1));
// Note here `PrePost`/`PostPre` order is the opposite as from `thread`.
if (_extraParams.genericCallbackType == GenericCallbackType::Pre ||
_extraParams.genericCallbackType == GenericCallbackType::PostPre) {
ASSERT_FALSE(_worker->registerGenericPre(callback, 1));
} else if (_extraParams.genericCallbackType == GenericCallbackType::Post ||
_extraParams.genericCallbackType == GenericCallbackType::PrePost) {
ASSERT_FALSE(_worker->registerGenericPost(callback, 1));
}
ASSERT_FALSE(done);

// Unblock thread to terminate.
Expand All @@ -397,11 +421,18 @@ TEST_P(WorkerGenericCallbackTest, RegisterGenericPreCancel)
thread.join();

// Nothing should be blocking the progress thread now, the callback should succeed.
ASSERT_TRUE(_worker->registerGenericPre(callback));
// Note here `PrePost`/`PostPre` order is the opposite as from `thread`.
if (_extraParams.genericCallbackType == GenericCallbackType::Pre ||
_extraParams.genericCallbackType == GenericCallbackType::PostPre) {
ASSERT_TRUE(_worker->registerGenericPre(callback));
} else if (_extraParams.genericCallbackType == GenericCallbackType::Post ||
_extraParams.genericCallbackType == GenericCallbackType::PrePost) {
ASSERT_TRUE(_worker->registerGenericPost(callback));
}
ASSERT_TRUE(done);
}

TEST_P(WorkerGenericCallbackTest, RegisterGenericPreUncancelable)
TEST_P(WorkerGenericCallbackSingleTest, RegisterGenericPreUncancelable)
{
bool terminateThread = false;
bool match = false;
Expand All @@ -419,7 +450,10 @@ TEST_P(WorkerGenericCallbackTest, RegisterGenericPreUncancelable)

// This will submit the callback and attempt to cancel once every 1ms,
// a warning is logged when multiples of 10 attempts to cancel are made.
ASSERT_TRUE(_worker->registerGenericPre(threadCallback, 1000000 /* 1ms */));
if (_extraParams.genericCallbackType == GenericCallbackType::Pre)
ASSERT_TRUE(_worker->registerGenericPre(threadCallback, 1000000 /* 1ms */));
else if (_extraParams.genericCallbackType == GenericCallbackType::Post)
ASSERT_TRUE(_worker->registerGenericPost(threadCallback, 1000000 /* 1ms */));
});

loopWithTimeout(std::chrono::milliseconds(5000), [&match] {
Expand Down Expand Up @@ -450,16 +484,31 @@ INSTANTIATE_TEST_SUITE_P(ProgressModes,
ProgressMode::Blocking,
ProgressMode::Wait,
ProgressMode::ThreadPolling,
ProgressMode::ThreadBlocking)));
ProgressMode::ThreadBlocking),
Values(ExtraParams{})));

INSTANTIATE_TEST_SUITE_P(DelayedSubmission,
WorkerProgressTest,
Combine(Values(true),
Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking),
Values(ExtraParams{})));

INSTANTIATE_TEST_SUITE_P(
GenericCallbacks,
WorkerGenericCallbackTest,
Combine(Values(false), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking)));
Combine(Values(false, true),
Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking),
Values(ExtraParams{.genericCallbackType = GenericCallbackType::Pre},
ExtraParams{.genericCallbackType = GenericCallbackType::Post},
ExtraParams{.genericCallbackType = GenericCallbackType::PrePost},
ExtraParams{.genericCallbackType = GenericCallbackType::PostPre})));

INSTANTIATE_TEST_SUITE_P(
DelayedSubmission,
WorkerProgressTest,
Combine(Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking)));
GenericCallbacksSingle,
WorkerGenericCallbackSingleTest,
Combine(Values(false, true),
Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking),
Values(ExtraParams{.genericCallbackType = GenericCallbackType::Pre},
ExtraParams{.genericCallbackType = GenericCallbackType::Post})));

} // namespace

0 comments on commit 3088fce

Please sign in to comment.