Skip to content

Commit

Permalink
Added JSON option db:update_surrogate (boolean) to control whether we…
Browse files Browse the repository at this point in the history
… update surrogate

Signed-off-by: Loic Pottier <[email protected]>
  • Loading branch information
lpottier committed Aug 8, 2024
1 parent 1a30e15 commit 384f9f3
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 25 deletions.
4 changes: 3 additions & 1 deletion src/AMSlib/AMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ class AMSWrap
getEntry<std::string>(rmq_entry, "rabbitmq-exchange");
std::string routing_key =
getEntry<std::string>(rmq_entry, "rabbitmq-routing-key");
bool update_surrogate = getEntry<bool>(entry, "update_surrogate");

auto &DB = ams::db::DBManager::getInstance();
DB.instantiate_rmq_db(port,
Expand All @@ -391,7 +392,8 @@ class AMSWrap
rmq_cert,
rmq_out_queue,
exchange,
routing_key);
routing_key,
update_surrogate);
}

void parseDatabase(json &jRoot)
Expand Down
50 changes: 36 additions & 14 deletions src/AMSlib/wf/basedb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,16 @@ class BaseDB
{
/** @brief unique id of the process running this simulation */
uint64_t id;
/** @brief True if surrogate model update is allowed */
bool allowUpdate;

public:
BaseDB(const BaseDB&) = delete;
BaseDB& operator=(const BaseDB&) = delete;

BaseDB(uint64_t id) : id(id) {}
BaseDB(uint64_t id) : id(id), allowUpdate(false) {}

BaseDB(uint64_t id, bool allowUpdate) : id(id), allowUpdate(allowUpdate) {}

virtual void close() {}

Expand Down Expand Up @@ -128,6 +132,8 @@ class BaseDB

uint64_t getId() const { return id; }

bool allowModelUpdate() { return allowUpdate; }

virtual bool updateModel() { return false; }

virtual std::string getLatestModel() { return {}; }
Expand Down Expand Up @@ -1522,8 +1528,9 @@ class RMQPublisher
* "service-port": 31495,
* "service-host": "url.czapps.llnl.gov",
* "rabbitmq-cert": "tls-cert.crt",
* "rabbitmq-inbound-queue": "test4",
* "rabbitmq-outbound-queue": "test3"
* "rabbitmq-outbound-queue": "test3",
* "rabbitmq-exchange": "ams-fanout",
* "rabbitmq-routing-key": "training"
* }
*
* The TLS certificate must be generated by the user and the absolute paths are preferred.
Expand Down Expand Up @@ -1717,18 +1724,20 @@ class RMQInterface
class RabbitMQDB final : public BaseDB
{
private:
/** @brief the application domain that stores the data*/
/** @brief the application domain that stores the data */
std::string appDomain;

/** An interface to RMQ to push the data to*/
/** @brief An interface to RMQ to push the data to */
RMQInterface& interface;

public:
RabbitMQDB(const RabbitMQDB&) = delete;
RabbitMQDB& operator=(const RabbitMQDB&) = delete;

RabbitMQDB(RMQInterface& interface, std::string& domain, uint64_t id)
: BaseDB(id), appDomain(domain), interface(interface)
RabbitMQDB(RMQInterface& interface,
std::string& domain,
uint64_t id,
bool allowModelUpdate)
: BaseDB(id, allowModelUpdate), appDomain(domain), interface(interface)
{
/* We set manually the MPI rank here because when
* RMQInterface was statically initialized, MPI was not
Expand Down Expand Up @@ -1782,8 +1791,17 @@ class RabbitMQDB final : public BaseDB
*/
std::string type() override { return "rabbitmq"; }

/**
* @brief Check if the surrogate model can be updated (i.e., if
* RMQConsumer received a training message)
* @return True if the model can be updated
*/
bool updateModel() { return interface.updateModel(); }

/**
* @brief Return the path of the latest surrogate model if available
* @return The path of the latest available surrogate model
*/
std::string getLatestModel() { return interface.getLatestModel(); }

/**
Expand Down Expand Up @@ -1866,8 +1884,10 @@ class DBManager
std::unordered_map<std::string, std::shared_ptr<BaseDB>> db_instances;
AMSDBType dbType;
uint64_t rId;
/** @brief If True, the DB is allowed to update the surrogate model */
bool updateSurrogate;

DBManager() : dbType(AMSDBType::AMS_NONE){};
DBManager() : dbType(AMSDBType::AMS_NONE), updateSurrogate(false){};

protected:
RMQInterface rmq_interface;
Expand All @@ -1880,7 +1900,6 @@ class DBManager
return instance;
}

public:
~DBManager()
{
for (auto& e : db_instances) {
Expand Down Expand Up @@ -1949,7 +1968,10 @@ class DBManager
#endif
#ifdef __ENABLE_RMQ__
case AMSDBType::AMS_RMQ:
return std::make_shared<RabbitMQDB>(rmq_interface, domainName, rId);
return std::make_shared<RabbitMQDB>(rmq_interface,
domainName,
rId,
updateSurrogate);
#endif
default:
return nullptr;
Expand All @@ -1958,7 +1980,6 @@ class DBManager
return nullptr;
}


/**
* @brief get a data base object referred by this string.
* This should never be used for large scale simulations as txt/csv format will
Expand Down Expand Up @@ -2042,7 +2063,8 @@ class DBManager
std::string& rmq_cert,
std::string& outbound_queue,
std::string& exchange,
std::string& routing_key)
std::string& routing_key,
bool update_surrogate)
{
fs::path Path(rmq_cert);
std::error_code ec;
Expand All @@ -2051,7 +2073,7 @@ class DBManager
"Certificate file '%s' for RMQ server does not exist",
rmq_cert.c_str());
dbType = AMSDBType::AMS_RMQ;

updateSurrogate = update_surrogate;
#ifdef __ENABLE_RMQ__
rmq_interface.connect(rmq_name,
rmq_pass,
Expand Down
9 changes: 3 additions & 6 deletions src/AMSlib/wf/workflow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ class AMSWorkflow
MPI_Comm comm;
#endif


/** @brief Is the evaluate a distributed execution **/
bool isDistributed;

Expand Down Expand Up @@ -168,12 +167,11 @@ class AMSWorkflow
/** \brief Check if we can perform a surrogate model update.
* AMS can update surrogate model only when all MPI ranks have received
* the latest model from RabbitMQ.
* @param[in] performUpdate Perform the model update if True
* @return True
* @return True if surrogate model can be updated
*/
bool updateModel(bool performUpdate = false)
bool updateModel()
{
if (!DB || !performUpdate) return false;
if (!DB || !DB->allowModelUpdate()) return false;
bool local = DB->updateModel();
#ifdef __ENABLE_MPI__
bool global = false;
Expand Down Expand Up @@ -217,7 +215,6 @@ class AMSWorkflow
comm(MPI_COMM_NULL),
#endif
ePolicy(AMSExecPolicy::AMS_UBALANCED)

{
DB = nullptr;
auto &dbm = ams::db::DBManager::getInstance();
Expand Down
10 changes: 6 additions & 4 deletions tests/AMSlib/json_configs/rmq.json.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"db" : {
"dbType" : "rmq",
"rmq_config" : {
"rmq_config" : {
"service-port": ,
"service-host": "",
"rabbitmq-erlang-cookie": "",
Expand All @@ -10,9 +10,11 @@
"rabbitmq-user": "",
"rabbitmq-vhost": "",
"rabbitmq-cert": "",
"rabbitmq-inbound-queue": "",
"rabbitmq-outbound-queue": ""
}
"rabbitmq-outbound-queue": "",
"rabbitmq-exchange": "",
"rabbitmq-routing-key": ""
},
"update_surrogate": false
},
"ml_models" : {
"random_50": {
Expand Down

0 comments on commit 384f9f3

Please sign in to comment.