Skip to content

Commit

Permalink
#1265: replay: allow custom stats callback for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
nlslatt committed May 9, 2022
1 parent c4454cf commit b4b96f6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
13 changes: 8 additions & 5 deletions src/vt/vrt/collection/balance/workload_replay.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,18 @@ void replayWorkloads(
auto const filename = theConfig()->getLBDataFileIn();
auto workloads = readInWorkloads(filename);

replayWorkloads(initial_phase, phases_to_run, workloads);
// use the default stats handler
auto stats_cb = vt::theCB()->makeBcast<
LBManager, balance::NodeStatsMsg, &LBManager::statsHandler
>(theLBManager()->getProxy());

replayWorkloads(initial_phase, phases_to_run, workloads, stats_cb);
}

void replayWorkloads(
PhaseType initial_phase, PhaseType phases_to_run,
std::shared_ptr<LBDataHolder> workloads
std::shared_ptr<LBDataHolder> workloads,
Callback<balance::NodeStatsMsg> stats_cb
) {
using ObjIDType = elm::ElementIDStruct;

Expand Down Expand Up @@ -202,9 +208,6 @@ void replayWorkloads(
"Number of objects after LB: {}\n", migratable_objects_here.size()
);
runInEpochCollective("postLBWorkForReplay -> computeStats", [=] {
auto stats_cb = vt::theCB()->makeBcast<
LBManager, balance::NodeStatsMsg, &LBManager::statsHandler
>(theLBManager()->getProxy());
theLBManager()->computeStatistics(
proposed_model, false, phase, stats_cb
);
Expand Down
4 changes: 3 additions & 1 deletion src/vt/vrt/collection/balance/workload_replay.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

#include "vt/config.h"
#include "vt/elm/elm_id.h"
#include "vt/vrt/collection/balance/stats_msg.h"
#include "vt/vrt/collection/balance/lb_data_holder.h"
#include "vt/vrt/collection/balance/baselb/baselb.h"
#include "vt/vrt/collection/balance/model/load_model.h"
Expand Down Expand Up @@ -91,7 +92,8 @@ void replayWorkloads(
*/
void replayWorkloads(
PhaseType initial_phase, PhaseType phases_to_run,
std::shared_ptr<LBDataHolder> workloads
std::shared_ptr<LBDataHolder> workloads,
Callback<balance::NodeStatsMsg> stats_cb
);

/**
Expand Down
22 changes: 19 additions & 3 deletions tests/unit/collection/test_workload_data_migrator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

#include "vt/elm/elm_id.h"
#include "vt/elm/elm_id_bits.h"
#include "vt/vrt/collection/balance/stats_msg.h"
#include "vt/vrt/collection/balance/lb_common.h"
#include "vt/vrt/collection/balance/lb_data_holder.h"
#include "vt/vrt/collection/balance/lb_invoke/lb_manager.h"
Expand All @@ -72,6 +73,10 @@ std::shared_ptr<LBDataHolder>
setupWorkloads(PhaseType phase, size_t numElements) {
auto const& this_node = vt::theContext()->getNode();

if (this_node == 0) {
vt_print(replay, "Generating workloads to replay...\n");
}

using vt::vrt::collection::balance::ElementIDStruct;

std::vector<ElementIDStruct> myElemList(numElements);
Expand Down Expand Up @@ -564,6 +569,10 @@ setupManyWorkloads(
) {
auto const& this_node = vt::theContext()->getNode();

if (this_node == 0) {
vt_print(replay, "Generating workloads to replay...\n");
}

using vt::vrt::collection::balance::ElementIDStruct;

std::vector<ElementIDStruct> myElemList(numElements);
Expand Down Expand Up @@ -634,22 +643,29 @@ struct TestWorkloadReplay : TestParallelHarness {
void addAdditionalArgs() override {
static char vt_lb[]{"--vt_lb"};
static char vt_lb_name[]{"--vt_lb_name=RandomLB"};
addArgs(vt_lb, vt_lb_name);
static char vt_lb_interval[]{"--vt_lb_interval=2"};
addArgs(vt_lb, vt_lb_name, vt_lb_interval);
}
#endif
};

TEST_F(TestWorkloadReplay, test_run_replay_no_verify) {
PhaseType initial_phase = 1;
PhaseType num_phases = 3;
PhaseType num_phases = 5;
const size_t numElements = 5;

// first set up the workloads to replay, moving them around by phase
auto lbdh = setupManyWorkloads(initial_phase, num_phases, numElements);

using LBManager = vt::vrt::collection::balance::LBManager;
using NodeStatsMsg = vt::vrt::collection::balance::NodeStatsMsg;
auto stats_cb = vt::theCB()->makeBcast<
LBManager, NodeStatsMsg, &LBManager::statsHandler
>(vt::theLBManager()->getProxy());

// then replay them but allow the lb to place objects differently
vt::vrt::collection::balance::replay::replayWorkloads(
initial_phase, num_phases, lbdh
initial_phase, num_phases, lbdh, stats_cb
);
}

Expand Down

0 comments on commit b4b96f6

Please sign in to comment.