Skip to content

Commit

Permalink
Addressed comments from @koparasy in #45
Browse files Browse the repository at this point in the history
Signed-off-by: Loic Pottier <[email protected]>
  • Loading branch information
lpottier committed Feb 14, 2024
1 parent 25156c9 commit 3166b0a
Showing 1 changed file with 26 additions and 30 deletions.
56 changes: 26 additions & 30 deletions src/AMSlib/wf/basedb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,8 +835,7 @@ struct AMSMsgHeader {
uint8_t new_dtype = data_blob[current_offset];
current_offset += sizeof(uint8_t);
// MPI rank (should be 2 bytes)
uint16_t new_mpirank;
std::memcpy(&new_mpirank, data_blob + current_offset, sizeof(uint16_t));
uint16_t new_mpirank = (reinterpret_cast<uint16_t*>(data_blob + current_offset))[0];
current_offset += sizeof(uint16_t);
// Num elem (should be 4 bytes)
uint32_t new_num_elem;
Expand Down Expand Up @@ -882,6 +881,7 @@ class AMSMessage
*/
AMSMessage()
: _id(0),
_rank(0),
_num_elements(0),
_input_dim(0),
_output_dim(0),
Expand All @@ -897,6 +897,7 @@ class AMSMessage
void swap(const AMSMessage& other)
{
_id = other._id;
_rank = other._rank;
_num_elements = other._num_elements;
_input_dim = other._input_dim;
_output_dim = other._output_dim;
Expand All @@ -918,6 +919,7 @@ class AMSMessage
const std::vector<TypeValue*>& inputs,
const std::vector<TypeValue*>& outputs)
: _id(id),
_rank(0),
_num_elements(num_elements),
_input_dim(inputs.size()),
_output_dim(outputs.size()),
Expand Down Expand Up @@ -1480,10 +1482,11 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
std::promise<RMQConnectionStatus> _error_connection;
std::future<RMQConnectionStatus> _ftr_error;

public:
std::mutex ptr_mutex;
std::vector<AMSMessage> data_ptrs;
std::mutex _mutex;
/** @brief Messages that have not been successfully acknowledged */
std::vector<AMSMessage> _messages;

public:
/**
* @brief Constructor
* @param[in] loop Event Loop
Expand Down Expand Up @@ -1519,7 +1522,10 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
*/
void publish(AMSMessage&& msg)
{
data_ptrs.push_back(msg);
{
const std::lock_guard<std::mutex> lock(_mutex);
_messages.push_back(msg);
}
if (_rchannel) {
// publish a message via the reliable-channel
// onAck : message has been explicitly ack'ed by RabbitMQ
Expand All @@ -1532,7 +1538,7 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
&_nb_msg_ack = _nb_msg_ack,
id = msg.id(),
data = msg.data(),
&data_ptrs = this->data_ptrs]() mutable {
&_messages = this->_messages]() mutable {
DBG(RMQPublisherHandler,
"[rank=%d] message #%d (Addr:%p) got acknowledged successfully "
"by "
Expand All @@ -1541,7 +1547,7 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
_rank,
id,
data)
this->free_ams_message(id, data_ptrs);
this->free_ams_message(id, _messages);
_nb_msg_ack++;
})
.onNack([this, id = msg.id(), data = msg.data()]() mutable {
Expand Down Expand Up @@ -1616,12 +1622,12 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
* @brief Return the messages that have NOT been acknowledged by the RabbitMQ server.
* @return A vector of AMSMessage
*/
std::vector<AMSMessage>& internal_msg_buffer() { return data_ptrs; }
std::vector<AMSMessage>& internal_msg_buffer() { return _messages; }

/**
* @brief Free AMSMessages held by the handler
*/
void cleanup() { free_all_messages(data_ptrs); }
void cleanup() { free_all_messages(_messages); }

/**
* @brief Total number of messages sent
Expand All @@ -1648,7 +1654,7 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
if (++tries > 10) break;
std::this_thread::sleep_for(std::chrono::milliseconds(50 * tries));
}
free_all_messages(data_ptrs);
free_all_messages(_messages);
}

private:
Expand Down Expand Up @@ -1833,7 +1839,7 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
*/
void free_ams_message(int msg_id, std::vector<AMSMessage>& buf)
{
const std::lock_guard<std::mutex> lock(ptr_mutex);
const std::lock_guard<std::mutex> lock(_mutex);
auto it =
std::find_if(buf.begin(), buf.end(), [&msg_id](const AMSMessage& obj) {
return obj.id() == msg_id;
Expand All @@ -1855,16 +1861,7 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
msg.data());
}
DBG(RMQPublisherHandler, "Deallocated msg #%d (%p)", msg.id(), msg.data())
it = std::remove_if(buf.begin(),
buf.end(),
[&msg_id](const AMSMessage& obj) {
return obj.id() == msg_id;
});
CWARNING(RMQPublisherHandler,
it == buf.end(),
"Failed to erase %p from buffer",
msg.data());
buf.erase(it, buf.end());
buf.erase(it);
}

/**
Expand All @@ -1873,8 +1870,7 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
*/
void free_all_messages(std::vector<AMSMessage>& buffer)
{
const std::lock_guard<std::mutex> lock(ptr_mutex);
// auto& urm = umpire::ResourceManager::getInstance();
const std::lock_guard<std::mutex> lock(_mutex);
auto& rm = ams::ResourceManager::getInstance();
for (auto& dp : buffer) {
DBG(RMQPublisherHandler, "deallocate msg #%d (%p)", dp.id(), dp.data())
Expand All @@ -1887,7 +1883,7 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
dp.data());
}
}
buffer.erase(buffer.begin(), buffer.end());
buffer.clear();
}

}; // class RMQPublisherHandler
Expand Down Expand Up @@ -2006,14 +2002,14 @@ class RMQPublisher
bool connection_valid() { return _handler->connection_valid(); }

/**
* @brief Return the messages that have not been acknolwged.
* @brief Return the messages that have not been acknowledged.
* It does not mean they have not been delivered but the
* acknowledgements have not arrived yet.
* @return A map of messages (ID, data)
* @return A vector of AMSMessage
*/
std::vector<AMSMessage> get_buffer_msgs()
std::vector<AMSMessage>& get_buffer_msgs()
{
return std::move(_handler->internal_msg_buffer());
return _handler->internal_msg_buffer();
}

/**
Expand Down Expand Up @@ -2314,7 +2310,7 @@ class RabbitMQDB final : public BaseDB<TypeValue>
return a.id() < b.id();
}));

WARNING(RMQPublisher,
DBG(RMQPublisher,
"[rank=%d] we have %d buffered messages that will get re-send "
"(starting from msg #%d).",
_rank,
Expand Down

0 comments on commit 3166b0a

Please sign in to comment.