diff --git a/src/AMSlib/wf/workflow.hpp b/src/AMSlib/wf/workflow.hpp index 8310896..99d33ec 100644 --- a/src/AMSlib/wf/workflow.hpp +++ b/src/AMSlib/wf/workflow.hpp @@ -73,7 +73,7 @@ class AMSWorkflow std::shared_ptr DB; /** @brief The process id. For MPI runs this is the rank */ - const int rId; + int rId; /** @brief The total number of processes participating in the simulation * (world_size for MPI) */ @@ -247,7 +247,17 @@ class AMSWorkflow void set_physics(AMSPhysicFn _AppCall) { AppCall = _AppCall; } #ifdef __ENABLE_MPI__ - void set_communicator(MPI_Comm communicator) { comm = communicator; } + void set_communicator(MPI_Comm communicator) { + comm = communicator; + + if (comm == MPI_COMM_NULL) { + rId = 0; + wSize = 1; + } else { + MPI_Comm_rank(comm, &rId); + MPI_Comm_size(comm, &wSize); + } + } #endif void set_exec_policy(AMSExecPolicy policy) { ePolicy = policy; } diff --git a/tests/AMSlib/ams_ete.cpp b/tests/AMSlib/ams_ete.cpp index 7c2578d..807730d 100644 --- a/tests/AMSlib/ams_ete.cpp +++ b/tests/AMSlib/ams_ete.cpp @@ -177,36 +177,49 @@ int main(int argc, char **argv) AMSCAbstrModel model_descr = AMSRegisterAbstractModel( "test", uq_policy, threshold, model_path, nullptr, "test", -1); - int process_id = 0; - int world_size = 1; - -#ifdef __ENABLE_MPI__ - if (use_mpi) { - MPI_Comm_rank(MPI_COMM_WORLD, &process_id); - MPI_Comm_size(MPI_COMM_WORLD, &world_size); - } -#endif // __ENABLE_MPI__ - - if (data_type == AMSDType::AMS_SINGLE) { Problem prob(num_inputs, num_outputs); - AMSExecutor wf = AMSCreateExecutor(model_descr, - AMSDType::AMS_SINGLE, - resource, - (AMSPhysicFn)callBackSingle, - process_id, - world_size); +// TODO: I do not think we should pass the process id and world size here. +// It should be obtained from the communicator passed. + AMSExecutor wf = +#ifdef __ENABLE_MPI__ + use_mpi? + AMSCreateDistributedExecutor(model_descr, + AMSDType::AMS_SINGLE, + resource, + (AMSPhysicFn)callBackSingle, + MPI_COMM_WORLD, + 0, + 1) : +#endif // __ENABLE_MPI__ + AMSCreateExecutor(model_descr, + AMSDType::AMS_SINGLE, + resource, + (AMSPhysicFn)callBackSingle, + 0, + 1); prob.ams_run(wf, resource, num_iterations, avg_elements); } else { Problem prob(num_inputs, num_outputs); - AMSExecutor wf = AMSCreateExecutor(model_descr, - AMSDType::AMS_DOUBLE, - resource, - (AMSPhysicFn)callBackDouble, - process_id, - world_size); + AMSExecutor wf = +#ifdef __ENABLE_MPI__ + use_mpi? + AMSCreateDistributedExecutor(model_descr, + AMSDType::AMS_DOUBLE, + resource, + (AMSPhysicFn)callBackDouble, + MPI_COMM_WORLD, + 0, + 1) : +#endif // __ENABLE_MPI__ + AMSCreateExecutor(model_descr, + AMSDType::AMS_DOUBLE, + resource, + (AMSPhysicFn)callBackDouble, + 0, + 1); prob.ams_run(wf, resource, num_iterations, avg_elements); }