Skip to content

Commit

Permalink
#2387: use gather instead of allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
cwschilly committed Jan 13, 2025
1 parent 779e46d commit d25f3f2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
42 changes: 35 additions & 7 deletions src/vt/trace/trace_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
#include <sys/stat.h>
#include <zlib.h>
#include <map>
#include <numeric>

namespace vt {
#if vt_check_enabled(trace_only)
Expand Down Expand Up @@ -542,15 +543,42 @@ void TraceLite::flushTracesFile(bool useGlobalSync) {

void TraceLite::writeTracesFile(int flush, bool is_incremental_flush) {
auto const node = theContext()->getNode();
auto const comm = theContext()->getComm();
auto const comm_size = theContext()->getNumNodes();

// Allreduce the hashed events to rank 0 before writing sts file
// Gather all hashed events to rank 0 before writing sts file
using events_t = std::vector<UserEventIDType>;
auto const root = 0;
std::vector<UserEventIDType> all_hashed_events;
auto msg = makeMessage<ReduceVecMsg<UserEventIDType>>(
theTrace()->user_hashed_events_);
theCollective()->global()->reduce<
PlusOp<std::vector<UserEventIDType>>, Verify<ReduceOP::Plus>
>(root, msg.get());
events_t local_hashed_events = theTrace()->user_hashed_events_;
int local_size = local_hashed_events.size();
std::vector<int> all_sizes(comm_size);
MPI_Gather(&local_size, 1, MPI_INT, all_sizes.data(), 1, MPI_INT, 0, comm);

// Compute displacements
std::vector<int> displs(comm_size, 0);
if (node == 0) {
std::partial_sum(all_sizes.begin(), all_sizes.end() - 1, displs.begin() + 1);
}

// Create vector in which to store all events
events_t all_hashed_events;
if (node == 0) {
int total_size = std::accumulate(all_sizes.begin(), all_sizes.end(), 0);
all_hashed_events.resize(total_size);
}

// Gather events
MPI_Gatherv(
local_hashed_events.data(), // Send buffer
local_size, // Number of elements to send
MPI_UINT32_T, // Data type (adjust to match UserEventIDType)
all_hashed_events.data(), // Receive buffer (on root)
all_sizes.data(), // Number of elements to receive from each rank
displs.data(), // Displacements for each rank
MPI_UINT32_T, // Data type (adjust to match UserEventIDType)
root, // Root node
comm // Communicator
);

size_t to_write = traces_.size();

Expand Down
2 changes: 1 addition & 1 deletion src/vt/trace/trace_user_event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ UserEventIDType UserEventRegistry::hash(std::string const& in_event_name) {
auto id = std::get<0>(ret);
auto inserted = std::get<1>(ret);
if (inserted) {
theTrace->addHashedEvent(id);
vt::theTrace()->addHashedEvent(id);
}
return id;
}
Expand Down

0 comments on commit d25f3f2

Please sign in to comment.