Skip to content

Commit

Permalink
Fix the error regarding MPI usage before initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
JaeseungYeom committed Sep 19, 2024
1 parent e519e62 commit 1fad494
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 25 deletions.
14 changes: 12 additions & 2 deletions src/AMSlib/wf/workflow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class AMSWorkflow
std::shared_ptr<ams::db::BaseDB> DB;

/** @brief The process id. For MPI runs this is the rank */
const int rId;
int rId;

/** @brief The total number of processes participating in the simulation
* (world_size for MPI) */
Expand Down Expand Up @@ -247,7 +247,17 @@ class AMSWorkflow
void set_physics(AMSPhysicFn _AppCall) { AppCall = _AppCall; }

#ifdef __ENABLE_MPI__
void set_communicator(MPI_Comm communicator) { comm = communicator; }
void set_communicator(MPI_Comm communicator) {
comm = communicator;

if (comm == MPI_COMM_NULL) {
rId = 0;
wSize = 1;
} else {
MPI_Comm_rank(comm, &rId);
MPI_Comm_size(comm, &wSize);
}
}
#endif

void set_exec_policy(AMSExecPolicy policy) { ePolicy = policy; }
Expand Down
59 changes: 36 additions & 23 deletions tests/AMSlib/ams_ete.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,36 +177,49 @@ int main(int argc, char **argv)
AMSCAbstrModel model_descr = AMSRegisterAbstractModel(
"test", uq_policy, threshold, model_path, nullptr, "test", -1);

int process_id = 0;
int world_size = 1;

#ifdef __ENABLE_MPI__
if (use_mpi) {
MPI_Comm_rank(MPI_COMM_WORLD, &process_id);
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
}
#endif // __ENABLE_MPI__


if (data_type == AMSDType::AMS_SINGLE) {
Problem<float> prob(num_inputs, num_outputs);

AMSExecutor wf = AMSCreateExecutor(model_descr,
AMSDType::AMS_SINGLE,
resource,
(AMSPhysicFn)callBackSingle,
process_id,
world_size);
// TODO: I do not think we should pass the process id and world size here.
// It should be obtained from the communicator passed.
AMSExecutor wf =
#ifdef __ENABLE_MPI__
use_mpi?
AMSCreateDistributedExecutor(model_descr,
AMSDType::AMS_SINGLE,
resource,
(AMSPhysicFn)callBackSingle,
MPI_COMM_WORLD,
0,
1) :
#endif // __ENABLE_MPI__
AMSCreateExecutor(model_descr,
AMSDType::AMS_SINGLE,
resource,
(AMSPhysicFn)callBackSingle,
0,
1);

prob.ams_run(wf, resource, num_iterations, avg_elements);
} else {
Problem<double> prob(num_inputs, num_outputs);
AMSExecutor wf = AMSCreateExecutor(model_descr,
AMSDType::AMS_DOUBLE,
resource,
(AMSPhysicFn)callBackDouble,
process_id,
world_size);
AMSExecutor wf =
#ifdef __ENABLE_MPI__
use_mpi?
AMSCreateDistributedExecutor(model_descr,
AMSDType::AMS_DOUBLE,
resource,
(AMSPhysicFn)callBackDouble,
MPI_COMM_WORLD,
0,
1) :
#endif // __ENABLE_MPI__
AMSCreateExecutor(model_descr,
AMSDType::AMS_DOUBLE,
resource,
(AMSPhysicFn)callBackDouble,
0,
1);
prob.ams_run(wf, resource, num_iterations, avg_elements);
}

Expand Down

0 comments on commit 1fad494

Please sign in to comment.