Skip to content

Commit

Permalink
Fix RMQ to send multi-model data to server and client consumes it
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy committed Jun 21, 2024
1 parent 2ff2160 commit 981aca7
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 71 deletions.
38 changes: 24 additions & 14 deletions src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,21 @@ def header_format(self) -> str:
- 1 byte is the size of the header (here 12). Limit max: 255
- 1 byte is the precision (4 for float, 8 for double). Limit max: 255
- 2 bytes are the MPI rank (0 if AMS is not running with MPI). Limit max: 65535
- 2 bytes to store the size of the MSG domain name. Limit max: 65535
- 4 bytes are the number of elements in the message. Limit max: 2^32 - 1
- 2 bytes are the input dimension. Limit max: 65535
- 2 bytes are the output dimension. Limit max: 65535
- 4 bytes are for aligning memory to 8
- 2 bytes are for aligning memory to 8
|__Header_size__|__Datatype__|__Rank__|__#elem__|__InDim__|__OutDim__|...real data...|
|_Header_|_Datatype_|___Rank___|__DomainSize__|__#elems__|___InDim____|___OutDim___|_Pad_|.real data.|
Then the data starts at 12 and is structered as pairs of input/outputs.
Then the data starts at 16 and is structered as pairs of input/outputs.
Let K be the total number of elements, then we have K pairs of inputs/outputs (either float or double):
|__Header_(12B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__|
|__Header_(16B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__|
"""
return "BBHIHHI"
return "BBHHIHHH"

def endianness(self) -> str:
"""
Expand Down Expand Up @@ -85,6 +86,7 @@ def _parse_header(self, body: str) -> dict:
res["hsize"],
res["datatype"],
res["mpirank"],
res["domain_size"],
res["num_element"],
res["input_dim"],
res["output_dim"],
Expand All @@ -103,17 +105,24 @@ def _parse_header(self, body: str) -> dict:
res["multiple_msg"] = len(body) != res["msg_size"]
return res

def _parse_data(self, body: str, header_info: dict) -> np.array:
def _parse_data(self, body: str, header_info: dict) -> Tuple[str, np.array, np.array]:
data = np.array([])
if len(body) == 0:
return data
hsize = header_info["hsize"]
dsize = header_info["dsize"]
domain_name_size = header_info["domain_size"]
domain_name = body[hsize : hsize + domain_name_size]
domain_name = domain_name.decode("utf-8")
try:
if header_info["datatype"] == 4: # if datatype takes 4 bytes (float)
data = np.frombuffer(body[hsize : hsize + dsize], dtype=np.float32)
data = np.frombuffer(
body[hsize + domain_name_size : hsize + domain_name_size + dsize], dtype=np.float32
)
else:
data = np.frombuffer(body[hsize : hsize + dsize], dtype=np.float64)
data = np.frombuffer(
body[hsize + domain_name_size : hsize + domain_name_size + dsize], dtype=np.float64
)
except ValueError as e:
print(f"Error: {e} => {header_info}")
return np.array([])
Expand All @@ -122,25 +131,26 @@ def _parse_data(self, body: str, header_info: dict) -> np.array:
odim = header_info["output_dim"]
data = data.reshape((-1, idim + odim))
# Return input, output
return data[:, :idim], data[:, idim:]
return (domain_name, data[:, :idim], data[:, idim:])

def _decode(self, body: str) -> Tuple[np.array]:
input = []
output = []
# Multiple AMS messages could be packed in one RMQ message
while body:
header_info = self._parse_header(body)
temp_input, temp_output = self._parse_data(body, header_info)
print(f"input shape {temp_input.shape} outpute shape {temp_output.shape}")
print("Received domain name ", header_info["domain_size"])
domain_name, temp_input, temp_output = self._parse_data(body, header_info)
print(f"MSG: {domain_name} input shape {temp_input.shape} outpute shape {temp_output.shape}")
# total size of byte we read for that message
chunk_size = header_info["hsize"] + header_info["dsize"]
chunk_size = header_info["hsize"] + header_info["dsize"] + header_info["domain_size"]
input.append(temp_input)
output.append(temp_output)
# We remove the current message and keep going
body = body[chunk_size:]
return np.concatenate(input), np.concatenate(output)
return domain_name, np.concatenate(input), np.concatenate(output)

def decode(self) -> Tuple[np.array]:
def decode(self) -> Tuple[str, np.array, np.array]:
return self._decode(self.body)


Expand Down
77 changes: 45 additions & 32 deletions src/AMSWorkflow/ams/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from queue import Queue as ser_queue
from threading import Thread
from typing import Callable
import warnings

import numpy as np
from ams.config import AMSInstance
Expand Down Expand Up @@ -45,7 +46,8 @@ class DataBlob:
outputs: A ndarray of the outputs.
"""

def __init__(self, inputs, outputs):
def __init__(self, inputs, outputs, domain_name=None):
self._domain_name = domain_name
self._inputs = inputs
self._outputs = outputs

Expand All @@ -57,6 +59,10 @@ def inputs(self):
def outputs(self):
return self._outputs

@property
def domain_name(self):
return self._domain_name


class QueueMessage:
"""
Expand Down Expand Up @@ -277,7 +283,7 @@ def callback_message(self, ch, basic_deliver, properties, body):
the connection (or if a problem happened with the connection).
"""
start_time = time.time()
input_data, output_data = AMSMessage(body).decode()
domain_name, input_data, output_data = AMSMessage(body).decode()
row_size = input_data[0, :].nbytes + output_data[0, :].nbytes
rows_per_batch = int(np.ceil(BATCH_SIZE / row_size))
num_batches = int(np.ceil(input_data.shape[0] / rows_per_batch))
Expand All @@ -287,7 +293,7 @@ def callback_message(self, ch, basic_deliver, properties, body):
self.datasize += input_data.nbytes + output_data.nbytes

for j, (i, o) in enumerate(zip(input_batches, output_batches)):
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o)))
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o, domain_name)))

self.total_time += time.time() - start_time

Expand Down Expand Up @@ -346,35 +352,40 @@ def __call__(self):
"""

start = time.time()
while True:
fn = get_unique_fn()
fn = f"{self.out_dir}/{fn}.{self.suffix}"
is_terminate = False
total_bytes_written = 0
with self.data_writer_cls(fn) as fd:
bytes_written = 0
with AMSMonitor(obj=self, tag="internal_loop", accumulate=False):
while True:
# This is a blocking call
item = self.i_queue.get(block=True)
if item.is_terminate():
is_terminate = True
elif item.is_process():
data = item.data()
bytes_written += data.inputs.size * data.inputs.itemsize
bytes_written += data.outputs.size * data.outputs.itemsize
fd.store(data.inputs, data.outputs)
total_bytes_written += data.inputs.size * data.inputs.itemsize
total_bytes_written += data.outputs.size * data.outputs.itemsize
# FIXME: We currently decide to chunk files to 2GB
# of contents. Is this a good size?
if is_terminate or bytes_written >= 2 * 1024 * 1024 * 1024:
break

self.o_queue.put(QueueMessage(MessageType.Process, fn))
if is_terminate:
self.o_queue.put(QueueMessage(MessageType.Terminate, None))
break
total_bytes_written = 0
data_files = dict()
# with self.data_writer_cls(fn) as fd:
with AMSMonitor(obj=self, tag="internal_loop", accumulate=False):
while True:
# This is a blocking call
item = self.i_queue.get(block=True)
if item.is_terminate():
for k, v in data_files.items():
v[0].close()
self.o_queue.put(QueueMessage(MessageType.Process, v[0].file_name))
del data_files
self.o_queue.put(QueueMessage(MessageType.Terminate, None))
break
elif item.is_process():
data = item.data()
if data.domain_name not in data_files:
fn = get_unique_fn()
fn = f"{self.out_dir}/{data.domain_name}_{fn}.{self.suffix}"
# TODO: bytes_written should be an attribute of the file
# to keep track of the size of the current file. Currently we keep track of this
# by keeping a value in a list
data_files[data.domain_name] = [self.data_writer_cls(fn).open(), 0]
bytes_written = data.inputs.size * data.inputs.itemsize
bytes_written += data.outputs.size * data.outputs.itemsize
data_files[data.domain_name][0].store(data.inputs, data.outputs)
data_files[data.domain_name][1] += bytes_written
total_bytes_written += data.inputs.size * data.inputs.itemsize
total_bytes_written += data.outputs.size * data.outputs.itemsize

if data_files[data.domain_name][1] >= 2 * 1024 * 1024 * 1024:
data_files[data.domain_name][0].close()
self.o_queue.put(QueueMessage(MessageType.Process, data_files[data.domain_name][0].file_name))
del data_files[data.domain_name]

end = time.time()
self.datasize = total_bytes_written
Expand Down Expand Up @@ -432,6 +443,8 @@ def __call__(self):
dest_file = self.dir / src_fn.name
if src_fn != dest_file:
shutil.move(src_fn, dest_file)
# TODO: Fix me candidates now will be "indexed by the name"
warnings.warn("AMS Kosh manager does not operate with multi-models")
if self._store:
db_store.add_candidates([str(dest_file)])

Expand Down
92 changes: 74 additions & 18 deletions src/AMSlib/AMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ class AMSWrap
std::vector<AMSAbstractModel> registered_models;
std::unordered_map<std::string, int> ams_candidate_models;
AMSDBType dbType = AMSDBType::AMS_NONE;
ams::ResourceManager &memManager;

private:
void dumpEnv()
Expand Down Expand Up @@ -330,6 +331,64 @@ class AMSWrap
}
}

