Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ReliableMessage from 2.4 #2717

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions nvflare/apis/utils/reliable_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def process(self, request: Shareable, fl_ctx: FLContext) -> Shareable:
self.tx_timeout = request.get_header(HEADER_TX_TIMEOUT)

# start processing
ReliableMessage.info(fl_ctx, f"started processing request of topic {self.topic}")
ReliableMessage.debug(fl_ctx, f"started processing request of topic {self.topic}")
self.executor.submit(self._do_request, request, fl_ctx)
return _status_reply(STATUS_IN_PROCESS) # ack
elif self.result:
Expand Down Expand Up @@ -143,14 +143,14 @@ def process(self, request: Shareable, fl_ctx: FLContext) -> Shareable:
ReliableMessage.error(fl_ctx, f"aborting processing since exceeded max tx time {self.tx_timeout}")
return _status_reply(STATUS_ABORTED)
else:
ReliableMessage.info(fl_ctx, "got query: request is in-process")
ReliableMessage.debug(fl_ctx, "got query: request is in-process")
return _status_reply(STATUS_IN_PROCESS)

def _try_reply(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
self.replying = True
start_time = time.time()
ReliableMessage.info(fl_ctx, f"try to send reply back to {self.source}: {self.per_msg_timeout=}")
ReliableMessage.debug(fl_ctx, f"try to send reply back to {self.source}: {self.per_msg_timeout=}")
ack = engine.send_aux_request(
targets=[self.source],
topic=TOPIC_RELIABLE_REPLY,
Expand All @@ -175,7 +175,7 @@ def _try_reply(self, fl_ctx: FLContext):

def _do_request(self, request: Shareable, fl_ctx: FLContext):
start_time = time.time()
ReliableMessage.info(fl_ctx, "invoking request handler")
ReliableMessage.debug(fl_ctx, "invoking request handler")
try:
result = self.request_handler_f(self.topic, request, fl_ctx)
except Exception as e:
Expand All @@ -187,7 +187,7 @@ def _do_request(self, request: Shareable, fl_ctx: FLContext):
result.set_header(HEADER_OP, OP_REPLY)
result.set_header(HEADER_TOPIC, self.topic)
self.result = result
ReliableMessage.info(fl_ctx, f"finished request handler in {time.time()-start_time} secs")
ReliableMessage.debug(fl_ctx, f"finished request handler in {time.time()-start_time} secs")
self._try_reply(fl_ctx)


Expand Down Expand Up @@ -277,12 +277,14 @@ def _receive_request(cls, topic: str, request: Shareable, fl_ctx: FLContext):
cls.error(fl_ctx, f"no handler registered for request {rm_topic=}")
return make_reply(ReturnCode.TOPIC_UNKNOWN)
receiver = cls._get_or_create_receiver(rm_topic, request, handler_f)
cls.info(fl_ctx, f"received request {rm_topic=}")
cls.debug(fl_ctx, f"received request {rm_topic=}")
return receiver.process(request, fl_ctx)
elif op == OP_QUERY:
receiver = cls._req_receivers.get(tx_id)
if not receiver:
cls.error(fl_ctx, f"received query but the request ({rm_topic=}) is not received or already done!")
cls.warning(
fl_ctx, f"received query but the request ({rm_topic=} {tx_id=}) is not received or already done!"
)
return _status_reply(STATUS_NOT_RECEIVED) # meaning the request wasn't received
else:
return receiver.process(request, fl_ctx)
Expand All @@ -300,7 +302,7 @@ def _receive_reply(cls, topic: str, request: Shareable, fl_ctx: FLContext):
cls.error(fl_ctx, "received reply but we are no longer waiting for it")
else:
assert isinstance(receiver, _ReplyReceiver)
cls.info(fl_ctx, f"received reply in {time.time()-receiver.tx_start_time} secs - set waiter")
cls.debug(fl_ctx, f"received reply in {time.time()-receiver.tx_start_time} secs - set waiter")
receiver.process(request)
return make_reply(ReturnCode.OK)

Expand Down Expand Up @@ -415,6 +417,10 @@ def _log_msg(cls, fl_ctx: FLContext, msg: str):
def info(cls, fl_ctx: FLContext, msg: str):
cls._logger.info(cls._log_msg(fl_ctx, msg))

@classmethod
def warning(cls, fl_ctx: FLContext, msg: str):
cls._logger.warning(cls._log_msg(fl_ctx, msg))

@classmethod
def error(cls, fl_ctx: FLContext, msg: str):
cls._logger.error(cls._log_msg(fl_ctx, msg))
Expand Down Expand Up @@ -511,7 +517,7 @@ def _send_request(
return make_reply(ReturnCode.COMMUNICATION_ERROR)

if num_tries > 0:
cls.info(fl_ctx, f"retry #{num_tries} sending request: {per_msg_timeout=}")
cls.debug(fl_ctx, f"retry #{num_tries} sending request: {per_msg_timeout=}")

ack = engine.send_aux_request(
targets=[target],
Expand All @@ -528,23 +534,23 @@ def _send_request(
# the reply is already the result - we are done!
# this could happen when we didn't get positive ack for our first request, and the result was
# already produced when we did the 2nd request (this request).
cls.info(fl_ctx, f"C1: received result in {time.time()-receiver.tx_start_time} seconds; {rc=}")
cls.debug(fl_ctx, f"C1: received result in {time.time()-receiver.tx_start_time} seconds; {rc=}")
return ack

# the ack is a status report - check status
status = ack.get_header(HEADER_STATUS)
if status and status != STATUS_NOT_RECEIVED:
# status should never be STATUS_NOT_RECEIVED, unless there is a bug in the receiving logic
# STATUS_NOT_RECEIVED is only possible during "query" phase.
cls.info(fl_ctx, f"received status ack: {rc=} {status=}")
cls.debug(fl_ctx, f"received status ack: {rc=} {status=}")
break

if time.time() + cls._query_interval - receiver.tx_start_time >= tx_timeout:
cls.error(fl_ctx, f"aborting send_request since it will exceed {tx_timeout=}")
return make_reply(ReturnCode.COMMUNICATION_ERROR)

# we didn't get a positive ack - wait a short time and re-send the request.
cls.info(fl_ctx, f"unsure the request was received ({rc=}): will retry in {cls._query_interval} secs")
cls.debug(fl_ctx, f"unsure the request was received ({rc=}): will retry in {cls._query_interval} secs")
num_tries += 1
start = time.time()
while time.time() - start < cls._query_interval:
Expand All @@ -553,7 +559,7 @@ def _send_request(
return make_reply(ReturnCode.TASK_ABORTED)
time.sleep(0.1)

cls.info(fl_ctx, "request was received by the peer - will query for result")
cls.debug(fl_ctx, "request was received by the peer - will query for result")
return cls._query_result(target, abort_signal, fl_ctx, receiver)

@classmethod
Expand Down Expand Up @@ -585,7 +591,7 @@ def _query_result(
# we already received result sent by the target.
# Note that we don't wait forever here - we only wait for _query_interval, so we could
# check other condition and/or send query to ask for result.
cls.info(fl_ctx, f"C2: received result in {time.time()-receiver.tx_start_time} seconds")
cls.debug(fl_ctx, f"C2: received result in {time.time()-receiver.tx_start_time} seconds")
return receiver.result

if abort_signal and abort_signal.triggered:
Expand All @@ -599,21 +605,26 @@ def _query_result(
# send a query. The ack of the query could be the result itself, or a status report.
# Note: the ack could be the result because we failed to receive the result sent by the target earlier.
num_tries += 1
cls.info(fl_ctx, f"query #{num_tries}: try to get result from {target}: {per_msg_timeout=}")
cls.debug(fl_ctx, f"query #{num_tries}: try to get result from {target}: {per_msg_timeout=}")
ack = engine.send_aux_request(
targets=[target],
topic=TOPIC_RELIABLE_REQUEST,
request=query,
timeout=per_msg_timeout,
fl_ctx=fl_ctx,
)

# Ignore query result if result is already received
if receiver.result_ready.is_set():
return receiver.result

last_query_time = time.time()
ack, rc = _extract_result(ack, target)
if ack and rc not in [ReturnCode.COMMUNICATION_ERROR]:
op = ack.get_header(HEADER_OP)
if op == OP_REPLY:
# the ack is result itself!
cls.info(fl_ctx, f"C3: received result in {time.time()-receiver.tx_start_time} seconds")
cls.debug(fl_ctx, f"C3: received result in {time.time()-receiver.tx_start_time} seconds")
return ack

status = ack.get_header(HEADER_STATUS)
Expand All @@ -625,6 +636,6 @@ def _query_result(
cls.error(fl_ctx, f"peer {target} aborted processing!")
return _error_reply(ReturnCode.EXECUTION_EXCEPTION, "Aborted")

cls.info(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}")
cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}")
else:
cls.info(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}")
cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}")
Loading