Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into admin_big_file
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv authored Sep 12, 2023
2 parents 5dd6f9e + b862abc commit 434f116
Show file tree
Hide file tree
Showing 12 changed files with 461 additions and 119 deletions.
185 changes: 102 additions & 83 deletions nvflare/fuel/f3/cellnet/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,34 @@

import logging
import threading
import uuid
from typing import Dict, List, Union

from nvflare.apis.fl_constant import ServerCommandNames
from nvflare.fuel.f3.cellnet.core_cell import CoreCell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, MessageType
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, MessageType, ReturnCode
from nvflare.fuel.f3.cellnet.utils import decode_payload, encode_payload, make_reply
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.stream_cell import StreamCell
from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey
from nvflare.fuel.f3.streaming.stream_types import StreamFuture
from nvflare.private.defs import CellChannel


class SimpleWaiter:
def __init__(self, req_id, result):
super().__init__()
self.req_id = req_id
self.result = result
self.receiving_futre = None
self.receiving_future = None
self.in_receiving = threading.Event()


class Adapter:
def __init__(self, cb, my_info, nice_cell):
def __init__(self, cb, my_info, cell):
self.cb = cb
self.my_info = my_info
self.nice_cell = nice_cell
self.cell = cell
self.logger = logging.getLogger(self.__class__.__name__)

def call(self, future): # this will be called by StreamCell upon receiving the first byte of blob
Expand All @@ -46,11 +50,15 @@ def call(self, future): # this will be called by StreamCell upon receiving the
origin = headers.get(MessageHeaderKey.ORIGIN, None)
result = future.result()
request = Message(headers, result)

decode_payload(request, StreamHeaderKey.PAYLOAD_ENCODING)

channel = request.get_header(StreamHeaderKey.CHANNEL)
request.set_header(MessageHeaderKey.CHANNEL, channel)
topic = request.get_header(StreamHeaderKey.TOPIC)
request.set_header(MessageHeaderKey.TOPIC, topic)
req_id = request.get_header(MessageHeaderKey.REQ_ID, "")
secure = request.get_header(MessageHeaderKey.SECURE, False)
response = self.cb(request)
response.add_headers(
{
Expand All @@ -59,7 +67,9 @@ def call(self, future): # this will be called by StreamCell upon receiving the
StreamHeaderKey.STREAM_REQ_ID: stream_req_id,
}
)
messagesend_future = self.nice_cell.send_blob(channel, topic, origin, response)

encode_payload(response, StreamHeaderKey.PAYLOAD_ENCODING)
self.cell.send_blob(channel, topic, origin, response, secure)


class Cell(StreamCell):
Expand All @@ -79,7 +89,7 @@ def method(*args, **kwargs):
return method

def fire_and_forget(
self, channel: str, topic: str, targets: Union[str, List[str]], message: Message, optional=False
self, channel: str, topic: str, targets: Union[str, List[str]], message: Message, secure=False, optional=False
) -> Dict[str, str]:
"""
Send a message over a channel to specified destination cell(s), and do not wait for replies.
Expand All @@ -89,25 +99,31 @@ def fire_and_forget(
topic: topic of the message
targets: one or more destination cell IDs. None means all.
message: message to be sent
secure: End-end encryption if True
optional: whether the message is optional
Returns: None
"""
# if channel == CellChannel.SERVER_COMMAND and topic == ServerCommandNames.HANDLE_DEAD_JOB:
# if isinstance(targets, list):
# for target in targets:
# self.send_blob(channel=channel, topic=topic, target=target, message=message)
# else:
# self.send_blob(channel=channel, topic=topic, target=targets, message=message)
# else:
# self.core_cell.fire_and_forget(
# channel=channel, topic=topic, targets=targets, message=message, optional=optional
# )

self.core_cell.fire_and_forget(
channel=channel, topic=topic, targets=targets, message=message, optional=optional
)

