Skip to content

Commit

Permalink
#1867: add new test for broadcast on insertable collections
Browse files Browse the repository at this point in the history
  • Loading branch information
nmm0 committed Jul 26, 2022
1 parent 41a66fb commit 9935b35
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 3 deletions.
11 changes: 11 additions & 0 deletions tests/unit/collection/test_broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ TYPED_TEST_P(TestBroadcast, test_broadcast_basic_1) {
test_broadcast_1<TypeParam>("test_broadcast_basic_1");
}

TYPED_TEST_P(TestBroadcastDynamic, test_broadcast_dynamic_basic_1) {
test_broadcast_dynamic_1<TypeParam>("test_broadcast_dynamic_basic_1");
}

std::unordered_map<TestIndex, bool> DynamicCountFun::index_map{};

REGISTER_TYPED_TEST_SUITE_P(TestBroadcast, test_broadcast_basic_1);
REGISTER_TYPED_TEST_SUITE_P(TestBroadcastDynamic, test_broadcast_dynamic_basic_1);

using CollectionTestTypesBasic = testing::Types<
bcast_col_ ::TestCol<int32_t>
Expand All @@ -65,4 +72,8 @@ INSTANTIATE_TYPED_TEST_SUITE_P(
test_bcast_basic, TestBroadcast, CollectionTestTypesBasic, DEFAULT_NAME_GEN
);

INSTANTIATE_TYPED_TEST_SUITE_P(
test_bcast_basic, TestBroadcastDynamic, CollectionTestTypesBasic, DEFAULT_NAME_GEN
);

}}}} // end namespace vt::tests::unit::bcast
84 changes: 84 additions & 0 deletions tests/unit/collection/test_broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@

#include <cstdint>

#define PRINT_CONSTRUCTOR_VALUES 1

namespace vt { namespace tests { namespace unit { namespace bcast {

using namespace vt;
Expand Down Expand Up @@ -115,8 +117,11 @@ struct BroadcastHandlers {

template <typename CollectionT>
struct TestBroadcast : TestParallelHarness {};
template <typename CollectionT>
struct TestBroadcastDynamic : TestParallelHarness {};

TYPED_TEST_SUITE_P(TestBroadcast);
TYPED_TEST_SUITE_P(TestBroadcastDynamic);

template<typename ColType>
void test_broadcast_1(std::string const& label) {
Expand All @@ -142,6 +147,85 @@ void test_broadcast_1(std::string const& label) {
}
}

struct DynamicCountMsg : vt::Message {
TestIndex idx;
};

struct DynamicCountFun {
static std::unordered_map<TestIndex, bool> index_map;
void operator()(DynamicCountMsg *msg) const {
index_map.at(msg->idx) = true;
}
};

template <
typename CollectionT,
typename MessageT = typename CollectionT::MsgType,
typename TupleT = typename MessageT::TupleType
>
struct DynamicBroadcastHandlers : BroadcastHandlers<CollectionT, MessageT, TupleT> {
static void track_handler(MessageT* msg, CollectionT* col) {
BroadcastHandlers<CollectionT, MessageT, TupleT>::handler(msg, col);
fmt::print("{}: setting index at {} to true\n", ::vt::theContext()->getNode(), col->getIndex());
DynamicCountFun::index_map.at(col->getIndex()) = true;
}
};

template<typename ColType>
void test_broadcast_dynamic_1(std::string const& label) {
using MsgType = typename ColType::MsgType;
using TestParamType = typename ColType::ParamType;

auto const& this_node = theContext()->getNode();
typename ColType::CollectionProxyType proxy = {};

auto const& col_size = 32;
auto range = TestIndex(col_size);
proxy = makeCollection<ColType>(label).collective(true).dynamicMembership(true).bounds(range).wait();

DynamicCountFun::index_map = std::unordered_map<TestIndex, bool>{
{ TestIndex{0}, false },
{ TestIndex{7}, false },
{ TestIndex{23}, false },
{ TestIndex{31}, false }
};

auto modify_token = proxy.beginModification(fmt::format("{}.beginModification", label));
if (this_node == 0) {
int count = 0;
for ( auto &&entry : DynamicCountFun::index_map )
{
proxy[entry.first].insertAt(modify_token, count++ % theContext()->getNumNodes());
}
}
proxy.finishModification(std::move(modify_token));

runInEpochCollective([this_node, proxy](){
if (this_node == 0) {
TestParamType args = ConstructTuple<TestParamType>::construct();
proxy.template broadcast<
MsgType,
BroadcastHandlers<ColType>::handler
>(args);

auto msg = makeMessage<MsgType>(args);
theCollection()->broadcastMsg<
MsgType,DynamicBroadcastHandlers<ColType>::track_handler
>(proxy, msg.get());
}
});

// Check to make sure we received the broadcast
int count = 0;
for ( auto &&entry : DynamicCountFun::index_map )
{
if ((count++ % theContext()->getNumNodes()) == this_node) {
fmt::print("{}: index at {} is {}\n", ::vt::theContext()->getNode(), entry.first, entry.second);
EXPECT_TRUE(entry.second);
}
}
}

}}}} // end namespace vt::tests::unit::bcast

#endif /*INCLUDED_UNIT_COLLECTION_TEST_BROADCAST_H*/
2 changes: 1 addition & 1 deletion tests/unit/collection/test_send.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ TYPED_TEST_P(TestCollectionSend, test_collection_send_basic_1) {
}

TYPED_TEST_P(TestCollectionSendSz, test_collection_send_sz_basic_1) {
test_collection_send_sz_1<TypeParam>();
test_collection_send_sz_1<TypeParam>("test_collection_send_sz_basic_1");
}

TYPED_TEST_P(TestCollectionSendMem, test_collection_send_ptm_basic_1) {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/collection/test_send.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void test_collection_send_1(std::string const& label) {
}

template <typename ColType>
void test_collection_send_sz_1() {
void test_collection_send_sz_1(std::string const& label) {
using PayloadType = typename ColType::MsgType::TupleType;
using MsgType = typename ColType::template MsgSzType<PayloadType>;
using TestParamType = typename ColType::ParamType;
Expand All @@ -216,7 +216,7 @@ void test_collection_send_sz_1() {
auto const& col_size = 32;
auto range = TestIndex(col_size);
TestParamType args = ConstructTuple<TestParamType>::construct();
auto proxy = theCollection()->construct<ColType>(range);
auto proxy = theCollection()->construct<ColType>(range, label);
for (int i = 0; i < col_size; i++) {
auto msg = makeMessageSz<MsgType>(sizeof(PayloadType));
EXPECT_EQ(msg.size(), sizeof(MsgType) + sizeof(PayloadType));
Expand Down

0 comments on commit 9935b35

Please sign in to comment.