Skip to content

Commit

Permalink
Percept: Remove use of hardcoded MPI_COMMP_WORLD
Browse files Browse the repository at this point in the history
  • Loading branch information
thearusable committed Sep 8, 2023
1 parent 004eee5 commit 28c9b96
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 22 deletions.
3 changes: 2 additions & 1 deletion packages/percept/src/adapt/HangingNodeAdapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <adapt/PredicateBasedElementAdapter.hpp>
#include <percept/SetOfEntities.hpp>
#include <stk_mesh/base/HashEntityAndEntityKey.hpp>
#include <percept/Percept_GlobalComm.hpp>

#include <percept/pooled_alloc.h>

Expand Down Expand Up @@ -367,7 +368,7 @@

if (0 && m_debug_print)
{
MPI_Barrier( MPI_COMM_WORLD );
MPI_Barrier( percept::get_global_comm() );
std::ostringstream ostr;
stk::mesh::EntityId id[]={12,13};
unsigned nid=1;
Expand Down
9 changes: 5 additions & 4 deletions packages/percept/src/adapt/main/MeshAdapt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <adapt/ExcludeWedgesNotConnectedToPyramids.hpp>

#include <percept/PerceptUtils.hpp>
#include <percept/Percept_GlobalComm.hpp>

#include <stk_util/registry/ProductRegistry.hpp>
#include <stk_util/util/string_case_compare.hpp>
Expand Down Expand Up @@ -61,8 +62,8 @@ namespace percept {
timer.stop();
stk::diag::printTimersTable(str, timer,
stk::diag::METRICS_ALL, false,
MPI_COMM_WORLD);
if (0 == stk::parallel_machine_rank(MPI_COMM_WORLD))
percept::get_global_comm());
if (0 == stk::parallel_machine_rank(percept::get_global_comm()))
{
std::cout << str.str() << std::endl;
}
Expand Down Expand Up @@ -565,7 +566,7 @@ namespace percept {
{

#if defined( STK_HAS_MPI )
MPI_Barrier( MPI_COMM_WORLD );
MPI_Barrier( percept::get_global_comm() );
#endif

// 3dm == opennurbs
Expand Down Expand Up @@ -2244,7 +2245,7 @@ void MeshAdapt::initialize_m2g_geometry(std::string input_geometry)

std::string m2gFile = input_geometry.substr(0,input_geometry.length()-3) + "m2g";

int THIS_PROC_NUM = stk::parallel_machine_rank( MPI_COMM_WORLD);
int THIS_PROC_NUM = stk::parallel_machine_rank( percept::get_global_comm());

stk::mesh::MetaData* md = eMeshP->get_fem_meta_data();
stk::mesh::BulkData* bd = eMeshP->get_bulk_data();
Expand Down
7 changes: 4 additions & 3 deletions packages/percept/src/percept/HistogramV.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <sstream>
#include <iomanip>
#include <stk_util/diag/PrintTable.hpp>
#include <percept/Percept_GlobalComm.hpp>

namespace percept
{
Expand Down Expand Up @@ -173,8 +174,8 @@ class HistogramV
#if defined( STK_HAS_MPI )
T minv=m_ranges[0];
T maxv=m_ranges[1];
my_all_reduce_min( MPI_COMM_WORLD , &minv , &m_ranges[0] , 1);
my_all_reduce_max( MPI_COMM_WORLD , &maxv , &m_ranges[1] , 1);
my_all_reduce_min( percept::get_global_comm() , &minv , &m_ranges[0] , 1);
my_all_reduce_max( percept::get_global_comm() , &maxv , &m_ranges[1] , 1);
#endif
}

Expand All @@ -196,7 +197,7 @@ class HistogramV
}
#if defined( STK_HAS_MPI )
std::vector<unsigned> cc(m_counts);
my_all_reduce_sum( MPI_COMM_WORLD , &cc[0] , &m_counts[0] , m_counts.size());
my_all_reduce_sum( percept::get_global_comm() , &cc[0] , &m_counts[0] , m_counts.size());
#endif

unsigned max_count = m_counts[0];
Expand Down
6 changes: 3 additions & 3 deletions packages/percept/src/percept/PerceptMesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <stdio.h>

#include <percept/Percept.hpp>

#include <percept/Percept_GlobalComm.hpp>
#include <percept/PerceptMesh.hpp>
#include <percept/PEnums.hpp>
#include <percept/Util.hpp>
Expand Down Expand Up @@ -468,7 +468,7 @@
checkStateSpec("print_info", m_isOpen, m_isInitialized);
PerceptMesh& eMesh = *this;

const unsigned p_rank = stk::parallel_machine_rank( MPI_COMM_WORLD );
const unsigned p_rank = stk::parallel_machine_rank( percept::get_global_comm() );

stream
<< ""<<NL<<""<<NL<< "P[" << p_rank << "] ======================================================== "<<NL
Expand Down Expand Up @@ -3506,7 +3506,7 @@

bool diff = false;

const unsigned p_rank = stk::parallel_machine_rank( MPI_COMM_WORLD );
const unsigned p_rank = stk::parallel_machine_rank( percept::get_global_comm() );

if (print)
{
Expand Down
7 changes: 4 additions & 3 deletions packages/percept/src/percept/PerceptMesh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@

// ctor constructor
/// Create a Mesh object that owns its constituent MetaData and BulkData (which are created by this object)
//PerceptMesh( stk::ParallelMachine comm = MPI_COMM_WORLD );
PerceptMesh(size_t spatialDimension = 3u, stk::ParallelMachine comm = MPI_COMM_WORLD);
//PerceptMesh( stk::ParallelMachine comm = MPI_COMM_WORLD ); // CHECK: ALLOW MPI_COMM_WORLD
PerceptMesh(size_t spatialDimension = 3u, stk::ParallelMachine comm = MPI_COMM_WORLD); // CHECK: ALLOW MPI_COMM_WORLD

/// Create a Mesh object that doesn't own its constituent MetaData and BulkData, pointers to which are adopted
/// by this constructor.
Expand Down Expand Up @@ -710,7 +710,8 @@


~PerceptMesh() ;
void init( stk::ParallelMachine comm = MPI_COMM_WORLD, bool no_alloc=false ); // FIXME - make private
// FIXME - make private
void init( stk::ParallelMachine comm = MPI_COMM_WORLD, bool no_alloc=false ); // CHECK: ALLOW MPI_COMM_WORLD
void destroy(); // FIXME - make private

const stk::mesh::Part*
Expand Down
3 changes: 2 additions & 1 deletion packages/percept/src/percept/PerceptUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <percept/PerceptUtils.hpp>
#include <percept/mesh/geometry/volume/VolumeUtil.hpp>
#include <percept/Percept_GlobalComm.hpp>

#include <stk_mesh/base/Field.hpp>
#include <stk_mesh/base/MetaData.hpp>
Expand Down Expand Up @@ -75,7 +76,7 @@ void printTimersTableStructured() {
rootTimerStructured().stop();
stk::diag::printTimersTable(str, rootTimerStructured(), stk::diag::METRICS_ALL, false);

if (0==stk::parallel_machine_rank(MPI_COMM_WORLD))
if (0==stk::parallel_machine_rank(percept::get_global_comm()))
std::cout << str.str() << std::endl;
}

Expand Down
33 changes: 33 additions & 0 deletions packages/percept/src/percept/Percept_GlobalComm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2002 - 2008, 2010, 2011 National Technology Engineering
// Solutions of Sandia, LLC (NTESS). Under the terms of Contract
// DE-NA0003525 with NTESS, the U.S. Government retains certain rights
// in this software.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.


#ifndef percept_global_comm_hpp
#define percept_global_comm_hpp

#include <mpi.h>
#include <mutex>

namespace percept {

static std::mutex mpi_mutex;
static MPI_Comm Global_MPI_Comm = MPI_COMM_WORLD; // CHECK: ALLOW MPI_COMM_WORLD

inline void initialize_global_comm(MPI_Comm comm) {
std::lock_guard<std::mutex> guard(mpi_mutex);
Global_MPI_Comm = comm;
}

inline MPI_Comm get_global_comm() {
std::lock_guard<std::mutex> guard(mpi_mutex);
return Global_MPI_Comm;
}

}

#endif /* percept_global_comm_hpp */
3 changes: 2 additions & 1 deletion packages/percept/src/percept/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <sys/malloc.h>
#endif

#include <percept/Percept_GlobalComm.hpp>
#include <percept/Util.hpp>
#include <percept/PerceptMesh.hpp>
#include <sys/resource.h>
Expand Down Expand Up @@ -244,7 +245,7 @@ namespace shards {
void Util::debug_stop()
{
int pRank = 0;
MPI_Comm_rank(MPI_COMM_WORLD, &pRank);
MPI_Comm_rank(percept::get_global_comm(), &pRank);

std::cout << "P[" << pRank << "] in debug_stop\n" << PerceptMesh::demangled_stacktrace(30) << std::endl;
std::cerr << "P[" << pRank << "] in debug_stop\n" << PerceptMesh::demangled_stacktrace(30) << std::endl;
Expand Down
11 changes: 6 additions & 5 deletions packages/percept/src/percept/mesh/mod/smoother/MeshSmoother.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "MeshSmoother.hpp"
#include <percept/mesh/mod/smoother/SmootherMetric.hpp>
#include <percept/mesh/geometry/kernel/GeometryKernel.hpp>
#include <percept/Percept_GlobalComm.hpp>
#include <array>

#define DEBUG_PRINT 0
Expand Down Expand Up @@ -170,14 +171,14 @@ namespace std {
GenericAlgorithm_parallel_count_invalid_elements<MeshType> ga(eMesh);
ga.run();

stk::all_reduce( MPI_COMM_WORLD, stk::ReduceSum<1>( &ga.num_invalid ) );
stk::all_reduce( percept::get_global_comm(), stk::ReduceSum<1>( &ga.num_invalid ) );

if (ga.get_mesh_diagnostics)
{
stk::all_reduce( MPI_COMM_WORLD, stk::ReduceMin<1>( &ga.detA_min ) );
stk::all_reduce( MPI_COMM_WORLD, stk::ReduceMin<1>( &ga.detW_min ) );
stk::all_reduce( MPI_COMM_WORLD, stk::ReduceMax<1>( &ga.shapeA_max ) );
stk::all_reduce( MPI_COMM_WORLD, stk::ReduceMax<1>( &ga.shapeW_max ) );
stk::all_reduce( percept::get_global_comm(), stk::ReduceMin<1>( &ga.detA_min ) );
stk::all_reduce( percept::get_global_comm(), stk::ReduceMin<1>( &ga.detW_min ) );
stk::all_reduce( percept::get_global_comm(), stk::ReduceMax<1>( &ga.shapeA_max ) );
stk::all_reduce( percept::get_global_comm(), stk::ReduceMax<1>( &ga.shapeW_max ) );
if (eMesh->get_rank() == 0)
{
std::cout << "P[0] detA_min= " << ga.detA_min << " detW_min= " << ga.detW_min
Expand Down
4 changes: 3 additions & 1 deletion packages/percept/src/percept/rfgen/RFGen_KLSolver.C
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

#include "Epetra_CrsMatrix.h"

#include <percept/Percept_GlobalComm.hpp>

namespace RFGen
{

KLSolver::KLSolver(const unsigned spatialDim)
:
Epetra_Operator(),
mpiComm_(MPI_COMM_WORLD),
mpiComm_(percept::get_global_comm()),
epetraComm_(Epetra_MpiComm(mpiComm_)),
localNumElem_(0),
globalNumElem_(0),
Expand Down

0 comments on commit 28c9b96

Please sign in to comment.