if channel == CellChannel.SERVER_COMMAND and topic == ServerCommandNames.HANDLE_DEAD_JOB:

encode_payload(message, encoding_key=StreamHeaderKey.PAYLOAD_ENCODING)

result = {}
if isinstance(targets, list):
for target in targets:
self.send_blob(channel=channel, topic=topic, target=target, message=message, secure=secure)
result[target] = ""
else:
self.send_blob(channel=channel, topic=topic, target=targets, message=message, secure=secure)
result[targets] = ""

return result
else:
return self.core_cell.fire_and_forget(
channel=channel, topic=topic, targets=targets, message=message, optional=optional
)

def _get_result(self, req_id):
waiter = self.requests_dict.pop(req_id)
Expand All @@ -124,63 +140,68 @@ def _future_wait(self, future, timeout):
last_progress = current_progress
return True

def send_request(self, channel, target, topic, request, timeout=10.0, optional=False):
def send_request(self, channel, target, topic, request, timeout=10.0, secure=False, optional=False):

self.logger.debug(f"send_request: {channel=}, {topic=}, {target=}, {timeout=}")
# if channel != CellChannel.SERVER_COMMAND:
# return self.core_cell.send_request(
# channel=channel, target=target, topic=topic, request=request, timeout=timeout, optional=optional
# )
#
# request.payload = fobs.dumps(request.payload)
#
# req_id = str(uuid.uuid4())
# request.add_headers({StreamHeaderKey.STREAM_REQ_ID: req_id})
#
# # this future can be used to check sending progress, but not for checking return blob
# future = self.send_blob(channel, topic, target, request)
#
# waiter = SimpleWaiter(req_id=req_id, result=make_reply(ReturnCode.TIMEOUT))
# self.requests_dict[req_id] = waiter
#
# # Three stages, sending, waiting for receiving first byte, receiving
#
# # sending with progress timeout
# self.logger.debug(f"{req_id=}: entering sending wait {timeout=}")
# sending_complete = self._future_wait(future, timeout)
# if not sending_complete:
# self.logger.debug(f"{req_id=}: sending timeout")
# return self._get_result(req_id)
# self.logger.debug(f"{req_id=}: sending complete")
#
# # waiting for receiving first byte
# self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}")
# if not waiter.in_receiving.wait(timeout):
# self.logger.debug(f"{req_id=}: remote processing timeout")
# return self._get_result(req_id)
# self.logger.debug(f"{req_id=}: in receiving")
#
# # receiving with progress timeout
# r_future = waiter.receiving_future
# self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}")
# receiving_complete = self._future_wait(r_future, timeout)
# if not receiving_complete:
# self.logger.debug(f"{req_id=}: receiving timeout")
# return self._get_result(req_id)
# self.logger.debug(f"{req_id=}: receiving complete")
# waiter.result = Message(r_future.headers, r_future.result())
# return self._get_result(req_id)

return self.core_cell.send_request(
channel=channel, target=target, topic=topic, request=request, timeout=timeout, optional=optional
)

if channel != CellChannel.SERVER_COMMAND:
return self.core_cell.send_request(
channel=channel,
target=target,
topic=topic,
request=request,
timeout=timeout,
secure=secure,
optional=optional,
)

encode_payload(request, StreamHeaderKey.PAYLOAD_ENCODING)

req_id = str(uuid.uuid4())
request.add_headers({StreamHeaderKey.STREAM_REQ_ID: req_id})

# this future can be used to check sending progress, but not for checking return blob
future = self.send_blob(channel, topic, target, request, secure)

waiter = SimpleWaiter(req_id=req_id, result=make_reply(ReturnCode.TIMEOUT))
self.requests_dict[req_id] = waiter

# Three stages, sending, waiting for receiving first byte, receiving

# sending with progress timeout
self.logger.debug(f"{req_id=}: entering sending wait {timeout=}")
sending_complete = self._future_wait(future, timeout)
if not sending_complete:
self.logger.debug(f"{req_id=}: sending timeout")
return self._get_result(req_id)
self.logger.debug(f"{req_id=}: sending complete")

