From 68cfa91d5ce1d0c62d404afc370c8855c3192554 Mon Sep 17 00:00:00 2001 From: Loic Pottier Date: Fri, 22 Dec 2023 17:51:36 -0800 Subject: [PATCH] Added fault-tolerance features to RMQ broker (#34). - Broker will restart if underlying connection is faulty - Broker will send unacknowledged from previous connection (this could result in messages being received twice or more if errors happen with the wrong timing) - Fixed a Static Initialization Order Fiasco (actually destruction fiasco) in ResourceManager and AMS DB Signed-off-by: Loic Pottier --- examples/app/eos.hpp | 2 + src/AMSlib/AMS.cpp | 39 +- src/AMSlib/ml/hdcache.hpp | 19 +- src/AMSlib/ml/uq.hpp | 5 +- src/AMSlib/wf/basedb.hpp | 818 ++++++++++++++++++++--------- src/AMSlib/wf/cuda/utilities.cuh | 44 +- src/AMSlib/wf/data_handler.hpp | 6 +- src/AMSlib/wf/debug.h | 9 +- src/AMSlib/wf/redist_load.hpp | 31 +- src/AMSlib/wf/resource_manager.cpp | 4 - src/AMSlib/wf/resource_manager.hpp | 39 +- src/AMSlib/wf/workflow.hpp | 38 +- tests/AMSlib/ams_allocate.cpp | 24 +- tests/AMSlib/cpu_packing_test.cpp | 63 +-- tests/AMSlib/gpu_packing_test.cpp | 37 +- tests/AMSlib/lb.cpp | 3 +- tests/AMSlib/test_hdcache.cpp | 25 +- tests/AMSlib/torch_model.cpp | 12 +- 18 files changed, 796 insertions(+), 422 deletions(-) diff --git a/examples/app/eos.hpp b/examples/app/eos.hpp index ad0bf12e..7872c003 100644 --- a/examples/app/eos.hpp +++ b/examples/app/eos.hpp @@ -29,6 +29,8 @@ class EOS outputs[3]); } + virtual ~EOS() = default; + virtual void Eval(const int length, const FPType *density, const FPType *energy, diff --git a/src/AMSlib/AMS.cpp b/src/AMSlib/AMS.cpp index 8b8308cf..c7f6bf60 100644 --- a/src/AMSlib/AMS.cpp +++ b/src/AMSlib/AMS.cpp @@ -10,7 +10,9 @@ #include #include "wf/resource_manager.hpp" +#include "wf/basedb.hpp" #include "wf/workflow.hpp" +#include "wf/debug.h" struct AMSWrap { std::vector> executors; @@ -39,15 +41,11 @@ void _AMSExecute(AMSExecutor executor, int outputDim, MPI_Comm Comm = 0) { - static std::once_flag flag; - std::call_once(flag, [&]() { ams::ResourceManager::init(); }); - uint64_t index = reinterpret_cast(executor); - if (index >= _amsWrap.executors.size()) throw std::runtime_error("AMS Executor identifier does not exist\n"); - auto currExec = _amsWrap.executors[index]; + if (currExec.first == AMSDType::Double) { ams::AMSWorkflow *dWF = reinterpret_cast *>(currExec.second); @@ -80,6 +78,12 @@ extern "C" { AMSExecutor AMSCreateExecutor(const AMSConfig config) { + static std::once_flag flag; + std::call_once(flag, [&]() { + auto& rm = ams::ResourceManager::getInstance(); + rm.init(); + }); + if (config.dType == Double) { ams::AMSWorkflow *dWF = new ams::AMSWorkflow(config.cBack, @@ -94,7 +98,6 @@ AMSExecutor AMSCreateExecutor(const AMSConfig config) config.pId, config.wSize, config.ePolicy); - _amsWrap.executors.push_back( std::make_pair(config.dType, static_cast(dWF))); return reinterpret_cast(_amsWrap.executors.size() - 1L); @@ -114,7 +117,6 @@ AMSExecutor AMSCreateExecutor(const AMSConfig config) config.ePolicy); _amsWrap.executors.push_back( std::make_pair(config.dType, static_cast(sWF))); - return reinterpret_cast(_amsWrap.executors.size() - 1L); } else { throw std::invalid_argument("Data type is not supported by AMSLib!"); @@ -139,6 +141,22 @@ void AMSExecute(AMSExecutor executor, outputDim); } +void AMSDestroyExecutor(AMSExecutor executor) { + uint64_t index = reinterpret_cast(executor); + if (index >= _amsWrap.executors.size()) + throw std::runtime_error("AMS Executor identifier does not exist\n"); + auto currExec = _amsWrap.executors[index]; + + if (currExec.first == AMSDType::Double) { + delete reinterpret_cast *>(currExec.second); + } else if (currExec.first == AMSDType::Single) { + delete reinterpret_cast *>(currExec.second); + } else { + throw std::invalid_argument("Data type is not supported by AMSLib!"); + return; + } +} + #ifdef __ENABLE_MPI__ void AMSDistributedExecute(AMSExecutor executor, MPI_Comm Comm, @@ -160,15 +178,16 @@ void AMSDistributedExecute(AMSExecutor executor, } #endif - const char *AMSGetAllocatorName(AMSResourceType device) { - return std::move(ams::ResourceManager::getAllocatorName(device)).c_str(); + auto& rm = ams::ResourceManager::getInstance(); + return std::move(rm.getAllocatorName(device)).c_str(); } void AMSSetAllocator(AMSResourceType resource, const char *alloc_name) { - ams::ResourceManager::setAllocator(std::string(alloc_name), resource); + auto& rm = ams::ResourceManager::getInstance(); + rm.setAllocator(std::string(alloc_name), resource); } #ifdef __cplusplus diff --git a/src/AMSlib/ml/hdcache.hpp b/src/AMSlib/ml/hdcache.hpp index eb6f2264..4cd81e0c 100644 --- a/src/AMSlib/ml/hdcache.hpp +++ b/src/AMSlib/ml/hdcache.hpp @@ -309,7 +309,8 @@ class HDCache TypeValue *lin_data = data_handler::linearize_features(cache_location, ndata, inputs); _add(ndata, lin_data); - ams::ResourceManager::deallocate(lin_data, cache_location); + auto& rm = ams::ResourceManager::getInstance(); + rm.deallocate(lin_data, cache_location); } //! ----------------------------------------------------------------------- @@ -336,7 +337,8 @@ class HDCache TypeValue *lin_data = data_handler::linearize_features(cache_location, ndata, inputs); _train(ndata, lin_data); - ams::ResourceManager::deallocate(lin_data, cache_location); + auto& rm = ams::ResourceManager::getInstance(); + rm.deallocate(lin_data, cache_location); } //! ------------------------------------------------------------------------ @@ -395,7 +397,8 @@ class HDCache TypeValue *lin_data = data_handler::linearize_features(cache_location, ndata, inputs); _evaluate(ndata, lin_data, is_acceptable); - ams::ResourceManager::deallocate(lin_data, cache_location); + auto& rm = ams::ResourceManager::getInstance(); + rm.deallocate(lin_data, cache_location); DBG(UQModule, "Done with evalution of uq"); } @@ -471,12 +474,12 @@ class HDCache const size_t knbrs = static_cast(m_knbrs); static const TypeValue ook = 1.0 / TypeValue(knbrs); - + auto& rm = ams::ResourceManager::getInstance(); TypeValue *kdists = - ams::ResourceManager::allocate(ndata * knbrs, + rm.allocate(ndata * knbrs, cache_location); TypeIndex *kidxs = - ams::ResourceManager::allocate(ndata * knbrs, + rm.allocate(ndata * knbrs, cache_location); // query faiss @@ -523,8 +526,8 @@ class HDCache kdists, is_acceptable, ndata, knbrs, acceptable_error); } - ams::ResourceManager::deallocate(kdists, cache_location); - ams::ResourceManager::deallocate(kidxs, cache_location); + rm.deallocate(kdists, cache_location); + rm.deallocate(kidxs, cache_location); } //! evaluate cache uncertainty when (data type != TypeValue) diff --git a/src/AMSlib/ml/uq.hpp b/src/AMSlib/ml/uq.hpp index dcd1e98a..cf393229 100644 --- a/src/AMSlib/ml/uq.hpp +++ b/src/AMSlib/ml/uq.hpp @@ -74,9 +74,10 @@ class UQ const size_t ndims = outputs.size(); std::vector outputs_stdev(ndims); // TODO: Enable device-side allocation and predicate calculation. + auto& rm = ams::ResourceManager::getInstance(); for (int dim = 0; dim < ndims; ++dim) outputs_stdev[dim] = - ams::ResourceManager::allocate(totalElements, + rm.allocate(totalElements, AMSResourceType::HOST); CALIPER(CALI_MARK_BEGIN("SURROGATE");) @@ -110,7 +111,7 @@ class UQ } for (int dim = 0; dim < ndims; ++dim) - ams::ResourceManager::deallocate(outputs_stdev[dim], + rm.deallocate(outputs_stdev[dim], AMSResourceType::HOST); CALIPER(CALI_MARK_END("DELTAUQ");) } else if (uqPolicy == AMSUQPolicy::FAISS_Mean || diff --git a/src/AMSlib/wf/basedb.hpp b/src/AMSlib/wf/basedb.hpp index 7f72107a..5540dd8a 100644 --- a/src/AMSlib/wf/basedb.hpp +++ b/src/AMSlib/wf/basedb.hpp @@ -34,10 +34,7 @@ namespace fs = std::experimental::filesystem; #include #include -// TODO: We should comment out "using" in header files as -// it propagates to every other file including this file #warning Redis is currently not supported/tested -using namespace sw::redis; #endif @@ -70,7 +67,6 @@ using namespace sw::redis; #include #include #include -#include #endif // __ENABLE_RMQ__ @@ -89,6 +85,9 @@ class BaseDB BaseDB& operator=(const BaseDB&) = delete; BaseDB(uint64_t id) : id(id) {} + + virtual void close() {} + virtual ~BaseDB() {} /** @@ -539,7 +538,7 @@ class RedisDB : public BaseDB { const std::string _fn; // path to the file storing the DB access config uint64_t _dbid; - Redis* _redis; + sw::redis::Redis* _redis; uint64_t keyId; public: @@ -558,8 +557,8 @@ class RedisDB : public BaseDB _dbid = reinterpret_cast(this); auto connection_info = read_json(fn); - ConnectionOptions connection_options; - connection_options.type = ConnectionType::TCP; + sw::redis::ConnectionOptions connection_options; + connection_options.type = sw::redis::ConnectionType::TCP; connection_options.host = connection_info["host"]; connection_options.port = std::stoi(connection_info["service-port"]); connection_options.password = connection_info["database-password"]; @@ -568,10 +567,10 @@ class RedisDB : public BaseDB true; // Required to connect to PDS within LC connection_options.tls.cacert = connection_info["cert"]; - ConnectionPoolOptions pool_options; + sw::redis::ConnectionPoolOptions pool_options; pool_options.size = 100; // Pool size, i.e. max number of connections. - _redis = new Redis(connection_options, pool_options); + _redis = new sw::redis::Redis(connection_options, pool_options); } ~RedisDB() @@ -693,28 +692,29 @@ class RedisDB : public BaseDB /** * @brief AMS represents the header as follows: - * The header is 12 bytes long: - * - 1 byte is the size of the header (here 12). Limit max: 255 + * The header is 16 bytes long: + * - 1 byte is the size of the header (here 16). Limit max: 255 * - 1 byte is the precision (4 for float, 8 for double). Limit max: 255 * - 2 bytes are the MPI rank (0 if AMS is not running with MPI). Limit max: 65535 * - 4 bytes are the number of elements in the message. Limit max: 2^32 - 1 * - 2 bytes are the input dimension. Limit max: 65535 * - 2 bytes are the output dimension. Limit max: 65535 - * - * |__Header__|__Datatype__|___Rank___|__#elems__|___InDim___|___OutDim___|...real data...| - * ^ ^ ^ ^ ^ ^ ^ ^ - * | Byte 1 | Byte 2 | Byte 3-4 | Byte 4-8 | Byte 8-10 | Byte 10-12 | Byte 12-X | + * - 4 bytes for padding. Limit max: 2^32 - 1 + * + * |_Header_|_Datatype_|___Rank___|__#elems__|___InDim___|___OutDim___|_Pad_|.real data.| + * ^ ^ ^ ^ ^ ^ ^ ^ ^ + * | Byte 1 | Byte 2 | Byte 3-4 | Byte 4-8 | Byte 8-10 | Byte 10-12 |-----| Byte 16-k | * - * where X = datatype * num_element * (InDim + OutDim). Total message size is 12+X. + * where X = datatype * num_element * (InDim + OutDim). Total message size is 16+k. * - * The data starts at byte 12, ends at byte X. + * The data starts at byte 16, ends at byte k. * The data is structured as pairs of input/outputs. Let K be the total number of * elements, then we have K pairs of inputs/outputs (either float or double): * - * |__Header_(12B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| + * |__Header_(16B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| */ struct AMSMsgHeader { - /** @brief Heaader size (bytes) */ + /** @brief Header size (bytes) */ uint8_t hsize; /** @brief Data type size (bytes) */ uint8_t dtype; @@ -748,6 +748,27 @@ struct AMSMsgHeader { { } + /** + * @brief Constructor for AMSMsgHeader + * @param[in] mpi_rank MPI rank + * @param[in] num_elem Number of elements (input/outputs) + * @param[in] in_dim Inputs dimension + * @param[in] out_dim Outputs dimension + */ + AMSMsgHeader(uint16_t mpi_rank, + uint32_t num_elem, + uint16_t in_dim, + uint16_t out_dim, + uint8_t type_size) + : hsize(static_cast(AMSMsgHeader::size())), + dtype(type_size), + mpi_rank(mpi_rank), + num_elem(num_elem), + in_dim(in_dim), + out_dim(out_dim) + { + } + /** * @brief Return the size of a header in the AMS protocol. * @return The size of a message header in AMS (in byte) @@ -792,6 +813,46 @@ struct AMSMsgHeader { return AMSMsgHeader::size(); } + + /** + * @brief Return a valid header based on a pre-existing data buffer + * @param[in] data_blob The buffer to fill + * @return An AMSMsgHeader with the correct attributes + */ + static AMSMsgHeader decode(uint8_t* data_blob) + { + size_t current_offset = 0; + // Header size (should be 1 bytes) + uint8_t new_hsize = data_blob[current_offset]; + CWARNING(AMSMsgHeader, + new_hsize != AMSMsgHeader::size(), + "buffer is likely not a valid AMSMessage (%d / %d)", + new_hsize, + current_offset) + + current_offset += sizeof(uint8_t); + // Data type (should be 1 bytes) + uint8_t new_dtype = data_blob[current_offset]; + current_offset += sizeof(uint8_t); + // MPI rank (should be 2 bytes) + uint16_t new_mpirank; + std::memcpy(&new_mpirank, data_blob + current_offset, sizeof(uint16_t)); + current_offset += sizeof(uint16_t); + // Num elem (should be 4 bytes) + uint32_t new_num_elem; + std::memcpy(&new_num_elem, data_blob + current_offset, sizeof(uint32_t)); + current_offset += sizeof(uint32_t); + // Input dim (should be 2 bytes) + uint16_t new_in_dim; + std::memcpy(&new_in_dim, data_blob + current_offset, sizeof(uint16_t)); + current_offset += sizeof(uint16_t); + // Output dim (should be 2 bytes) + uint16_t new_out_dim; + std::memcpy(&new_out_dim, data_blob + current_offset, sizeof(uint16_t)); + + return AMSMsgHeader( + new_mpirank, new_num_elem, new_in_dim, new_out_dim, new_dtype); + } }; @@ -816,9 +877,37 @@ class AMSMessage /** @brief The dimensions of outputs */ size_t _output_dim; + /** + * @brief Empty constructor + */ + AMSMessage() + : _id(0), + _num_elements(0), + _input_dim(0), + _output_dim(0), + _data(nullptr), + _total_size(0) + { + } + + /** + * @brief Internal Method swapping for AMSMessage + * @param[in] other Message to swap + */ + void swap(const AMSMessage& other) + { + _id = other._id; + _num_elements = other._num_elements; + _input_dim = other._input_dim; + _output_dim = other._output_dim; + _total_size = other._total_size; + _data = other._data; + } + public: /** * @brief Constructor + * @param[in] id ID of the message * @param[in] num_elements Number of elements * @param[in] inputs Inputs * @param[in] outputs Outputs @@ -842,32 +931,68 @@ class AMSMessage _rank, _num_elements, _input_dim, _output_dim, sizeof(TypeValue)); _total_size = AMSMsgHeader::size() + getTotalElements() * sizeof(TypeValue); - _data = ams::ResourceManager::allocate(_total_size, - AMSResourceType::HOST); + auto& rm = ams::ResourceManager::getInstance(); + _data = rm.allocate(_total_size, AMSResourceType::HOST); size_t current_offset = header.encode(_data); current_offset += encode_data(reinterpret_cast(_data + current_offset), inputs, outputs); - DBG(AMSMessage, "Allocated message: %p", _data); + DBG(AMSMessage, "Allocated message %d: %p", _id, _data); } - AMSMessage(const AMSMessage&) = delete; + /** + * @brief Constructor + * @param[in] id ID of the message + * @param[in] data Pointer containing data + */ + AMSMessage(int id, uint8_t* data) + : _id(id), + _num_elements(0), + _input_dim(0), + _output_dim(0), + _data(data), + _total_size(0) + { + auto header = AMSMsgHeader::decode(data); + + int current_rank = 0; +#ifdef __ENABLE_MPI__ + MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, ¤t_rank)); +#endif + _rank = header.mpi_rank; + CWARNING(AMSMessage, + _rank != current_rank, + "MPI rank are not matching (using %d)", + _rank) + + _num_elements = header.num_elem; + _input_dim = header.in_dim; + _output_dim = header.out_dim; + _data = data; + auto type_value = header.dtype; + + _total_size = AMSMsgHeader::size() + getTotalElements() * type_value; + + DBG(AMSMessage, "Allocated message %d: %p", _id, _data); + } + + AMSMessage(const AMSMessage& other) + { + DBG(AMSMessage, "Copy AMSMessage : %p -- %d", other._data, other._id); + swap(other); + }; + AMSMessage& operator=(const AMSMessage&) = delete; AMSMessage(AMSMessage&& other) noexcept { *this = std::move(other); } AMSMessage& operator=(AMSMessage&& other) noexcept { - DBG(AMSMessage, "Move AMSMessage : %p -- %d", other._data, other._id); + // DBG(AMSMessage, "Move AMSMessage : %p -- %d", other._data, other._id); if (this != &other) { - _id = other._id; - _num_elements = other._num_elements; - _input_dim = other._input_dim; - _output_dim = other._output_dim; - _total_size = other._total_size; - _data = other._data; + swap(other); other._data = nullptr; } return *this; @@ -886,7 +1011,6 @@ class AMSMessage const std::vector& inputs, const std::vector& outputs) { - size_t offset = 0; size_t x_dim = _input_dim + _output_dim; if (!data_blob) return 0; // Creating the body part of the messages @@ -926,6 +1050,10 @@ class AMSMessage */ int id() const { return _id; } + /** + * @brief Return MPI rank + * @return MPI rank + */ int rank() const { return _rank; } /** @@ -936,7 +1064,10 @@ class AMSMessage ~AMSMessage() { - DBG(AMSMessage, "Destroying message with address %p %d", _data, _id) + DBG(AMSMessage, + "Destroying message %d: %p (underlying memory NOT freed)", + _id, + _data) } }; // class AMSMessage @@ -1017,8 +1148,8 @@ class RMQConsumerHandler : public AMQP::LibEventHandler #else int ret = SSL_use_certificate_chain_file(ssl, _cacert.c_str()); #endif - // TODO: with openssl 3.0 - // SSL_set_options(ssl, SSL_OP_IGNORE_UNEXPECTED_EOF); + // FIXME: with openssl 3.0 + // Set => SSL_set_options(ssl, SSL_OP_IGNORE_UNEXPECTED_EOF); if (ret != 1) { std::string error("openssl: error loading ca-chain (" + _cacert + @@ -1064,7 +1195,7 @@ class RMQConsumerHandler : public AMQP::LibEventHandler virtual void onReady(AMQP::TcpConnection* connection) override { DBG(RMQConsumerHandler, - "[rank=%d] Sucessfuly logged in. Connection ready to use.\n", + "[rank=%d] Sucessfuly logged in. Connection ready to use.", _rank) _channel = std::make_shared(connection); @@ -1315,13 +1446,14 @@ class RMQConsumer } }; // class RMQConsumer +enum RMQConnectionStatus { FAILED, CONNECTED, CLOSED, ERROR }; + /** * @brief Specific handler for RabbitMQ connections based on libevent. */ class RMQPublisherHandler : public AMQP::LibEventHandler { private: - enum ConnectionStatus { FAILED, CONNECTED, CLOSED }; /** @brief Path to TLS certificate */ std::string _cacert; /** @brief The MPI rank (0 if MPI is not used) */ @@ -1339,15 +1471,18 @@ class RMQPublisherHandler : public AMQP::LibEventHandler /** @brief Number of messages successfully acknowledged */ int _nb_msg_ack; - std::promise establish_connection; - std::future established; + std::promise establish_connection; + std::future established; + + std::promise close_connection; + std::future closed; - std::promise close_connection; - std::future closed; + std::promise _error_connection; + std::future _ftr_error; public: std::mutex ptr_mutex; - std::vector data_ptrs; + std::vector data_ptrs; /** * @brief Constructor @@ -1373,91 +1508,76 @@ class RMQPublisherHandler : public AMQP::LibEventHandler #endif established = establish_connection.get_future(); closed = close_connection.get_future(); + _ftr_error = _error_connection.get_future(); } + ~RMQPublisherHandler() = default; + /** * @brief Publish data on RMQ queue. - * @param[in] data The data pointer - * @param[in] data_size The number of bytes in the data pointer + * @param[in] msg The AMSMessage to publish */ void publish(AMSMessage&& msg) { + data_ptrs.push_back(msg); if (_rchannel) { // publish a message via the reliable-channel + // onAck : message has been explicitly ack'ed by RabbitMQ + // onNack : message has been explicitly nack'ed by RabbitMQ + // onError : error occurred before any ack or nack was received + // onLost : messages that have either been nack'ed, or lost _rchannel ->publish("", _queue, reinterpret_cast(msg.data()), msg.size()) - .onAck([_msg_ptr = msg.data(), + .onAck([this, &_nb_msg_ack = _nb_msg_ack, - rank = msg.rank(), id = msg.id(), - &ptr_mutex = ptr_mutex, + data = msg.data(), &data_ptrs = this->data_ptrs]() mutable { - const std::lock_guard lock(ptr_mutex); DBG(RMQPublisherHandler, "[rank=%d] message #%d (Addr:%p) got acknowledged successfully " "by " "RMQ " "server", - rank, + _rank, id, - _msg_ptr) + data) + this->free_ams_message(id, data_ptrs); _nb_msg_ack++; - data_ptrs.push_back(_msg_ptr); }) - .onNack([_msg_ptr = msg.data(), - &_nb_msg_ack = _nb_msg_ack, - rank = msg.rank(), - id = msg.id(), - &ptr_mutex = ptr_mutex, - &data_ptrs = this->data_ptrs]() mutable { - const std::lock_guard lock(ptr_mutex); + .onNack([this, id = msg.id(), data = msg.data()]() mutable { WARNING(RMQPublisherHandler, - "[rank=%d] message #%d received negative acknowledged by " + "[rank=%d] message #%d (%p) received negative acknowledged " + "by " "RMQ " "server", - rank, - id) - data_ptrs.push_back(_msg_ptr); - }) - .onLost([_msg_ptr = msg.data(), - &_nb_msg_ack = _nb_msg_ack, - rank = msg.rank(), - id = msg.id(), - &ptr_mutex = ptr_mutex, - &data_ptrs = this->data_ptrs]() mutable { - const std::lock_guard lock(ptr_mutex); - CFATAL(RMQPublisherHandler, - false, - "[rank=%d] message #%d likely got lost by RMQ server", - rank, - id) - data_ptrs.push_back(_msg_ptr); + _rank, + id, + data) }) - .onError( - [_msg_ptr = msg.data(), - &_nb_msg_ack = _nb_msg_ack, - rank = msg.rank(), - id = msg.id(), - &ptr_mutex = ptr_mutex, - &data_ptrs = this->data_ptrs](const char* err_message) mutable { - const std::lock_guard lock(ptr_mutex); - CFATAL(RMQPublisherHandler, - false, - "[rank=%d] message #%d did not get send: %s", - rank, - id, - err_message) - data_ptrs.push_back(_msg_ptr); - }); + .onError([this, id = msg.id(), data = msg.data()]( + const char* err_message) mutable { + WARNING(RMQPublisherHandler, + "[rank=%d] message #%d (%p) did not get send: %s", + _rank, + id, + data, + err_message) + }); } else { WARNING(RMQPublisherHandler, "[rank=%d] The reliable channel was not ready for message #%d.", _rank, - _nb_msg) + msg.id()) } _nb_msg++; } + /** + * @brief Wait (blocking call) until connection has been established or that ms * repeat is over. + * @param[in] ms Number of milliseconds the function will wait on the future + * @param[in] repeat Number of times the function will wait + * @return True if connection has been established + */ bool waitToEstablish(unsigned ms, int repeat = 1) { if (waitFuture(established, ms, repeat)) { @@ -1468,6 +1588,12 @@ class RMQPublisherHandler : public AMQP::LibEventHandler return false; } + /** + * @brief Wait (blocking call) until connection has been closed or that ms * repeat is over. + * @param[in] ms Number of milliseconds the function will wait on the future + * @param[in] repeat Number of times the function will wait + * @return True if connection has been closed + */ bool waitToClose(unsigned ms, int repeat = 1) { if (waitFuture(closed, ms, repeat)) { @@ -1476,24 +1602,45 @@ class RMQPublisherHandler : public AMQP::LibEventHandler return false; } - ~RMQPublisherHandler() = default; - - void release_message_buffers() + /** + * @brief Check if the connection can be used to send messages. + * @return True if connection is valid (i.e., can send messages) + */ + bool connection_valid() { - const std::lock_guard lock(ptr_mutex); - for (auto& dp : data_ptrs) { - DBG(RMQPublisherHandler, "deallocate address %p", dp) - ams::ResourceManager::deallocate(dp, AMSResourceType::HOST); - } - data_ptrs.erase(data_ptrs.begin(), data_ptrs.end()); + std::chrono::milliseconds span(1); + return _ftr_error.wait_for(span) != std::future_status::ready; } + /** + * @brief Return the messages that have NOT been acknowledged by the RabbitMQ server. + * @return A vector of AMSMessage + */ + std::vector& internal_msg_buffer() { return data_ptrs; } + + /** + * @brief Free AMSMessages held by the handler + */ + void cleanup() { free_all_messages(data_ptrs); } + + /** + * @brief Total number of messages sent + * @return Number of messages + */ + int msg_sent() const { return _nb_msg; } + + /** + * @brief Total number of messages successfully acknowledged + * @return Number of messages + */ + int msg_acknowledged() const { return _nb_msg_ack; } + unsigned unacknowledged() const { return _rchannel->unacknowledged(); } void flush() { uint32_t tries = 0; - while (auto unAck = _rchannel->unacknowledged()) { + while (auto unAck = unacknowledged()) { DBG(RMQPublisherHandler, "Waiting for %lu messages to be acknowledged", unAck); @@ -1501,41 +1648,9 @@ class RMQPublisherHandler : public AMQP::LibEventHandler if (++tries > 10) break; std::this_thread::sleep_for(std::chrono::milliseconds(50 * tries)); } + free_all_messages(data_ptrs); } - // void purge() - // { - // std::promise purge_queue; - // std::future purged; - // purged = purge_queue.get_future(); - // - // _channel->purgeQueue(_queue) - // .onSuccess([&](uint32_t messageCount) { - // DBG(RMQPublisherHandler, - // "Sucessfuly purged queue with (%u) remaining messages", - // messageCount); - // purge_queue.set_value(true); - // }) - // .onError([&](const char* message) { - // DBG(RMQPublisherHandler, - // "Error '%s' when purging queue %s", - // message, - // _queue.c_str()); - // purge_queue.set_value(false); - // }) - // .onFinalize([&]() { - // DBG(RMQPublisherHandler, "Finalizing queue %s", _queue.c_str()) - // }); - // - // if (purged.get()) { - // DBG(RMQPublisherHandler, "Successfull destruction of RMQ queue"); - // return; - // } - // - // DBG(RMQPublisherHandler, "Non-successfull destruction of RMQ queue"); - // } - - private: /** * @brief Method that is called after a TCP connection has been set up, and @@ -1602,8 +1717,10 @@ class RMQPublisherHandler : public AMQP::LibEventHandler virtual void onReady(AMQP::TcpConnection* connection) override { DBG(RMQPublisherHandler, - "[rank=%d] Sucessfuly logged in. Connection ready to use.\n", - _rank) + "[rank=%d] Sucessfuly logged in (connection %p). Connection ready to " + "use.", + _rank, + connection) _channel = std::make_shared(connection); _channel->onError([&](const char* message) { @@ -1658,7 +1775,7 @@ class RMQPublisherHandler : public AMQP::LibEventHandler */ virtual void onClosed(AMQP::TcpConnection* connection) override { - DBG(RMQPublisherHandler, "[rank=%d] Connection is closed.\n", _rank) + DBG(RMQPublisherHandler, "[rank=%d] Connection is closed.", _rank) } /** @@ -1673,25 +1790,30 @@ class RMQPublisherHandler : public AMQP::LibEventHandler virtual void onError(AMQP::TcpConnection* connection, const char* message) override { - FATAL(RMQPublisherHandler, - "[rank=%d] fatal error on TCP connection: %s\n", - _rank, - message) + WARNING(RMQPublisherHandler, + "[rank=%d] fatal error on TCP connection: %s", + _rank, + message) + try { + _error_connection.set_value(ERROR); + } catch (const std::future_error& e) { + DBG(RMQPublisherHandler, "[rank=%d] future already set.", _rank) + } } /** - * Final method that is called. This signals that no further calls to your + * @brief Final method that is called. This signals that no further calls to your * handler will be made about the connection. * @param connection The connection that can be destructed */ virtual void onDetached(AMQP::TcpConnection* connection) override { // add your own implementation, like cleanup resources or exit the application - DBG(RMQPublisherHandler, "[rank=%d] Connection is detached.\n", _rank) + DBG(RMQPublisherHandler, "[rank=%d] Connection is detached.", _rank) close_connection.set_value(CLOSED); } - bool waitFuture(std::future& future, + bool waitFuture(std::future& future, unsigned ms, int repeat) { @@ -1700,9 +1822,74 @@ class RMQPublisherHandler : public AMQP::LibEventHandler std::future_status status; while ((status = future.wait_for(span)) == std::future_status::timeout && (iters++ < repeat)) - std::future established; + std::future established; return status == std::future_status::ready; } + + /** + * @brief Free the data pointed pointer in a vector and update vector. + * @param[in] addr Address of memory to free. + * @param[in] buffer The vector containing memory buffers + */ + void free_ams_message(int msg_id, std::vector& buf) + { + const std::lock_guard lock(ptr_mutex); + auto it = + std::find_if(buf.begin(), buf.end(), [&msg_id](const AMSMessage& obj) { + return obj.id() == msg_id; + }); + if (it == buf.end()) { + WARNING(RMQPublisherHandler, + "Failed to deallocate msg #%d: not found", + msg_id) + return; + } + auto& msg = *it; + auto& rm = ams::ResourceManager::getInstance(); + try { + rm.deallocate(msg.data(), AMSResourceType::HOST); + } catch (const umpire::util::Exception& e) { + WARNING(RMQPublisherHandler, + "Failed to deallocate #%d (%p)", + msg.id(), + msg.data()); + } + DBG(RMQPublisherHandler, "Deallocated msg #%d (%p)", msg.id(), msg.data()) + it = std::remove_if(buf.begin(), + buf.end(), + [&msg_id](const AMSMessage& obj) { + return obj.id() == msg_id; + }); + CWARNING(RMQPublisherHandler, + it == buf.end(), + "Failed to erase %p from buffer", + msg.data()); + buf.erase(it, buf.end()); + } + + /** + * @brief Free the data pointed by each pointer in a vector. + * @param[in] buffer The vector containing memory buffers + */ + void free_all_messages(std::vector& buffer) + { + const std::lock_guard lock(ptr_mutex); + // auto& urm = umpire::ResourceManager::getInstance(); + auto& rm = ams::ResourceManager::getInstance(); + for (auto& dp : buffer) { + DBG(RMQPublisherHandler, "deallocate msg #%d (%p)", dp.id(), dp.data()) + try { + rm.deallocate(dp.data(), AMSResourceType::HOST); + } catch (const umpire::util::Exception& e) { + WARNING(RMQPublisherHandler, + "Failed to deallocate msg #%d (%p)", + dp.id(), + dp.data()); + } + } + buffer.erase(buffer.begin(), buffer.end()); + } + }; // class RMQPublisherHandler @@ -1725,15 +1912,23 @@ class RMQPublisher std::shared_ptr _loop; /** @brief The handler which contains various callbacks for the sender */ std::shared_ptr _handler; + /** @brief Buffer holding unacknowledged messages in case of crash */ + std::vector _buffer_msg; public: RMQPublisher(const RMQPublisher&) = delete; RMQPublisher& operator=(const RMQPublisher&) = delete; - RMQPublisher(const AMQP::Address& address, - std::string cacert, - std::string queue) - : _rank(0), _queue(queue), _cacert(cacert), _handler(nullptr) + RMQPublisher( + const AMQP::Address& address, + std::string cacert, + std::string queue, + std::vector&& msgs_to_send = std::vector()) + : _rank(0), + _queue(queue), + _cacert(cacert), + _handler(nullptr), + _buffer_msg(std::move(msgs_to_send)) { #ifdef __ENABLE_MPI__ MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); @@ -1788,9 +1983,12 @@ class RMQPublisher return _handler->waitToEstablish(ms, repeat); } + /** + * @brief Return the number of unacknowledged messages + * @return Number of unacknowledged messages + */ unsigned unacknowledged() const { return _handler->unacknowledged(); } - /** * @brief Start the underlying I/O loop (blocking call) */ @@ -1801,10 +1999,63 @@ class RMQPublisher */ void stop() { event_base_loopexit(_loop.get(), NULL); } - void release_messages() { _handler->release_message_buffers(); } + /** + * @brief Check if the underlying connection has no errors + * @return True if no errors + */ + bool connection_valid() { return _handler->connection_valid(); } - void publish(AMSMessage&& message) { _handler->publish(std::move(message)); } + /** + * @brief Return the messages that have not been acknolwged. + * It does not mean they have not been delivered but the + * acknowledgements have not arrived yet. + * @return A map of messages (ID, data) + */ + std::vector get_buffer_msgs() + { + return std::move(_handler->internal_msg_buffer()); + } + + /** + * @brief Total number of messages successfully acknowledged + * @return Number of messages + */ + void cleanup() { _handler->cleanup(); } + void publish(AMSMessage&& message) + { + // We have some messages to send first (from a potential restart) + if (_buffer_msg.size() > 0) { + for (auto& msg : _buffer_msg) { + DBG(RMQPublisher, + "Publishing backed up message %d: %p", + msg.id(), + msg.data()) + _handler->publish(std::move(msg)); + } + _buffer_msg.clear(); + } + + DBG(RMQPublisher, "Publishing message %d: %p", message.id(), message.data()) + _handler->publish(std::move(message)); + } + + /** + * @brief Total number of messages sent + * @return Number of messages + */ + int msg_sent() const { return _handler->msg_sent(); } + + /** + * @brief Total number of messages successfully acknowledged + * @return Number of messages + */ + int msg_acknowledged() const { return _handler->msg_acknowledged(); } + + /** + * @brief Total number of messages successfully acknowledged + * @return Number of messages + */ bool close(unsigned ms, int repeat = 1) { _handler->flush(); @@ -1812,14 +2063,14 @@ class RMQPublisher return _handler->waitToClose(ms, repeat); } - ~RMQPublisher() {} + ~RMQPublisher() = default; }; // class RMQPublisher /** * @brief Class that manages a RabbitMQ broker and handles connection, event * loop and set up various handlers. - * @details This class manages a specific type of database backend in AMSLib. + * @details This class handles a specific type of database backend in AMSLib. * Instead of writing inputs/outputs directly to files (CSV or HDF5), we * send these elements (a collection of inputs and their corresponding outputs) * to a service called RabbitMQ which is listening on a given IP and port. @@ -1865,7 +2116,7 @@ class RMQPublisher * updates to rank regarding the ML surrrogate model. RMQConsumer will automatically populate a std::vector with all * messages received since the execution of AMS started. * - * Glabal note: Most calls dealing with RabbitMQ (to establish a RMQ connection, opening a channel, publish data etc) + * Global note: Most calls dealing with RabbitMQ (to establish a RMQ connection, opening a channel, publish data etc) * are asynchronous callbacks (similar to asyncio in Python or future in C++). * So, the simulation can have already started and the RMQ connection might not be valid which is why most part * of the code that deals with RMQ are wrapped into callbacks that will get run only in case of success. @@ -1882,6 +2133,10 @@ class RabbitMQDB final : public BaseDB std::string _queue_sender; /** @brief name of the queue to receive data */ std::string _queue_receiver; + /** @brief Address of the RabbitMQ server */ + std::shared_ptr _address; + /** @brief TLS certificate path */ + std::string _cacert; /** @brief MPI rank (if MPI is used, otherwise 0) */ int _rank; /** @brief Represent the ID of the last message sent */ @@ -1985,17 +2240,18 @@ class RabbitMQDB final : public BaseDB "sure rabbitmq-inbound-queue and rabbitmq-outbound-queue exist") return; } + _cacert = rmq_config["rabbitmq-cert"]; AMQP::Login login(rmq_config["rabbitmq-user"], rmq_config["rabbitmq-password"]); - AMQP::Address address(rmq_config["service-host"], - port, - login, - rmq_config["rabbitmq-vhost"], - is_secure); + _address = std::make_shared(rmq_config["service-host"], + port, + login, + rmq_config["rabbitmq-vhost"], + is_secure); - std::string cacert = rmq_config["rabbitmq-cert"]; - _publisher = std::make_shared(address, cacert, _queue_sender); + _publisher = + std::make_shared(*_address, _cacert, _queue_sender); _publisher_thread = std::thread([&]() { _publisher->start(); }); @@ -2014,7 +2270,8 @@ class RabbitMQDB final : public BaseDB /** * @brief Takes an input and an output vector each holding 1-D vectors data, and push - * it onto the libevent buffer. + * it onto the libevent buffer. If the underlying connection is not valid anymore, a + new connection will be set up and unacknowledged messages will be (re) sent. * @param[in] num_elements Number of elements of each 1-D vector * @param[in] inputs Vector of 1-D vectors containing the inputs to be sent * @param[in] outputs Vector of 1-D vectors, each 1-D vectors contains @@ -2033,12 +2290,49 @@ class RabbitMQDB final : public BaseDB num_elements, inputs.size(), outputs.size()) - - _publisher->release_messages(); + if (!_publisher->connection_valid()) { + restart(); + bool status = _publisher->waitToEstablish(100, 10); + if (!status) { + _publisher->stop(); + _publisher_thread.join(); + FATAL(RabbitMQDB, "Could not establish connection"); + } + } _publisher->publish(AMSMessage(_msg_tag, num_elements, inputs, outputs)); _msg_tag++; } + void restart() + { + std::vector messages = _publisher->get_buffer_msgs(); + + AMSMessage& msg_min = + *(std::min_element(messages.begin(), + messages.end(), + [](const AMSMessage& a, const AMSMessage& b) { + return a.id() < b.id(); + })); + + WARNING(RMQPublisher, + "[rank=%d] we have %d buffered messages that will get re-send " + "(starting from msg #%d).", + _rank, + messages.size(), + msg_min.id()) + + // Stop the faulty publisher + _publisher->stop(); + _publisher_thread.join(); + _publisher.reset(); + + _publisher = std::make_shared(*_address, + _cacert, + _queue_sender, + std::move(messages)); + _publisher_thread = std::thread([&]() { _publisher->start(); }); + } + /** * @brief Return the type of this broker * @return The type of the broker @@ -2050,119 +2344,157 @@ class RabbitMQDB final : public BaseDB */ AMSDBType dbType() { return AMSDBType::RMQ; }; - ~RabbitMQDB() + void close() { - + if (!_publisher_thread.joinable()) { + return; + } bool status = _publisher->close(100, 10); CWARNING(RabbitMQDB, !status, "Could not gracefully close TCP connection") + DBG(RabbitMQDB, "Number of messages sent: %d", _msg_tag) DBG(RabbitMQDB, "Number of unacknowledged messages are %d", _publisher->unacknowledged()) _publisher->stop(); - //_publisher->release_messages(); //_consumer->stop(); _publisher_thread.join(); //_consumer_thread.join(); } + + ~RabbitMQDB() { close(); } }; // class RabbitMQDB #endif // __ENABLE_RMQ__ +namespace ams +{ /** - * @brief Create an object of the respective database. - * This should never be used for large scale simulations as txt/csv format will - * be extremely slow. - * @param[in] dbPath path to the directory storing the data - * @param[in] dbType Type of the database to create - * @param[in] rId a unique Id for each process taking part in a distributed - * execution (rank-id) + * @brief Class that manages all DB attached to AMS workflows. + * Each DB can overload its method close() that will get called by + * the DB manager when the last workflow using a DB will be destructed. */ template -BaseDB* createDB(char* dbPath, AMSDBType dbType, uint64_t rId = 0) +class DBManager { - DBG(DB, "Instantiating data base"); -#ifdef __ENABLE_DB__ - if (dbPath == nullptr) { - std::cerr << " [WARNING] Path of DB is NULL, Please provide a valid path " - "to enable db\n"; - std::cerr << " [WARNING] Continueing\n"; - return nullptr; +public: + static auto& getInstance() + { + static DBManager instance; + return instance; + } + +private: + std::unordered_map>> + db_instances; + DBManager() = default; + +public: + ~DBManager() + { + for (auto& e : db_instances) { + DBG(DBManager, + "Closing DB %s (#client=%d)", + e.first.c_str(), + e.second.use_count() - 1); + e.second->close(); + } } + DBManager(const DBManager&) = delete; + DBManager(DBManager&&) = delete; + DBManager& operator=(const DBManager&) = delete; + DBManager& operator=(DBManager&&) = delete; - switch (dbType) { - case AMSDBType::CSV: - return new csvDB(dbPath, rId); + /** + * @brief Create an object of the respective database. + * This should never be used for large scale simulations as txt/csv format will + * be extremely slow. + * @param[in] dbPath path to the directory storing the data + * @param[in] dbType Type of the database to create + * @param[in] rId a unique Id for each process taking part in a distributed + * execution (rank-id) + */ + std::shared_ptr> createDB(char* dbPath, + AMSDBType dbType, + uint64_t rId = 0) + { +#ifdef __ENABLE_DB__ + DBG(DBManager, "Instantiating data base"); + if (dbPath == nullptr) { + WARNING(DBManager, + "Path of DB is NULL, Please provide a valid path to enable DB.") + return nullptr; + } + + switch (dbType) { + case AMSDBType::CSV: + return std::make_shared>(dbPath, rId); #ifdef __ENABLE_REDIS__ - case AMSDBType::REDIS: - return new RedisDB(dbPath, rId); + case AMSDBType::REDIS: + return std::make_shared>(dbPath, rId); #endif #ifdef __ENABLE_HDF5__ - case AMSDBType::HDF5: - return new hdf5DB(dbPath, rId); + case AMSDBType::HDF5: + return std::make_shared>(dbPath, rId); #endif #ifdef __ENABLE_RMQ__ - case AMSDBType::RMQ: - return new RabbitMQDB(dbPath, rId); + case AMSDBType::RMQ: + return std::make_shared>(dbPath, rId); #endif - default: - return nullptr; - } -#else - return nullptr; + default: + return nullptr; + } #endif -} - - -/** - * @brief get a data base object referred by this string. - * This should never be used for large scale simulations as txt/csv format will - * be extremely slow. - * @param[in] dbPath path to the directory storing the data - * @param[in] dbType Type of the database to create - * @param[in] rId a unique Id for each process taking part in a distributed - * execution (rank-id) - */ -template -std::shared_ptr> getDB(char* dbPath, - AMSDBType dbType, - uint64_t rId = 0) -{ - static std::unordered_map>> - instances; - if (dbPath == nullptr) { - std::cerr << " [WARNING] Path of DB is NULL, Please provide a valid path " - "to enable db\n"; - std::cerr << " [WARNING] Continueing\n"; return nullptr; } - auto db_iter = instances.find(std::string(dbPath)); - if (db_iter == instances.end()) { - DBG(DB, "Creating new Database writting to file: %s", dbPath); - std::shared_ptr> db = std::shared_ptr>( - createDB(dbPath, dbType, rId)); - instances.insert(std::make_pair(std::string(dbPath), db)); - return db; - } + /** + * @brief get a data base object referred by this string. + * This should never be used for large scale simulations as txt/csv format will + * be extremely slow. + * @param[in] dbPath path to the directory storing the data + * @param[in] dbType Type of the database to create + * @param[in] rId a unique Id for each process taking part in a distributed + * execution (rank-id) + */ + std::shared_ptr> getDB(char* dbPath, + AMSDBType dbType, + uint64_t rId = 0) + { + if (dbPath == nullptr) { + WARNING(DBManager, + "Path of DB is NULL, Please provide a valid path to enable DB.") + return nullptr; + } - auto db = db_iter->second; - // Corner case where creation of the db failed and someone is requesting - // the same entry point - if (db == nullptr) { - return db; - } + auto db_iter = db_instances.find(std::string(dbPath)); + if (db_iter == db_instances.end()) { + auto db = createDB(dbPath, dbType, rId); + db_instances.insert(std::make_pair(std::string(dbPath), db)); + DBG(DBManager, "Creating new Database writting to file: %s", dbPath); + return db; + } - if (db->dbType() != dbType) { - throw std::runtime_error("Requesting databases of different types"); - } + auto db = db_iter->second; + // Corner case where creation of the db failed and someone is requesting + // the same entry point + if (db == nullptr) { + return db; + } - if (db->getId() != rId) { - throw std::runtime_error("Requesting databases from different ranks"); + if (db->dbType() != dbType) { + throw std::runtime_error("Requesting databases of different types"); + } + + if (db->getId() != rId) { + throw std::runtime_error("Requesting databases from different ranks"); + } + DBG(DBManager, "Using existing Database writting to file: %s", dbPath); + + return db; } - DBG(DB, "Using existing Database writting to file: %s", dbPath); +}; - return db; -} +} // namespace ams #endif // __AMS_BASE_DB__ diff --git a/src/AMSlib/wf/cuda/utilities.cuh b/src/AMSlib/wf/cuda/utilities.cuh index 6c944f32..3969688a 100644 --- a/src/AMSlib/wf/cuda/utilities.cuh +++ b/src/AMSlib/wf/cuda/utilities.cuh @@ -287,29 +287,30 @@ int compact(bool cond, bool isReverse = false) { int numBlocks = divup(length, blockSize); + auto& rm = ams::ResourceManager::getInstance(); int* d_BlocksCount = - ams::ResourceManager::allocate(numBlocks, AMSResourceType::DEVICE); + rm.allocate(numBlocks, AMSResourceType::DEVICE); int* d_BlocksOffset = - ams::ResourceManager::allocate(numBlocks, AMSResourceType::DEVICE); + rm.allocate(numBlocks, AMSResourceType::DEVICE); // determine number of elements in the compacted list int* h_BlocksCount = - ams::ResourceManager::allocate(numBlocks, AMSResourceType::HOST); + rm.allocate(numBlocks, AMSResourceType::HOST); int* h_BlocksOffset = - ams::ResourceManager::allocate(numBlocks, AMSResourceType::HOST); + rm.allocate(numBlocks, AMSResourceType::HOST); T** d_dense = - ams::ResourceManager::allocate(dims, AMSResourceType::DEVICE); + rm.allocate(dims, AMSResourceType::DEVICE); T** d_sparse = - ams::ResourceManager::allocate(dims, AMSResourceType::DEVICE); + rm.allocate(dims, AMSResourceType::DEVICE); - ams::ResourceManager::registerExternal(dense, + rm.registerExternal(dense, sizeof(T*) * dims, AMSResourceType::HOST); - ams::ResourceManager::registerExternal(sparse, + rm.registerExternal(sparse, sizeof(T*) * dims, AMSResourceType::HOST); - ams::ResourceManager::copy(dense, d_dense); - ams::ResourceManager::copy(const_cast(sparse), d_sparse); + rm.copy(dense, d_dense); + rm.copy(const_cast(sparse), d_sparse); thrust::device_ptr thrustPrt_bCount(d_BlocksCount); thrust::device_ptr thrustPrt_bOffset(d_BlocksOffset); @@ -338,22 +339,22 @@ int compact(bool cond, cudaDeviceSynchronize(); CUDACHECKERROR(); - ams::ResourceManager::copy(d_BlocksCount, h_BlocksCount); - ams::ResourceManager::copy(d_BlocksOffset, h_BlocksOffset); + rm.copy(d_BlocksCount, h_BlocksCount); + rm.copy(d_BlocksOffset, h_BlocksOffset); int compact_length = h_BlocksOffset[numBlocks - 1] + thrustPrt_bCount[numBlocks - 1]; - ams::ResourceManager::deallocate(d_BlocksCount, AMSResourceType::DEVICE); - ams::ResourceManager::deallocate(d_BlocksOffset, AMSResourceType::DEVICE); + rm.deallocate(d_BlocksCount, AMSResourceType::DEVICE); + rm.deallocate(d_BlocksOffset, AMSResourceType::DEVICE); - ams::ResourceManager::deallocate(h_BlocksCount, AMSResourceType::HOST); - ams::ResourceManager::deallocate(h_BlocksOffset, AMSResourceType::HOST); + rm.deallocate(h_BlocksCount, AMSResourceType::HOST); + rm.deallocate(h_BlocksOffset, AMSResourceType::HOST); - ams::ResourceManager::deallocate(d_dense, AMSResourceType::DEVICE); - ams::ResourceManager::deallocate(d_sparse, AMSResourceType::DEVICE); + rm.deallocate(d_dense, AMSResourceType::DEVICE); + rm.deallocate(d_sparse, AMSResourceType::DEVICE); - ams::ResourceManager::deregisterExternal(dense); - ams::ResourceManager::deregisterExternal(sparse); + rm.deregisterExternal(dense); + rm.deregisterExternal(sparse); cudaDeviceSynchronize(); CUDACHECKERROR(); @@ -432,8 +433,9 @@ void cuda_rand_init(bool* predicate, const size_t length, T threshold) const int TS = 4096; const int BS = 128; int numBlocks = divup(TS, BS); + auto& rm = ams::ResourceManager::getInstance(); if (!dev_random) { - dev_random = ams::ResourceManager::allocate(4096, AMSResourceType::DEVICE); + dev_random = rm.allocate(4096, AMSResourceType::DEVICE); srand_dev<<>>(dev_random, TS); } diff --git a/src/AMSlib/wf/data_handler.hpp b/src/AMSlib/wf/data_handler.hpp index d7c61f71..6089fba0 100644 --- a/src/AMSlib/wf/data_handler.hpp +++ b/src/AMSlib/wf/data_handler.hpp @@ -68,7 +68,8 @@ class DataHandler std::enable_if_t::value>* = nullptr> static inline TypeValue* cast_to_typevalue(AMSResourceType resource, const size_t n, TypeInValue* data) { - TypeValue* fdata = ams::ResourceManager::allocate(resource, n); + auto& rm = ams::ResourceManager::getInstance(); + TypeValue* fdata = rm.allocate(resource, n); std::transform(data, data + n, fdata, [&](const TypeInValue& v) { return static_cast(v); }); @@ -143,7 +144,8 @@ PERFFASPECT() const size_t nfeatures = features.size(); const size_t nvalues = n * nfeatures; - TypeValue* data = ams::ResourceManager::allocate(nvalues, resource); + auto& rm = ams::ResourceManager::getInstance(); + TypeValue* data = rm.allocate(nvalues, resource); if (resource == AMSResourceType::HOST) { for (size_t d = 0; d < nfeatures; d++) { diff --git a/src/AMSlib/wf/debug.h b/src/AMSlib/wf/debug.h index 52d04739..4927c32e 100644 --- a/src/AMSlib/wf/debug.h +++ b/src/AMSlib/wf/debug.h @@ -101,18 +101,19 @@ inline uint32_t getVerbosityLevel() do { \ double vm, rs; \ size_t watermark, current_size, actual_size; \ + auto& rm = ams::ResourceManager::getInstance(); \ memUsage(vm, rs); \ - DBG(id, "Memory usage at %s is VM:%g RS:%g\n", phase, vm, rs); \ + DBG(id, "Memory usage at %s is VM:%g RS:%g", phase, vm, rs); \ \ for (int i = 0; i < AMSResourceType::RSEND; i++) { \ - if (ams::ResourceManager::isActive((AMSResourceType)i)) { \ - ams::ResourceManager::getAllocatorStats((AMSResourceType)i, \ + if (rm.isActive((AMSResourceType)i)) { \ + rm.getAllocatorStats((AMSResourceType)i, \ watermark, \ current_size, \ actual_size); \ DBG(id, \ "Allocator: %s HWM:%lu CS:%lu AS:%lu) ", \ - ams::ResourceManager::getAllocatorName((AMSResourceType)i) \ + rm.getAllocatorName((AMSResourceType)i) \ .c_str(), \ watermark, \ current_size, \ diff --git a/src/AMSlib/wf/redist_load.hpp b/src/AMSlib/wf/redist_load.hpp index f7e461e8..bbedf3b5 100644 --- a/src/AMSlib/wf/redist_load.hpp +++ b/src/AMSlib/wf/redist_load.hpp @@ -116,11 +116,12 @@ class AMSLoadBalancer */ void init(int numIn, int numOut, AMSResourceType resource) { + auto& rm = ams::ResourceManager::getInstance(); // We need to store information if (rId == root) { dataElements = - ams::ResourceManager::allocate(worldSize, AMSResourceType::HOST); - displs = ams::ResourceManager::allocate(worldSize + 1, + rm.allocate(worldSize, AMSResourceType::HOST); + displs = rm.allocate(worldSize + 1, AMSResourceType::HOST); } @@ -149,9 +150,9 @@ class AMSLoadBalancer if (rId == root) { balancedElements = - ResourceManager::allocate(worldSize, AMSResourceType::HOST); + rm.ResourceManager::allocate(worldSize, AMSResourceType::HOST); balancedDispls = - ResourceManager::allocate(worldSize, AMSResourceType::HOST); + rm.ResourceManager::allocate(worldSize, AMSResourceType::HOST); for (int i = 0; i < worldSize; i++) { balancedElements[i] = (globalLoad / worldSize) + static_cast(i < (globalLoad % worldSize)); @@ -164,12 +165,12 @@ class AMSLoadBalancer for (int i = 0; i < numIn; i++) { distInputs.push_back( - ams::ResourceManager::allocate(balancedLoad, resource)); + rm.allocate(balancedLoad, resource)); } for (int i = 0; i < numOut; i++) { distOutputs.push_back( - ams::ResourceManager::allocate(balancedLoad, resource)); + rm.allocate(balancedLoad, resource)); } } @@ -265,9 +266,10 @@ class AMSLoadBalancer AMSResourceType resource) { FPTypeValue *temp_data; + auto& rm = ams::ResourceManager::getInstance(); if (rId == root) { - temp_data = ResourceManager::allocate(globalLoad, resource); + temp_data = rm.ResourceManager::allocate(globalLoad, resource); } for (int i = 0; i < src.size(); i++) { @@ -284,7 +286,7 @@ class AMSLoadBalancer } if (rId == root) { - ResourceManager::deallocate(temp_data, resource); + rm.ResourceManager::deallocate(temp_data, resource); } return; @@ -330,20 +332,21 @@ class AMSLoadBalancer /** @brief deallocates all objects of this load balancing transcation */ ~AMSLoadBalancer() { + auto& rm = ams::ResourceManager::getInstance(); CINFO(LoadBalance, root==rId, "Total data %d Data per rank %d", globalLoad, balancedLoad); - if (displs) ams::ResourceManager::deallocate(displs, AMSResourceType::HOST); + if (displs) rm.deallocate(displs, AMSResourceType::HOST); if (dataElements) - ams::ResourceManager::deallocate(dataElements, AMSResourceType::HOST); + rm.deallocate(dataElements, AMSResourceType::HOST); if (balancedElements) - ams::ResourceManager::deallocate(balancedElements, AMSResourceType::HOST); + rm.deallocate(balancedElements, AMSResourceType::HOST); if (balancedDispls) - ams::ResourceManager::deallocate(balancedDispls, AMSResourceType::HOST); + rm.deallocate(balancedDispls, AMSResourceType::HOST); for (int i = 0; i < distOutputs.size(); i++) - ams::ResourceManager::deallocate(distOutputs[i], resource); + rm.deallocate(distOutputs[i], resource); for (int i = 0; i < distInputs.size(); i++) { - ams::ResourceManager::deallocate(distInputs[i], resource); + rm.deallocate(distInputs[i], resource); } }; diff --git a/src/AMSlib/wf/resource_manager.cpp b/src/AMSlib/wf/resource_manager.cpp index 2d35573f..d941c745 100644 --- a/src/AMSlib/wf/resource_manager.cpp +++ b/src/AMSlib/wf/resource_manager.cpp @@ -46,10 +46,6 @@ void AMSAllocator::getAllocatorStats(size_t &wm, size_t &cs, size_t &as) as = allocator.getActualSize(); } - -std::vector ResourceManager::RMAllocators = {nullptr, - nullptr, - nullptr}; // ----------------------------------------------------------------------------- // set up the resource manager // ----------------------------------------------------------------------------- diff --git a/src/AMSlib/wf/resource_manager.hpp b/src/AMSlib/wf/resource_manager.hpp index a26b7d9f..536af956 100644 --- a/src/AMSlib/wf/resource_manager.hpp +++ b/src/AMSlib/wf/resource_manager.hpp @@ -33,6 +33,11 @@ struct AMSAllocator { { auto& rm = umpire::ResourceManager::getInstance(); allocator = rm.getAllocator(alloc_name); + DBG(AMSAllocator, "in AMSAllocator(%d, %s, %p)", id, alloc_name.c_str(), this) + } + + ~AMSAllocator() { + DBG(AMSAllocator, "in ~AMSAllocator(%d, %p)", id, this) } void* allocate(size_t num_bytes); @@ -54,21 +59,25 @@ struct AMSAllocator { class ResourceManager { -public: private: /** @brief Used internally to map resource types (Device, host, pinned memory) to * umpire allocator ids. */ - static std::vector RMAllocators; - + std::vector RMAllocators; + ResourceManager() : RMAllocators({nullptr,nullptr,nullptr}) {}; public: - ResourceManager() = delete; + ~ResourceManager() = default; ResourceManager(const ResourceManager&) = delete; ResourceManager(ResourceManager&&) = delete; ResourceManager& operator=(const ResourceManager&) = delete; ResourceManager& operator=(ResourceManager&&) = delete; + static ResourceManager& getInstance() { + static ResourceManager instance; + return instance; + } + /** @brief return the name of an allocator */ - static std::string getAllocatorName(AMSResourceType resource) + std::string getAllocatorName(AMSResourceType resource) { return RMAllocators[resource]->getName(); } @@ -81,7 +90,7 @@ class ResourceManager */ template PERFFASPECT() - static TypeInValue* allocate(size_t nvalues, AMSResourceType dev) + TypeInValue* allocate(size_t nvalues, AMSResourceType dev) { return static_cast( RMAllocators[dev]->allocate(nvalues * sizeof(TypeInValue))); @@ -95,7 +104,7 @@ class ResourceManager */ template PERFFASPECT() - static void deallocate(TypeInValue* data, AMSResourceType dev) + void deallocate(TypeInValue* data, AMSResourceType dev) { RMAllocators[dev]->deallocate(data); } @@ -107,7 +116,7 @@ class ResourceManager * @return void. */ PERFFASPECT() - static void registerExternal(void* ptr, size_t nBytes, AMSResourceType dev) + void registerExternal(void* ptr, size_t nBytes, AMSResourceType dev) { RMAllocators[dev]->registerPtr(ptr, nBytes); } @@ -116,7 +125,7 @@ class ResourceManager * @param[in] ptr pointer to memory to de-register. * @return void. */ - static void deregisterExternal(void* ptr) + void deregisterExternal(void* ptr) { AMSAllocator::deregisterPtr(ptr); } @@ -130,7 +139,7 @@ class ResourceManager */ template PERFFASPECT() - static void copy(TypeInValue* src, TypeInValue* dest, size_t size = 0) + void copy(TypeInValue* src, TypeInValue* dest, size_t size = 0) { static auto& rm = umpire::ResourceManager::getInstance(); rm.copy(dest, src, size); @@ -142,13 +151,13 @@ class ResourceManager * @return void. */ template - static void deallocate(std::vector& dPtr, AMSResourceType resource) + void deallocate(std::vector& dPtr, AMSResourceType resource) { for (auto* I : dPtr) RMAllocators[resource]->deallocate(I); } - static void init() + void init() { DBG(ResourceManager, "Default initialization of allocators"); if (!RMAllocators[AMSResourceType::HOST]) @@ -162,7 +171,7 @@ class ResourceManager #endif } - static void setAllocator(std::string alloc_name, AMSResourceType resource) + void setAllocator(std::string alloc_name, AMSResourceType resource) { if (RMAllocators[resource]) { delete RMAllocators[resource]; @@ -175,7 +184,7 @@ class ResourceManager RMAllocators[resource]->getName().c_str()); } - static bool isActive(AMSResourceType resource){ + bool isActive(AMSResourceType resource){ return RMAllocators[resource] != nullptr; } @@ -186,7 +195,7 @@ class ResourceManager * @param[out] as The actual size of the pool.. * @return void. */ - static void getAllocatorStats(AMSResourceType resource, + void getAllocatorStats(AMSResourceType resource, size_t& wm, size_t& cs, size_t& as) diff --git a/src/AMSlib/wf/workflow.hpp b/src/AMSlib/wf/workflow.hpp index a20b4aaa..0e35da4b 100644 --- a/src/AMSlib/wf/workflow.hpp +++ b/src/AMSlib/wf/workflow.hpp @@ -95,6 +95,7 @@ class AMSWorkflow static const long bSize = 1 * 1024 * 1024; const int numIn = inputs.size(); const int numOut = outputs.size(); + auto &rm = ams::ResourceManager::getInstance(); // No database, so just de-allocate and return if (!DB) return; @@ -107,8 +108,7 @@ class AMSWorkflow // Compute number of elements that fit inside the buffer size_t bElements = bSize / sizeof(FPTypeValue); FPTypeValue *pPtr = - ams::ResourceManager::allocate(bElements, - AMSResourceType::PINNED); + rm.allocate(bElements, AMSResourceType::PINNED); // Total inner vector dimensions (inputs and outputs) size_t totalDims = inputs.size() + outputs.size(); // Compute number of elements of each outer dimension that fit in buffer @@ -125,22 +125,18 @@ class AMSWorkflow size_t actualElems = std::min(elPerDim, num_elements - i); // Copy input data to host for (int k = 0; k < numIn; k++) { - ams::ResourceManager::copy(&inputs[k][i], - hInputs[k], - actualElems * sizeof(FPTypeValue)); + rm.copy(&inputs[k][i], hInputs[k], actualElems * sizeof(FPTypeValue)); } // Copy output data to host for (int k = 0; k < numIn; k++) { - ams::ResourceManager::copy(&outputs[k][i], - hOutputs[k], - actualElems * sizeof(FPTypeValue)); + rm.copy(&outputs[k][i], hOutputs[k], actualElems * sizeof(FPTypeValue)); } // Store to database DB->store(actualElems, hInputs, hOutputs); } - ams::ResourceManager::deallocate(pPtr, AMSResourceType::PINNED); + rm.deallocate(pPtr, AMSResourceType::PINNED); return; } @@ -153,9 +149,9 @@ class AMSWorkflow appDataLoc(AMSResourceType::HOST), ePolicy(AMSExecPolicy::UBALANCED) { - #ifdef __ENABLE_DB__ - DB = createDB("miniApp_data.txt", dbType, 0); + auto &dbm = ams::DBManager::getInstance(); + DB = dbm.createDB("miniApp_data.txt", dbType, 0); CFATAL(WORKFLOW, !DB, "Cannot create database"); #endif } @@ -183,7 +179,8 @@ class AMSWorkflow DB = nullptr; if (db_path) { DBG(Workflow, "Creating Database"); - DB = getDB(db_path, dbType, rId); + auto &dbm = ams::DBManager::getInstance(); + DB = dbm.getDB(db_path, dbType, rId); } UQModel = std::make_unique>( @@ -194,7 +191,6 @@ class AMSWorkflow ~AMSWorkflow() { DBG(Workflow, "Destroying Workflow Handler"); } - /** @brief This is the main entry point of AMSLib and replaces the original * execution path of the application. * @param[in] probDescr an opaque type that will be forwarded to the @@ -260,6 +256,7 @@ class AMSWorkflow // To move around the inputs, outputs we bundle them as std::vectors std::vector origInputs(inputs, inputs + inputDim); std::vector origOutputs(outputs, outputs + outputDim); + auto &rm = ams::ResourceManager::getInstance(); REPORT_MEM_USAGE(Workflow, "Start") @@ -280,8 +277,7 @@ class AMSWorkflow return; } // The predicate with which we will split the data on a later step - bool *p_ml_acceptable = - ams::ResourceManager::allocate(totalElements, appDataLoc); + bool *p_ml_acceptable = rm.allocate(totalElements, appDataLoc); // ------------------------------------------------------------- // STEP 1: call the UQ module to look at input uncertainties @@ -299,8 +295,7 @@ class AMSWorkflow for (int i = 0; i < inputDim; i++) { packedInputs.emplace_back( - ams::ResourceManager::allocate(totalElements, - appDataLoc)); + rm.allocate(totalElements, appDataLoc)); } DBG(Workflow, "Allocated input resources") @@ -319,8 +314,7 @@ class AMSWorkflow std::vector packedOutputs; for (int i = 0; i < outputDim; i++) { packedOutputs.emplace_back( - ams::ResourceManager::allocate(packedElements, - appDataLoc)); + rm.allocate(packedElements, appDataLoc)); } { @@ -376,11 +370,11 @@ class AMSWorkflow // Deallocate temporal data // ----------------------------------------------------------------- for (int i = 0; i < inputDim; i++) - ams::ResourceManager::deallocate(packedInputs[i], appDataLoc); + rm.deallocate(packedInputs[i], appDataLoc); for (int i = 0; i < outputDim; i++) - ams::ResourceManager::deallocate(packedOutputs[i], appDataLoc); + rm.deallocate(packedOutputs[i], appDataLoc); - ams::ResourceManager::deallocate(p_ml_acceptable, appDataLoc); + rm.deallocate(p_ml_acceptable, appDataLoc); DBG(Workflow, "Finished AMSExecution") CINFO(Workflow, diff --git a/tests/AMSlib/ams_allocate.cpp b/tests/AMSlib/ams_allocate.cpp index d5ea1631..ed6bd5aa 100644 --- a/tests/AMSlib/ams_allocate.cpp +++ b/tests/AMSlib/ams_allocate.cpp @@ -18,34 +18,35 @@ int test_allocation(AMSResourceType resource, std::string pool_name) { std::cout << "Testing Pool: " << pool_name << "\n"; auto& rm = umpire::ResourceManager::getInstance(); - double* data = ams::ResourceManager::allocate(1, resource); + auto& ams_rm = ams::ResourceManager::getInstance(); + double* data = ams_rm.allocate(1, resource); auto found_allocator = rm.getAllocator(data); - if (ams::ResourceManager::getAllocatorName(resource) != + if (ams_rm.getAllocatorName(resource) != found_allocator.getName()) { std::cout << "Allocator Name" - << ams::ResourceManager::getAllocatorName(resource) + << ams_rm.getAllocatorName(resource) << "Actual Allocation " << found_allocator.getName() << "\n"; return 1; } - if (ams::ResourceManager::getAllocatorName(resource) != pool_name) { + if (ams_rm.getAllocatorName(resource) != pool_name) { std::cout << "Allocator Name" - << ams::ResourceManager::getAllocatorName(resource) + << ams_rm.getAllocatorName(resource) << "is not equal to pool name " << pool_name << "\n"; return 1; } found_allocator = rm.getAllocator(data); - if (ams::ResourceManager::getAllocatorName(resource) != + if (ams_rm.getAllocatorName(resource) != found_allocator.getName().data()) { std::cout << "Device Allocator Name" - << ams::ResourceManager::getAllocatorName(resource) + << ams_rm.getAllocatorName(resource) << "Actual Allocation " << found_allocator.getName() << "\n"; return 3; } - ams::ResourceManager::deallocate(data, resource); + ams_rm.deallocate(data, resource); return 0; } @@ -54,7 +55,8 @@ int main(int argc, char* argv[]) int device = std::atoi(argv[1]); // Testing with global umpire allocators - ams::ResourceManager::init(); + auto& ams_rm = ams::ResourceManager::getInstance(); + ams_rm.init(); if (device == 1) { if (test_allocation(AMSResourceType::DEVICE, "DEVICE") != 0) return 1; } else if (device == 0) { @@ -67,13 +69,13 @@ int main(int argc, char* argv[]) auto& rm = umpire::ResourceManager::getInstance(); auto alloc_resource = rm.makeAllocator( "test-device", rm.getAllocator("DEVICE")); - ams::ResourceManager::setAllocator("test-device", AMSResourceType::DEVICE); + ams_rm.setAllocator("test-device", AMSResourceType::DEVICE); if (test_allocation(AMSResourceType::DEVICE, "test-device") != 0) return 1; } else if (device == 0) { auto& rm = umpire::ResourceManager::getInstance(); auto alloc_resource = rm.makeAllocator( "test-host", rm.getAllocator("HOST")); - ams::ResourceManager::setAllocator("test-host", AMSResourceType::HOST); + ams_rm.setAllocator("test-host", AMSResourceType::HOST); if (test_allocation(AMSResourceType::HOST, "test-host") != 0) return 1; } diff --git a/tests/AMSlib/cpu_packing_test.cpp b/tests/AMSlib/cpu_packing_test.cpp index 508ed7a8..65b05e54 100644 --- a/tests/AMSlib/cpu_packing_test.cpp +++ b/tests/AMSlib/cpu_packing_test.cpp @@ -51,13 +51,14 @@ int main(int argc, char* argv[]) using data_handler = DataHandler; const size_t size = SIZE; int device = std::atoi(argv[1]); - ams::ResourceManager::init(); + auto& rm = ams::ResourceManager::getInstance(); + rm.init(); if (device == 0) { AMSResourceType resource = AMSResourceType::HOST; - bool* predicate = ams::ResourceManager::allocate(SIZE, resource); - double* dense = ams::ResourceManager::allocate(SIZE, resource); - double* sparse = ams::ResourceManager::allocate(SIZE, resource); - double* rsparse = ams::ResourceManager::allocate(SIZE, resource); + bool* predicate = rm.allocate(SIZE, resource); + double* dense = rm.allocate(SIZE, resource); + double* sparse = rm.allocate(SIZE, resource); + double* rsparse = rm.allocate(SIZE, resource); initPredicate(predicate, sparse, SIZE); std::vector s_data({const_cast(sparse)}); @@ -88,31 +89,31 @@ int main(int argc, char* argv[]) } } - ResourceManager::deallocate(predicate, AMSResourceType::HOST); - ResourceManager::deallocate(dense, AMSResourceType::HOST); - ResourceManager::deallocate(sparse, AMSResourceType::HOST); - ResourceManager::deallocate(rsparse, AMSResourceType::HOST); + rm.deallocate(predicate, AMSResourceType::HOST); + rm.deallocate(dense, AMSResourceType::HOST); + rm.deallocate(sparse, AMSResourceType::HOST); + rm.deallocate(rsparse, AMSResourceType::HOST); } else if (device == 1) { AMSResourceType resource = AMSResourceType::DEVICE; bool* h_predicate = - ams::ResourceManager::allocate(SIZE, AMSResourceType::HOST); + rm.allocate(SIZE, AMSResourceType::HOST); double* h_dense = - ams::ResourceManager::allocate(SIZE, AMSResourceType::HOST); + rm.allocate(SIZE, AMSResourceType::HOST); double* h_sparse = - ams::ResourceManager::allocate(SIZE, AMSResourceType::HOST); + rm.allocate(SIZE, AMSResourceType::HOST); double* h_rsparse = - ams::ResourceManager::allocate(SIZE, AMSResourceType::HOST); + rm.allocate(SIZE, AMSResourceType::HOST); initPredicate(h_predicate, h_sparse, SIZE); - bool* predicate = ams::ResourceManager::allocate(SIZE, resource); - double* dense = ams::ResourceManager::allocate(SIZE, resource); - double* sparse = ams::ResourceManager::allocate(SIZE, resource); - double* rsparse = ams::ResourceManager::allocate(SIZE, resource); - int* reindex = ams::ResourceManager::allocate(SIZE, resource); + bool* predicate = rm.allocate(SIZE, resource); + double* dense = rm.allocate(SIZE, resource); + double* sparse = rm.allocate(SIZE, resource); + double* rsparse = rm.allocate(SIZE, resource); + int* reindex = rm.allocate(SIZE, resource); - ResourceManager::copy(h_predicate, predicate); - ResourceManager::copy(h_sparse, sparse); + rm.copy(h_predicate, predicate); + rm.copy(h_sparse, sparse); std::vector s_data({const_cast(sparse)}); std::vector sr_data({rsparse}); @@ -129,7 +130,7 @@ int main(int argc, char* argv[]) return 1; } - ams::ResourceManager::copy(dense, h_dense); + rm.copy(dense, h_dense); if (verify(h_dense, elements, flag)) { std::cout << "Dense elements do not have the correct values\n"; @@ -138,7 +139,7 @@ int main(int argc, char* argv[]) data_handler::unpack(resource, predicate, size, d_data, sr_data, flag); - ams::ResourceManager::copy(rsparse, h_rsparse); + rm.copy(rsparse, h_rsparse); if (verify(h_predicate, h_sparse, h_rsparse, size, flag)) { // for ( int k = 0; k < SIZE; k++){ @@ -150,15 +151,15 @@ int main(int argc, char* argv[]) } } - ams::ResourceManager::deallocate(predicate, AMSResourceType::DEVICE); - ams::ResourceManager::deallocate(h_predicate, AMSResourceType::HOST); - ams::ResourceManager::deallocate(dense, AMSResourceType::DEVICE); - ams::ResourceManager::deallocate(h_dense, AMSResourceType::HOST); - ams::ResourceManager::deallocate(sparse, AMSResourceType::DEVICE); - ams::ResourceManager::deallocate(h_sparse, AMSResourceType::HOST); - ams::ResourceManager::deallocate(rsparse, AMSResourceType::DEVICE); - ams::ResourceManager::deallocate(h_rsparse, AMSResourceType::HOST); - ams::ResourceManager::deallocate(reindex, AMSResourceType::DEVICE); + rm.deallocate(predicate, AMSResourceType::DEVICE); + rm.deallocate(h_predicate, AMSResourceType::HOST); + rm.deallocate(dense, AMSResourceType::DEVICE); + rm.deallocate(h_dense, AMSResourceType::HOST); + rm.deallocate(sparse, AMSResourceType::DEVICE); + rm.deallocate(h_sparse, AMSResourceType::HOST); + rm.deallocate(rsparse, AMSResourceType::DEVICE); + rm.deallocate(h_rsparse, AMSResourceType::HOST); + rm.deallocate(reindex, AMSResourceType::DEVICE); } return 0; diff --git a/tests/AMSlib/gpu_packing_test.cpp b/tests/AMSlib/gpu_packing_test.cpp index 658662e7..7caa8da5 100644 --- a/tests/AMSlib/gpu_packing_test.cpp +++ b/tests/AMSlib/gpu_packing_test.cpp @@ -49,25 +49,26 @@ int main(int argc, char* argv[]) using namespace ams; using data_handler = DataHandler; auto& rm = umpire::ResourceManager::getInstance(); + auto& ams_rm = ams::ResourceManager::getInstance(); const size_t size = SIZE; bool* h_predicate = - ams::ResourceManager::allocate(SIZE, + ams_rm.allocate(SIZE, ResourceManager::ResourceType::HOST); - double* h_dense = ams::ResourceManager::allocate( + double* h_dense = ams_rm.allocate( SIZE, ResourceManager::ResourceType::HOST); - double* h_sparse = ams::ResourceManager::allocate( + double* h_sparse = ams_rm.allocate( SIZE, ResourceManager::ResourceType::HOST); - double* h_rsparse = ams::ResourceManager::allocate( + double* h_rsparse = ams_rm.allocate( SIZE, ResourceManager::ResourceType::HOST); initPredicate(h_predicate, h_sparse, SIZE); - bool* predicate = ams::ResourceManager::allocate(SIZE); - double* dense = ams::ResourceManager::allocate(SIZE); - double* sparse = ams::ResourceManager::allocate(SIZE); - double* rsparse = ams::ResourceManager::allocate(SIZE); - int* reindex = ams::ResourceManager::allocate(SIZE); + bool* predicate = ams_rm.allocate(SIZE); + double* dense = ams_rm.allocate(SIZE); + double* sparse = ams_rm.allocate(SIZE); + double* rsparse = ams_rm.allocate(SIZE); + int* reindex = ams_rm.allocate(SIZE); rm.copy(predicate, h_predicate); rm.copy(sparse, h_sparse); @@ -104,19 +105,19 @@ int main(int argc, char* argv[]) return 1; } - ams::ResourceManager::deallocate(predicate); - ams::ResourceManager::deallocate(h_predicate, + ams_rm.deallocate(predicate); + ams_rm.deallocate(h_predicate, ResourceManager::ResourceType::HOST); - ams::ResourceManager::deallocate(dense); - ams::ResourceManager::deallocate(h_dense, + ams_rm.deallocate(dense); + ams_rm.deallocate(h_dense, ResourceManager::ResourceType::HOST); - ams::ResourceManager::deallocate(sparse); - ams::ResourceManager::deallocate(h_sparse, + ams_rm.deallocate(sparse); + ams_rm.deallocate(h_sparse, ResourceManager::ResourceType::HOST); - ams::ResourceManager::deallocate(rsparse); - ams::ResourceManager::deallocate(h_rsparse, + ams_rm.deallocate(rsparse); + ams_rm.deallocate(h_rsparse, ResourceManager::ResourceType::HOST); - ams::ResourceManager::deallocate(reindex); + ams_rm.deallocate(reindex); return 0; } diff --git a/tests/AMSlib/lb.cpp b/tests/AMSlib/lb.cpp index 0f3dd690..d34f88eb 100644 --- a/tests/AMSlib/lb.cpp +++ b/tests/AMSlib/lb.cpp @@ -21,7 +21,8 @@ void init(double *data, int elements, double value) void evaluate(double *data, double *src, int elements) { - ams::ResourceManager::copy(src, data, elements * sizeof(double)); + auto& rm = ams::ResourceManager::getInstance(); + rm.copy(src, data, elements * sizeof(double)); } int verify(double *data, double *src, int elements, int rId) diff --git a/tests/AMSlib/test_hdcache.cpp b/tests/AMSlib/test_hdcache.cpp index 2416717b..84b7ed64 100644 --- a/tests/AMSlib/test_hdcache.cpp +++ b/tests/AMSlib/test_hdcache.cpp @@ -22,13 +22,14 @@ std::vector generate_vectors(const int num_clusters, int dims) { std::vector v_data; + auto& rm = ams::ResourceManager::getInstance(); // This are fixed to mimic the way the faiss was generated // The code below generates data values that are either within // the distance of the faiss index or just outside of it. const T distance = 10.0; const T offset = 5.0; for (int i = 0; i < dims; i++) { - T *data = ams::ResourceManager::allocate(num_clusters * elements, + T *data = rm.allocate(num_clusters * elements, AMSResourceType::HOST); for (int j = 0; j < elements; j++) { // Generate a value for every cluster center @@ -88,15 +89,16 @@ bool do_faiss(std::shared_ptr> &index, std::vector orig_data = generate_vectors(nClusters, nElements, nDims); std::vector data = orig_data; + auto& rm = ams::ResourceManager::getInstance(); bool *predicates = - ams::ResourceManager::allocate(nClusters * nElements, resource); + rm.allocate(nClusters * nElements, resource); if (resource == AMSResourceType::DEVICE) { for (int i = 0; i < orig_data.size(); i++) { T *d_data = - ams::ResourceManager::allocate(nClusters * nElements, resource); - ams::ResourceManager::copy(const_cast(orig_data[i]), + rm.allocate(nClusters * nElements, resource); + rm.copy(const_cast(orig_data[i]), d_data, nClusters * nElements * sizeof(T)); data[i] = d_data; @@ -109,24 +111,24 @@ bool do_faiss(std::shared_ptr> &index, bool *h_predicates = predicates; if (resource == AMSResourceType::DEVICE) { - h_predicates = ams::ResourceManager::allocate(nClusters * nElements, + h_predicates = rm.allocate(nClusters * nElements, AMSResourceType::HOST); - ams::ResourceManager::copy(predicates, h_predicates, nClusters * nElements); + rm.copy(predicates, h_predicates, nClusters * nElements); for (auto d : data) { - ams::ResourceManager::deallocate(const_cast(d), + rm.deallocate(const_cast(d), AMSResourceType::DEVICE); } - ams::ResourceManager::deallocate(predicates, AMSResourceType::DEVICE); + rm.deallocate(predicates, AMSResourceType::DEVICE); } for (auto h_d : orig_data) - ams::ResourceManager::deallocate(const_cast(h_d), + rm.deallocate(const_cast(h_d), AMSResourceType::HOST); bool res = validate(nClusters, nElements, h_predicates); - ams::ResourceManager::deallocate(h_predicates, AMSResourceType::HOST); + rm.deallocate(h_predicates, AMSResourceType::HOST); return res; } @@ -156,7 +158,8 @@ int main(int argc, char *argv[]) AMSResourceType resource = AMSResourceType::HOST; if (use_device == 1) resource = AMSResourceType::DEVICE; - ams::ResourceManager::init(); + auto& urm = ams::ResourceManager::getInstance(); + urm.init(); if (std::strcmp("double", data_type) == 0) { std::shared_ptr> cache = HDCache::getInstance( diff --git a/tests/AMSlib/torch_model.cpp b/tests/AMSlib/torch_model.cpp index 736049e5..723caf9e 100644 --- a/tests/AMSlib/torch_model.cpp +++ b/tests/AMSlib/torch_model.cpp @@ -24,28 +24,30 @@ void inference(SurrogateModel &model, AMSResourceType resource) std::vector inputs; std::vector outputs; + auto& ams_rm = ams::ResourceManager::getInstance(); for (int i = 0; i < 2; i++) - inputs.push_back(ams::ResourceManager::allocate(SIZE, resource)); + inputs.push_back(ams_rm.allocate(SIZE, resource)); for (int i = 0; i < 4; i++) - outputs.push_back(ams::ResourceManager::allocate(SIZE, resource)); + outputs.push_back(ams_rm.allocate(SIZE, resource)); model.evaluate( SIZE, inputs.size(), outputs.size(), inputs.data(), outputs.data()); for (int i = 0; i < 2; i++) - ResourceManager::deallocate(const_cast(inputs[i]), resource); + ams_rm.deallocate(const_cast(inputs[i]), resource); for (int i = 0; i < 4; i++) - ResourceManager::deallocate(outputs[i], resource); + ams_rm.deallocate(outputs[i], resource); } int main(int argc, char *argv[]) { using namespace ams; auto &rm = umpire::ResourceManager::getInstance(); + auto& ams_rm = ams::ResourceManager::getInstance(); int use_device = std::atoi(argv[1]); char *model_path = argv[2]; char *data_type = argv[3]; @@ -55,7 +57,7 @@ int main(int argc, char *argv[]) resource = AMSResourceType::DEVICE; } - ams::ResourceManager::init(); + ams_rm.init(); if (std::strcmp("double", data_type) == 0) { std::shared_ptr> model =