void setupFSDB(json &entry, std::string &dbStrType)
{
if (!entry.contains("fs_path"))
THROW(std::runtime_error,
"JSON db-fields does not provide file system path");

std::string db_path = entry["fs_path"].get<std::string>();
auto &DB = ams::db::DBManager::getInstance();
DB.instantiate_fs_db(dbType, db_path);
DBG(AMS,
"Configured AMS File system database to point to %s using file "
"type %s",
db_path.c_str(),
dbStrType.c_str());
}

template <typename T>
T getEntry(json &entry, std::string field)
{
if (!entry.contains(field)) {
THROW(std::runtime_error,
("I was expecting entry '" + field + "' to exist in json").c_str())
}
return entry[field].get<T>();
}

void setupRMQ(json &entry, std::string &dbStrType)
{
if (!entry.contains("rmq_config")) {
THROW(std::runtime_error,
"JSON db-fields do not contain rmq_config entires")
}
auto rmq_entry = entry["rmq_config"];
int port = getEntry<int>(rmq_entry, "service-port");
std::string host = getEntry<std::string>(rmq_entry, "service-host");
std::string rmq_name = getEntry<std::string>(rmq_entry, "rabbitmq-name");
std::string rmq_pass =
getEntry<std::string>(rmq_entry, "rabbitmq-password");
std::string rmq_user = getEntry<std::string>(rmq_entry, "rabbitmq-user");
std::string rmq_vhost = getEntry<std::string>(rmq_entry, "rabbitmq-vhost");
std::string rmq_cert = getEntry<std::string>(rmq_entry, "rabbitmq-cert");
std::string rmq_in_queue =
getEntry<std::string>(rmq_entry, "rabbitmq-inbound-queue");
std::string rmq_out_queue =
getEntry<std::string>(rmq_entry, "rabbitmq-outbound-queue");

auto &DB = ams::db::DBManager::getInstance();
DB.instantiate_rmq_db(port,
host,
rmq_name,
rmq_pass,
rmq_user,
rmq_vhost,
rmq_cert,
rmq_in_queue,
rmq_out_queue);
}