# waiting for receiving first byte
self.logger.debug(f"{req_id=}: entering remote process wait {timeout=}")
if not waiter.in_receiving.wait(timeout):
self.logger.debug(f"{req_id=}: remote processing timeout")
return self._get_result(req_id)
self.logger.debug(f"{req_id=}: in receiving")

# receiving with progress timeout
r_future = waiter.receiving_future
self.logger.debug(f"{req_id=}: entering receiving wait {timeout=}")
receiving_complete = self._future_wait(r_future, timeout)
if not receiving_complete:
self.logger.debug(f"{req_id=}: receiving timeout")
return self._get_result(req_id)
self.logger.debug(f"{req_id=}: receiving complete")
waiter.result = Message(r_future.headers, r_future.result())
decode_payload(waiter.result, encoding_key=StreamHeaderKey.PAYLOAD_ENCODING)
return self._get_result(req_id)

def _process_reply(self, future: StreamFuture):
headers = future.headers
req_id = headers.get(StreamHeaderKey.STREAM_REQ_ID, -1)
try:
waiter = self.requests_dict[req_id]
except KeyError as e:
self.logger.warning(f"Receiving unknown {req_id=}, discarded")
self.logger.warning(f"Receiving unknown {req_id=}, discarded: {e}")
return
waiter.receiving_future = future
waiter.in_receiving.set()
Expand All @@ -199,19 +220,17 @@ def register_request_cb(self, channel: str, topic: str, cb, *args, **kwargs):
Returns:
"""

if not callable(cb):
raise ValueError(f"specified request_cb {type(cb)} is not callable")
# if channel == CellChannel.SERVER_COMMAND and topic in [
# "*",
# ServerCommandNames.GET_TASK,
# ServerCommandNames.SUBMIT_UPDATE,
# ]:
# self.logger.debug(f"Register blob CB for {channel=}, {topic=}")
# adapter = Adapter(cb, self.core_cell.my_info, self)
# self.register_blob_cb(channel, topic, adapter.call, *args, **kwargs)
# else:
# self.logger.debug(f"Register regular CB for {channel=}, {topic=}")
# self.core_cell.register_request_cb(channel, topic, cb, *args, **kwargs)

self.logger.debug(f"Register regular CB for {channel=}, {topic=}")
self.core_cell.register_request_cb(channel, topic, cb, *args, **kwargs)
if channel == CellChannel.SERVER_COMMAND and topic in [
"*",
ServerCommandNames.GET_TASK,
ServerCommandNames.SUBMIT_UPDATE,
]:
self.logger.debug(f"Register blob CB for {channel=}, {topic=}")
adapter = Adapter(cb, self.core_cell.my_info, self)
self.register_blob_cb(channel, topic, adapter.call, *args, **kwargs)
else:
self.logger.debug(f"Register regular CB for {channel=}, {topic=}")
self.core_cell.register_request_cb(channel, topic, cb, *args, **kwargs)
11 changes: 11 additions & 0 deletions nvflare/fuel/f3/cellnet/cell_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def _sign(k, m):


def _verify(k, m, s):

if not isinstance(m, bytes):
m = bytes(m)

if not isinstance(s, bytes):
s = bytes(s)

k.verify(
s,
m,
Expand Down Expand Up @@ -210,6 +217,10 @@ def decrypt(self, message: bytes, origin_cert: Certificate):
message[NONCE_LENGTH : NONCE_LENGTH + KEY_ENC_LENGTH],
message[NONCE_LENGTH + KEY_ENC_LENGTH : SIMPLE_HEADER_LENGTH],
)

if not isinstance(key_enc, bytes):
key_enc = bytes(key_enc)

key_hash = hash(key_enc)
dec = self._cached_dec.get(key_hash)
if dec is None:
Expand Down
Loading

0 comments on commit 434f116

Please sign in to comment.