Skip to content

Commit

Permalink
Removed MPI communicator from DB store() based on PR discussions
Browse files Browse the repository at this point in the history
added flags to updateModel to make it conditional

Signed-off-by: Loic Pottier <[email protected]>
  • Loading branch information
lpottier committed Aug 7, 2024
1 parent c5005c5 commit 1a30e15
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 83 deletions.
88 changes: 22 additions & 66 deletions src/AMSlib/wf/basedb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,12 @@ class BaseDB
virtual void store(size_t num_elements,
std::vector<double*>& inputs,
std::vector<double*>& outputs,
#ifdef __ENABLE_MPI__
MPI_Comm comm = MPI_COMM_NULL,
#endif
bool* predicate = nullptr) = 0;


virtual void store(size_t num_elements,
std::vector<float*>& inputs,
std::vector<float*>& outputs,
#ifdef __ENABLE_MPI__
MPI_Comm comm = MPI_COMM_NULL,
#endif
bool* predicate = nullptr) = 0;

uint64_t getId() const { return id; }
Expand Down Expand Up @@ -287,9 +281,6 @@ class csvDB final : public FileDB
virtual void store(size_t num_elements,
std::vector<float*>& inputs,
std::vector<float*>& outputs,
#ifdef __ENABLE_MPI__
MPI_Comm comm = MPI_COMM_NULL,
#endif
bool* predicate = nullptr) override
{
CFATAL(CSV,
Expand All @@ -302,9 +293,6 @@ class csvDB final : public FileDB
virtual void store(size_t num_elements,
std::vector<double*>& inputs,
std::vector<double*>& outputs,
#ifdef __ENABLE_MPI__
MPI_Comm comm = MPI_COMM_NULL,
#endif
bool* predicate = nullptr) override
{

Expand Down Expand Up @@ -467,9 +455,6 @@ class hdf5DB final : public FileDB
void store(size_t num_elements,
std::vector<float*>& inputs,
std::vector<float*>& outputs,
#ifdef __ENABLE_MPI__
MPI_Comm comm = MPI_COMM_NULL,
#endif
bool* predicate = nullptr) override;


Expand All @@ -487,9 +472,6 @@ class hdf5DB final : public FileDB
void store(size_t num_elements,
std::vector<double*>& inputs,
std::vector<double*>& outputs,
#ifdef __ENABLE_MPI__
MPI_Comm comm = MPI_COMM_NULL,
#endif
bool* predicate = nullptr) override;

/**
Expand Down Expand Up @@ -614,9 +596,6 @@ class RedisDB : public BaseDB<TypeValue>
void store(size_t num_elements,
std::vector<TypeValue*>& inputs,
std::vector<TypeValue*>& outputs,
#ifdef __ENABLE_MPI__
MPI_Comm comm = MPI_COMM_NULL,
#endif
bool predicate = nullptr) override
{

Expand Down Expand Up @@ -1605,11 +1584,9 @@ class RMQInterface
std::thread _consumer_thread;
/** @brief True if connected to RabbitMQ */
bool connected;
/** @brief True if _rId is set to the correct MPI rank */
bool set_rank;

public:
RMQInterface() : connected(false), set_rank(false), _rId(0) {}
RMQInterface() : connected(false), _rId(0) {}

/**
* @brief Connect to a RabbitMQ server
Expand Down Expand Up @@ -1643,28 +1620,29 @@ class RMQInterface
*/
bool isConnected() const { return connected; }

/**
* @brief Set the internal ID of the interface (usually MPI rank).
* @param[in] id The ID
*/
void setId(uint64_t id) { _rId = id; }

/**
* @brief Try to restart the RabbitMQ publisher (restart the thread managing messages publishing)
*/
void restartPublisher();

/**
* @brief Return the latest model and, by default, delete the corresponding message from the Consumer
* @param[in] comm MPI communicator which can be used to determinate the rank of the sender
* @param[in] domain_name The name of the domain
* @param[in] num_elements The number of elements for inputs/outputs
* @param[in] inputs A vector containing arrays of inputs, each array has num_elements elements
* @param[in] outputs A vector containing arrays of outputs, each array has num_elements elements
*/
template <typename TypeValue>
void publish(
#ifdef __ENABLE_MPI__
MPI_Comm comm,
#endif
std::string& domain_name,
size_t num_elements,
std::vector<TypeValue*>& inputs,
std::vector<TypeValue*>& outputs)
void publish(std::string& domain_name,
size_t num_elements,
std::vector<TypeValue*>& inputs,
std::vector<TypeValue*>& outputs)
{
DBG(RMQInterface,
"[tag=%d] stores %ld elements of input/output "
Expand All @@ -1674,24 +1652,6 @@ class RMQInterface
inputs.size(),
outputs.size())

#ifdef __ENABLE_MPI__
if (set_rank) {
int flag = 0;
int rank = 0;
MPI_Initialized(&flag);
if (flag && comm != MPI_COMM_NULL) {
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
_rId = rank;
set_rank = true;
} else {
WARNING(RMQInterface,
"MPI is not initialized while __ENABLE_MPI__ is defined. "
"Message "
"will not have MPI rank.")
}
}
#endif

AMSMessage msg(_msg_tag, _rId, domain_name, num_elements, inputs, outputs);

if (!_publisher->connectionValid()) {
Expand Down Expand Up @@ -1770,6 +1730,17 @@ class RabbitMQDB final : public BaseDB
RabbitMQDB(RMQInterface& interface, std::string& domain, uint64_t id)
: BaseDB(id), appDomain(domain), interface(interface)
{
/* We set manually the MPI rank here because when
* RMQInterface was statically initialized, MPI was not
* necessarily initialized and ready. So we provide the
* option of setting the distributed ID afterward.
*
* Note: this ID is encoded into AMSMessage but for
* logging we use a randomly generated ID to stay
* consistent over time (some logging could happen
* before setId is called).
*/
interface.setId(id);
}

/**
Expand All @@ -1780,44 +1751,29 @@ class RabbitMQDB final : public BaseDB
* @param[in] inputs Vector of 1-D vectors containing the inputs to be sent
* @param[in] outputs Vector of 1-D vectors, each 1-D vectors contains
* 'num_elements' values to be sent
* @param[in] comm MPI communicator which can be used to determinate the rank of the sender
* @param[in] predicate (NOT SUPPORTED YET) Series of predicate
*/
PERFFASPECT()
void store(size_t num_elements,
std::vector<double*>& inputs,
std::vector<double*>& outputs,
#ifdef __ENABLE_MPI__
MPI_Comm comm = MPI_COMM_NULL,
#endif
bool* predicate = nullptr) override
{
CFATAL(RMQDB,
predicate != nullptr,
"RMQ database does not support storing uq-predicates")
#ifdef __ENABLE_MPI__
interface.publish(comm, appDomain, num_elements, inputs, outputs);
#else
interface.publish(appDomain, num_elements, inputs, outputs);
#endif
}

void store(size_t num_elements,
std::vector<float*>& inputs,
std::vector<float*>& outputs,
#ifdef __ENABLE_MPI__
MPI_Comm comm = MPI_COMM_NULL,
#endif
bool* predicate = nullptr) override
{
CFATAL(RMQDB,
predicate != nullptr,
"RMQ database does not support storing uq-predicates")
#ifdef __ENABLE_MPI__
interface.publish(comm, appDomain, num_elements, inputs, outputs);
#else
interface.publish(appDomain, num_elements, inputs, outputs);
#endif
}

/**
Expand Down
6 changes: 0 additions & 6 deletions src/AMSlib/wf/hdf5db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,6 @@ hdf5DB::~hdf5DB()
void hdf5DB::store(size_t num_elements,
std::vector<float*>& inputs,
std::vector<float*>& outputs,
#ifdef __ENABLE_MPI__
MPI_Comm comm,
#endif
bool* predicate)
{
if (HDType == -1) {
Expand All @@ -234,9 +231,6 @@ void hdf5DB::store(size_t num_elements,
void hdf5DB::store(size_t num_elements,
std::vector<double*>& inputs,
std::vector<double*>& outputs,
#ifdef __ENABLE_MPI__
MPI_Comm comm,
#endif
bool* predicate)
{
if (HDType == -1) {
Expand Down
18 changes: 8 additions & 10 deletions src/AMSlib/wf/workflow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,7 @@ class AMSWorkflow
bool *hPredicate = nullptr;

if (appDataLoc == AMSResourceType::AMS_HOST) {
#ifdef __ENABLE_MPI__
return DB->store(num_elements, inputs, outputs, comm, predicate);
#else
return DB->store(num_elements, inputs, outputs, predicate);
#endif
}

for (int i = 0; i < inputs.size(); i++) {
Expand All @@ -148,11 +144,7 @@ class AMSWorkflow
}

// Store to database
#ifdef __ENABLE_MPI__
DB->store(num_elements, hInputs, hOutputs, comm, hPredicate);
#else
DB->store(num_elements, hInputs, hOutputs, hPredicate);
#endif
rm.deallocate(hInputs, AMSResourceType::AMS_HOST);
rm.deallocate(hOutputs, AMSResourceType::AMS_HOST);
if (predicate) rm.deallocate(hPredicate, AMSResourceType::AMS_HOST);
Expand All @@ -173,9 +165,15 @@ class AMSWorkflow
store(num_elements, mInputs, outputs, predicate);
}

bool updateModel()
/** \brief Check if we can perform a surrogate model update.
* AMS can update surrogate model only when all MPI ranks have received
* the latest model from RabbitMQ.
* @param[in] performUpdate Perform the model update if True
* @return True
*/
bool updateModel(bool performUpdate = false)
{
if (!DB) return false;
if (!DB || !performUpdate) return false;
bool local = DB->updateModel();
#ifdef __ENABLE_MPI__
bool global = false;
Expand Down
10 changes: 10 additions & 0 deletions tools/rmq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ RabbitMQ. They are useful to test and interact with AMSlib. Each script
is completetly standalone and does not require the AMS Python package,
however they require `pika` and `numpy`.

## Generate TLS certificate

To use most of the tools related to RabbitMQ you might need to provide TLS certificates.
To generate such certificate you can use OpenSSL, for example:

```bash
openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' > rmq-tls.crt
```
where `REMOTE_HOST` is the hostname of the RabbitMQ server and `REMOTE_PORT` is the port.

## Consume messages from AMSlib

To receive, or consume, messages emitted by AMSlib you can use `recv_binary.py`:
Expand Down
3 changes: 2 additions & 1 deletion tools/rmq/recv_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,13 @@ def callback(ch, method, properties, body, args = None):
break
domain_name, data_input, data_output = parse_data(stream, header_info)
num_element = header_info["num_element"]
mpirank = header_info["mpirank"]
# total size of byte we read for that message
chunk_size = header_info["header_size"] + header_info["domain_size"] + header_info["data_size"]

print(
f" [{nbmsg}/{i}] Received from exchange=\"{method.exchange}\" routing_key=\"{method.routing_key}\"\n"
f" > data ({domain_name}) : {len(stream)/(1024*1024)} MB / {num_element} elements\n")
f" > [r{mpirank}] ({domain_name}) : {len(stream)/(1024*1024)} MB / {num_element} elements\n")

if data_input.size > 0:
all_messages.append(data_input)
Expand Down

0 comments on commit 1a30e15

Please sign in to comment.