void parseDatabase(json &jRoot)
{
DBG(AMS, "Parsing Data Base Fields")
Expand All @@ -341,24 +400,21 @@ class AMSWrap
"\"dbType\" "
"entry");
auto dbStrType = entry["dbType"].get<std::string>();
DBG(AMS, "DB Type is: %s", dbStrType.c_str())
AMSDBType dbType = ams::db::getDBType(dbStrType);
if (dbType == AMSDBType::AMS_NONE) return;

if (dbType == AMSDBType::AMS_CSV || dbType == AMSDBType::AMS_HDF5) {
if (!entry.contains("fs_path"))
THROW(std::runtime_error,
"JSON db-fiels does not provide file system path");

std::string db_path = entry["fs_path"].get<std::string>();
auto &DB = ams::db::DBManager::getInstance();
DB.instantiate_fs_db(dbType, db_path);
DBG(AMS,
"Configured AMS File system database to point to %s using file "
"type %s",
db_path.c_str(),
dbStrType.c_str());
dbType = ams::db::getDBType(dbStrType);
switch (dbType) {
case AMSDBType::AMS_NONE:
return;
case AMSDBType::AMS_CSV:
case AMSDBType::AMS_HDF5:
setupFSDB(entry, dbStrType);
break;
case AMSDBType::AMS_RMQ:
setupRMQ(entry, dbStrType);
break;
case AMSDBType::AMS_REDIS:
FATAL(AMS, "Cannot connect to REDIS database, missing implementation");
}
return;
}

std::pair<bool, std::string> setup_loggers()
Expand Down Expand Up @@ -427,7 +483,7 @@ class AMSWrap
}

public:
AMSWrap()
AMSWrap() : memManager(ams::ResourceManager::getInstance())
{
auto log_stats = setup_loggers();
DBG(AMS,
Expand Down
Loading

0 comments on commit 981aca7

Please sign in to comment.