diff --git a/nvflare/apis/aux_spec.py b/nvflare/apis/aux_spec.py index a238668bab..e48ba2374c 100644 --- a/nvflare/apis/aux_spec.py +++ b/nvflare/apis/aux_spec.py @@ -13,6 +13,8 @@ # limitations under the License. import enum +from abc import ABC, abstractmethod +from typing import Dict from .fl_context import FLContext from .shareable import Shareable @@ -44,3 +46,85 @@ def aux_request_handle_func_signature(topic: str, request: Shareable, fl_ctx: FL """ pass + + +class AuxMessenger(ABC): + @abstractmethod + def register_aux_message_handler(self, topic: str, message_handle_func): + """Register aux message handling function with specified topics. + + Exception is raised when: + a handler is already registered for the topic; + bad topic - must be a non-empty string + bad message_handle_func - must be callable + + Implementation Note: + This method should simply call the ServerAuxRunner's register_aux_message_handler method. + + Args: + topic: the topic to be handled by the func + message_handle_func: the func to handle the message. Must follow aux_message_handle_func_signature. + + """ + pass + + @abstractmethod + def send_aux_request( + self, + targets: [], + topic: str, + request: Shareable, + timeout: float, + fl_ctx: FLContext, + optional=False, + secure=False, + ) -> dict: + """Send a request to specified clients via the aux channel. + + Implementation: simply calls the AuxRunner's send_aux_request method. + + Args: + targets: target clients. None or empty list means all clients. + topic: topic of the request. + request: request to be sent + timeout: number of secs to wait for replies. 0 means fire-and-forget. + fl_ctx: FL context + optional: whether this message is optional + secure: send the aux request in a secure way + + Returns: a dict of replies (client name => reply Shareable) + + """ + pass + + @abstractmethod + def multicast_aux_requests( + self, + topic: str, + target_requests: Dict[str, Shareable], + timeout: float, + fl_ctx: FLContext, + optional: bool = False, + secure: bool = False, + ) -> dict: + """Send requests to specified clients via the aux channel. + + Implementation: simply calls the AuxRunner's multicast_aux_requests method. + + Args: + topic: topic of the request + target_requests: requests of the target clients. Different target can have different request. + timeout: amount of time to wait for responses. 0 means fire and forget. + fl_ctx: FL context + optional: whether this request is optional + secure: whether to send the aux request in P2P secure + + Returns: a dict of replies (client name => reply Shareable) + + """ + pass + + def fire_and_forget_aux_request( + self, targets: [], topic: str, request: Shareable, fl_ctx: FLContext, optional=False, secure=False + ) -> dict: + return self.send_aux_request(targets, topic, request, 0.0, fl_ctx, optional, secure=secure) diff --git a/nvflare/apis/rm.py b/nvflare/apis/rm.py index 9a93215cdd..b7dfa7a337 100644 --- a/nvflare/apis/rm.py +++ b/nvflare/apis/rm.py @@ -30,8 +30,8 @@ def send_reliable_request( per_msg_timeout: float, tx_timeout: float, fl_ctx: FLContext, - secure=False, optional=False, + secure=False, ) -> Shareable: """Send a reliable request. diff --git a/nvflare/apis/server_engine_spec.py b/nvflare/apis/server_engine_spec.py index 4c4013f0c1..af5eb9d9c3 100644 --- a/nvflare/apis/server_engine_spec.py +++ b/nvflare/apis/server_engine_spec.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple -from nvflare.apis.shareable import Shareable +from nvflare.apis.aux_spec import AuxMessenger from nvflare.widgets.widget import Widget from .client import Client @@ -26,7 +26,7 @@ from .workspace import Workspace -class ServerEngineSpec(EngineSpec, ABC): +class ServerEngineSpec(EngineSpec, AuxMessenger, ABC): @abstractmethod def fire_event(self, event_type: str, fl_ctx: FLContext): pass @@ -84,86 +84,6 @@ def get_component(self, component_id: str) -> object: """ pass - @abstractmethod - def register_aux_message_handler(self, topic: str, message_handle_func): - """Register aux message handling function with specified topics. - - Exception is raised when: - a handler is already registered for the topic; - bad topic - must be a non-empty string - bad message_handle_func - must be callable - - Implementation Note: - This method should simply call the ServerAuxRunner's register_aux_message_handler method. - - Args: - topic: the topic to be handled by the func - message_handle_func: the func to handle the message. Must follow aux_message_handle_func_signature. - - """ - pass - - @abstractmethod - def send_aux_request( - self, - targets: [], - topic: str, - request: Shareable, - timeout: float, - fl_ctx: FLContext, - optional=False, - secure=False, - ) -> dict: - """Send a request to specified clients via the aux channel. - - Implementation: simply calls the AuxRunner's send_aux_request method. - - Args: - targets: target clients. None or empty list means all clients. - topic: topic of the request. - request: request to be sent - timeout: number of secs to wait for replies. 0 means fire-and-forget. - fl_ctx: FL context - optional: whether this message is optional - secure: send the aux request in a secure way - - Returns: a dict of replies (client name => reply Shareable) - - """ - pass - - @abstractmethod - def multicast_aux_requests( - self, - topic: str, - target_requests: Dict[str, Shareable], - timeout: float, - fl_ctx: FLContext, - optional: bool = False, - secure: bool = False, - ) -> dict: - """Send requests to specified clients via the aux channel. - - Implementation: simply calls the AuxRunner's multicast_aux_requests method. - - Args: - topic: topic of the request - target_requests: requests of the target clients. Different target can have different request. - timeout: amount of time to wait for responses. 0 means fire and forget. - fl_ctx: FL context - optional: whether this request is optional - secure: whether to send the aux request in P2P secure - - Returns: a dict of replies (client name => reply Shareable) - - """ - pass - - def fire_and_forget_aux_request( - self, targets: [], topic: str, request: Shareable, fl_ctx: FLContext, optional=False, secure=False - ) -> dict: - return self.send_aux_request(targets, topic, request, 0.0, fl_ctx, optional, secure=secure) - @abstractmethod def get_widget(self, widget_id: str) -> Widget: """Get the widget with the specified ID. diff --git a/nvflare/app_opt/job_launcher/docker_launcher.py b/nvflare/app_opt/job_launcher/docker_launcher.py index 48627db81a..ca4a03ce50 100644 --- a/nvflare/app_opt/job_launcher/docker_launcher.py +++ b/nvflare/app_opt/job_launcher/docker_launcher.py @@ -45,7 +45,6 @@ class DOCKER_STATE: class DockerJobHandle(JobHandleSpec): - def __init__(self, container, timeout=None): super().__init__() diff --git a/nvflare/private/fed/client/client_engine.py b/nvflare/private/fed/client/client_engine.py index baf944a956..4e248233c6 100644 --- a/nvflare/private/fed/client/client_engine.py +++ b/nvflare/private/fed/client/client_engine.py @@ -17,15 +17,14 @@ import shutil import sys import threading -from typing import List +from typing import Dict +from nvflare.apis.aux_spec import AuxMessenger from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FLContextKey, MachineStatus, ProcessType, SystemComponents, WorkspaceConstants from nvflare.apis.fl_context import FLContext, FLContextManager -from nvflare.apis.rm import RMEngine from nvflare.apis.shareable import Shareable -from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamableEngine, StreamContext from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx from nvflare.apis.workspace import Workspace from nvflare.fuel.f3.cellnet.cell import Cell @@ -38,8 +37,7 @@ from nvflare.private.fed.server.job_meta_validator import JobMetaValidator from nvflare.private.fed.utils.app_deployer import AppDeployer from nvflare.private.fed.utils.fed_utils import security_close -from nvflare.private.rm_runner import ReliableMessenger -from nvflare.private.stream_runner import ObjectStreamer +from nvflare.private.msg_engine import MessagingEngine from nvflare.security.logging import secure_format_exception, secure_log_traceback from .client_engine_internal_spec import ClientEngineInternalSpec @@ -56,7 +54,7 @@ def _remove_custom_path(): sys.path.remove(path) -class ClientEngine(ClientEngineInternalSpec, StreamableEngine, RMEngine): +class ClientEngine(ClientEngineInternalSpec, AuxMessenger, MessagingEngine): """ClientEngine runs in the client parent process (CP).""" def __init__(self, client: FederatedClient, args, rank, workers=5): @@ -69,6 +67,7 @@ def __init__(self, client: FederatedClient, args, rank, workers=5): workers: number of workers """ super().__init__() + MessagingEngine.__init__(self, messenger=self) self.client = client self.client_name = client.client_name self.args = args @@ -76,8 +75,6 @@ def __init__(self, client: FederatedClient, args, rank, workers=5): self.client_executor = JobExecutor(client, os.path.join(args.workspace, "startup")) self.admin_agent = None self.aux_runner = AuxRunner(self) - self.object_streamer = ObjectStreamer(self.aux_runner) - self.reliable_messenger = ReliableMessenger(self.aux_runner) self.cell = None self.fl_ctx_mgr = FLContextManager( @@ -168,6 +165,7 @@ def register_aux_message_handler(self, topic: str, message_handle_func): def send_aux_request( self, + targets: [], topic: str, request: Shareable, timeout: float, @@ -180,6 +178,7 @@ def send_aux_request( Implementation: simply calls the AuxRunner's send_aux_request method. Args: + targets: not used topic: topic of the request. request: request to be sent timeout: number of secs to wait for replies. 0 means fire-and-forget. @@ -207,132 +206,29 @@ def send_aux_request( else: return Shareable() - def stream_objects( + def multicast_aux_requests( self, - channel: str, topic: str, - stream_ctx: StreamContext, - targets: List[str], - producer: ObjectProducer, + target_requests: Dict[str, Shareable], + timeout: float, fl_ctx: FLContext, - optional=False, - secure=False, - ): - """Send a stream of Shareable objects to receivers. - - Args: - channel: the channel for this stream - topic: topic of the stream - stream_ctx: context of the stream - targets: receiving sites - producer: the ObjectProducer that can produces the stream of Shareable objects - fl_ctx: the FLContext object - optional: whether the stream is optional - secure: whether to use P2P security - - Returns: result from the generator's reply processing - - """ - if not self.object_streamer: - raise RuntimeError("object streamer has not been created") - - # We are CP: can only stream to SP - if targets: - for t in targets: - self.logger.debug(f"ignored target: {t}") - - return self.object_streamer.stream( - channel=channel, - topic=topic, - stream_ctx=stream_ctx, - targets=[AuxMsgTarget.server_target()], - producer=producer, - fl_ctx=fl_ctx, - secure=secure, - optional=optional, - ) - - def register_stream_processing( - self, - channel: str, - topic: str, - factory: ConsumerFactory, - stream_done_cb=None, - **cb_kwargs, - ): - """Register a ConsumerFactory for specified app channel and topic. - Once a new streaming request is received for the channel/topic, the registered factory will be used - to create an ObjectConsumer object to handle the new stream. - - Note: the factory should generate a new ObjectConsumer every time get_consumer() is called. This is because - multiple streaming sessions could be going on at the same time. Each streaming session should have its - own ObjectConsumer. + optional: bool = False, + secure: bool = False, + ) -> dict: + """No need for this since targets can only be server. Args: - channel: app channel - topic: app topic - factory: the factory to be registered - stream_done_cb: the callback to be called when streaming is done on receiving side + topic: + target_requests: + timeout: + fl_ctx: + optional: + secure: - Returns: None + Returns: """ - if not self.object_streamer: - raise RuntimeError("object streamer has not been created") - - self.object_streamer.register_stream_processing( - topic=topic, channel=channel, factory=factory, stream_done_cb=stream_done_cb, **cb_kwargs - ) - - def shutdown_streamer(self): - if self.object_streamer: - self.object_streamer.shutdown() - - def register_reliable_request_handler(self, channel: str, topic: str, handler_f, **handler_kwargs): - if not self.reliable_messenger: - raise RuntimeError("reliable messenger has not been created") - - self.reliable_messenger.register_request_handler( - channel=channel, - topic=topic, - handler_f=handler_f, - **handler_kwargs, - ) - - def send_reliable_request( - self, - target: str, - channel: str, - topic: str, - request: Shareable, - per_msg_timeout: float, - tx_timeout: float, - fl_ctx: FLContext, - secure=False, - optional=False, - ) -> Shareable: - if not self.reliable_messenger: - raise RuntimeError("reliable messenger has not been created") - - # We are CP: can only stream to SP - if target: - self.logger.debug(f"ignored target '{target}'") - - return self.reliable_messenger.send_request( - target=AuxMsgTarget.server_target(), - channel=channel, - topic=topic, - request=request, - per_msg_timeout=per_msg_timeout, - tx_timeout=tx_timeout, - fl_ctx=fl_ctx, - secure=secure, - optional=optional, - ) - - def shutdown_reliable_messenger(self): - if self.reliable_messenger: - self.reliable_messenger.shutdown() + pass def set_agent(self, admin_agent): self.admin_agent = admin_agent @@ -478,8 +374,7 @@ def shutdown(self) -> str: thread = threading.Thread(target=shutdown_client, args=(self.client, touch_file)) thread.start() - self.shutdown_streamer() - self.shutdown_reliable_messenger() + self.shutdown_messaging() return "Shutdown the client..." def restart(self) -> str: diff --git a/nvflare/private/fed/client/client_engine_executor_spec.py b/nvflare/private/fed/client/client_engine_executor_spec.py index 062955d505..e7a5a94bcc 100644 --- a/nvflare/private/fed/client/client_engine_executor_spec.py +++ b/nvflare/private/fed/client/client_engine_executor_spec.py @@ -14,8 +14,8 @@ import time from abc import ABC, abstractmethod -from typing import Dict, List, Union +from nvflare.apis.aux_spec import AuxMessenger from nvflare.apis.client_engine_spec import ClientEngineSpec from nvflare.apis.engine_spec import EngineSpec from nvflare.apis.fl_context import FLContext @@ -42,7 +42,7 @@ def __init__(self, name: str, task_id: str, data: Shareable): self.receive_time = time.time() -class ClientEngineExecutorSpec(ClientEngineSpec, EngineSpec, ABC): +class ClientEngineExecutorSpec(AuxMessenger, ClientEngineSpec, EngineSpec, ABC): """The ClientEngineExecutorSpec defines the ClientEngine APIs running in the child process.""" @abstractmethod @@ -65,101 +65,6 @@ def get_widget(self, widget_id: str) -> Widget: def get_all_components(self) -> dict: pass - @abstractmethod - def register_aux_message_handler(self, topic: str, message_handle_func): - """Register aux message handling function with specified topics. - - Exception is raised when: - a handler is already registered for the topic; - bad topic - must be a non-empty string - bad message_handle_func - must be callable - - Implementation Note: - This method should simply call the ClientAuxRunner's register_aux_message_handler method. - - Args: - topic: the topic to be handled by the func - message_handle_func: the func to handle the message. Must follow aux_message_handle_func_signature. - - """ - pass - - @abstractmethod - def send_aux_request( - self, - targets: Union[None, str, List[str]], - topic: str, - request: Shareable, - timeout: float, - fl_ctx: FLContext, - optional=False, - secure: bool = False, - ) -> dict: - """Send a request to Server via the aux channel. - - Implementation: simply calls the ClientAuxRunner's send_aux_request method. - - Args: - targets: aux messages targets. None or empty list means the server. - topic: topic of the request - request: request to be sent - timeout: number of secs to wait for replies. 0 means fire-and-forget. - fl_ctx: FL context - optional: whether the request is optional - secure: should the request sent in the secure way - - Returns: - a dict of reply Shareable in the format of: - { site_name: reply_shareable } - - """ - pass - - @abstractmethod - def multicast_aux_requests( - self, - topic: str, - target_requests: Dict[str, Shareable], - timeout: float, - fl_ctx: FLContext, - optional: bool = False, - secure: bool = False, - ) -> dict: - """Send requests to specified targets (server or other clients) via the aux channel. - - Implementation: simply calls the AuxRunner's multicast_aux_requests method. - - Args: - topic: topic of the request - target_requests: requests of the targets. Different target can have different request. - timeout: amount of time to wait for responses. 0 means fire and forget. - fl_ctx: FL context - optional: whether this request is optional - secure: whether to send the aux request in P2P secure - - Returns: a dict of replies (client name => reply Shareable) - - """ - pass - - @abstractmethod - def fire_and_forget_aux_request( - self, topic: str, request: Shareable, fl_ctx: FLContext, optional=False, secure=False - ) -> Shareable: - """Send an async request to Server via the aux channel. - - Args: - topic: topic of the request. - request: request to be sent - fl_ctx: FL context - optional: whether the request is optional - secure: whether to send the message in P2P secure mode - - Returns: - - """ - pass - @abstractmethod def build_component(self, config_dict): """Build a component from the config_dict. diff --git a/nvflare/private/fed/client/client_run_manager.py b/nvflare/private/fed/client/client_run_manager.py index 6324bcbc43..bf83f98ef9 100644 --- a/nvflare/private/fed/client/client_run_manager.py +++ b/nvflare/private/fed/client/client_run_manager.py @@ -18,9 +18,7 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FLContextKey, ProcessType, ServerCommandKey, ServerCommandNames, SiteType from nvflare.apis.fl_context import FLContext, FLContextManager -from nvflare.apis.rm import RMEngine from nvflare.apis.shareable import Shareable -from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamableEngine, StreamContext from nvflare.apis.workspace import Workspace from nvflare.fuel.f3.cellnet.core_cell import FQCN from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey @@ -30,8 +28,7 @@ from nvflare.private.defs import CellChannel, CellMessageHeaderKeys, new_cell_message from nvflare.private.event import fire_event from nvflare.private.fed.utils.fed_utils import create_job_processing_context_properties -from nvflare.private.rm_runner import ReliableMessenger -from nvflare.private.stream_runner import ObjectStreamer +from nvflare.private.msg_engine import MessagingEngine from nvflare.widgets.fed_event import ClientFedEventRunner from nvflare.widgets.info_collector import InfoCollector from nvflare.widgets.widget import Widget, WidgetID @@ -60,7 +57,7 @@ def __init__(self, job_id): GET_CLIENTS_RETRY = 300 -class ClientRunManager(ClientEngineExecutorSpec, StreamableEngine, RMEngine): +class ClientRunManager(ClientEngineExecutorSpec, MessagingEngine): """ClientRunManager provides the ClientEngine APIs implementation running in the child process (CJ).""" def __init__( @@ -81,21 +78,18 @@ def __init__( workspace: workspace client: FL client object components: available FL components - handlers: available handlers + handlers: available handlers. conf: ClientJsonConfigurator object """ super().__init__() - + MessagingEngine.__init__(self, messenger=self) self.client = client self.handlers = handlers self.workspace = workspace self.components = components self.aux_runner = AuxRunner(self) - self.object_streamer = ObjectStreamer(self.aux_runner) - self.reliable_messenger = ReliableMessenger(self.aux_runner) self.add_handler(self.aux_runner) - self.add_handler(self.object_streamer) - self.add_handler(self.reliable_messenger) + self.conf = conf self.cell = None @@ -334,105 +328,6 @@ def get_all_clients_from_server(self, fl_ctx, retry=0): def register_aux_message_handler(self, topic: str, message_handle_func): self.aux_runner.register_aux_message_handler(topic, message_handle_func) - def fire_and_forget_aux_request( - self, topic: str, request: Shareable, fl_ctx: FLContext, optional=False, secure=False - ) -> dict: - return self.send_aux_request( - targets=None, topic=topic, request=request, timeout=0.0, fl_ctx=fl_ctx, optional=optional, secure=secure - ) - - def stream_objects( - self, - channel: str, - topic: str, - stream_ctx: StreamContext, - targets: List[str], - producer: ObjectProducer, - fl_ctx: FLContext, - optional=False, - secure=False, - ): - if not self.object_streamer: - raise RuntimeError("object streamer has not been created") - - return self.object_streamer.stream( - channel=channel, - topic=topic, - stream_ctx=stream_ctx, - targets=self._to_aux_msg_targets(targets), - producer=producer, - fl_ctx=fl_ctx, - secure=secure, - optional=optional, - ) - - def register_stream_processing( - self, - channel: str, - topic: str, - factory: ConsumerFactory, - stream_done_cb=None, - **cb_kwargs, - ): - if not self.object_streamer: - raise RuntimeError("object streamer has not been created") - - self.object_streamer.register_stream_processing(channel, topic, factory, stream_done_cb, **cb_kwargs) - - def shutdown_streamer(self): - if self.object_streamer: - self.object_streamer.shutdown() - - def register_reliable_request_handler(self, channel: str, topic: str, handler_f, **handler_kwargs): - if not self.reliable_messenger: - raise RuntimeError("reliable messenger has not been created") - - self.reliable_messenger.register_request_handler( - channel=channel, - topic=topic, - handler_f=handler_f, - **handler_kwargs, - ) - - def send_reliable_request( - self, - target: str, - channel: str, - topic: str, - request: Shareable, - per_msg_timeout: float, - tx_timeout: float, - fl_ctx: FLContext, - secure=False, - optional=False, - ) -> Shareable: - if not self.reliable_messenger: - raise RuntimeError("reliable messenger has not been created") - - if not target: - target = AuxMsgTarget.server_target() - else: - target = self._get_aux_msg_target(target) - - if not target: - raise ValueError(f"invalid target '{target}'") - - return self.reliable_messenger.send_request( - target=target, - channel=channel, - topic=topic, - request=request, - per_msg_timeout=per_msg_timeout, - tx_timeout=tx_timeout, - fl_ctx=fl_ctx, - secure=secure, - optional=optional, - ) - - def shutdown_reliable_messenger(self): - if self.reliable_messenger: - self.reliable_messenger.shutdown() - def abort_app(self, job_id: str, fl_ctx: FLContext): runner = fl_ctx.get_prop(key=FLContextKey.RUNNER, default=None) if isinstance(runner, ClientRunner): diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index 94234e84fe..593aeafddc 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -29,10 +29,8 @@ ) from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeJobError -from nvflare.apis.rm import RMEngine from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply from nvflare.apis.signal import Signal -from nvflare.apis.streaming import StreamableEngine from nvflare.apis.utils.fl_context_utils import add_job_audit_event from nvflare.apis.utils.task_utils import apply_filters from nvflare.fuel.f3.cellnet.fqcn import FQCN @@ -40,6 +38,7 @@ from nvflare.private.fed.client.client_engine_executor_spec import ClientEngineExecutorSpec, TaskAssignment from nvflare.private.fed.tbi import TBI from nvflare.private.json_configer import ConfigError +from nvflare.private.msg_engine import MessagingEngine from nvflare.private.privacy_manager import Scope from nvflare.security.logging import secure_format_exception from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector @@ -591,11 +590,8 @@ def run(self, app_root, args): self.log_exception(fl_ctx, f"processing error in RUN execution: {secure_format_exception(e)}") finally: self.end_run_events_sequence() - assert isinstance(self.engine, StreamableEngine) - self.engine.shutdown_streamer() - - assert isinstance(self.engine, RMEngine) - self.engine.shutdown_reliable_messenger() + assert isinstance(self.engine, MessagingEngine) + self.engine.shutdown_messaging() with self.task_lock: self.running_tasks = {} diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index 03485aa3d0..54dab1d66d 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -41,9 +41,7 @@ from nvflare.apis.impl.job_def_manager import JobDefManagerSpec from nvflare.apis.job_def import Job from nvflare.apis.job_launcher_spec import JobLauncherSpec -from nvflare.apis.rm import RMEngine from nvflare.apis.shareable import ReturnCode, Shareable, make_reply -from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamableEngine, StreamContext from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx, get_serializable_data from nvflare.apis.workspace import Workspace from nvflare.fuel.f3.cellnet.cell import Cell @@ -64,6 +62,7 @@ security_close, set_message_security_data, ) +from nvflare.private.msg_engine import MessagingEngine from nvflare.private.scheduler_constants import ShareableHeader from nvflare.security.logging import secure_format_exception from nvflare.widgets.info_collector import InfoCollector @@ -79,7 +78,7 @@ from .server_status import ServerStatus -class ServerEngine(ServerEngineInternalSpec, StreamableEngine, RMEngine): +class ServerEngine(ServerEngineInternalSpec, MessagingEngine): def __init__(self, server, args, client_manager: ClientManager, snapshot_persistor, workers=3): """Server engine. @@ -90,6 +89,7 @@ def __init__(self, server, args, client_manager: ClientManager, snapshot_persist workers: number of worker threads. """ # TODO:: clean up the server function / requirement here should be BaseServer + MessagingEngine.__init__(self, messenger=self) self.server = server self.args = args self.run_processes = {} @@ -524,6 +524,7 @@ def send_aux_request( return self.send_aux_to_targets(targets, topic, request, timeout, fl_ctx, optional, secure) except Exception as e: self.logger.error(f"Failed to send the aux_message: {topic} with exception: {secure_format_exception(e)}.") + raise e def multicast_aux_requests( self, @@ -604,112 +605,6 @@ def send_aux_to_targets(self, targets, topic, request, timeout, fl_ctx, optional else: return {} - def stream_objects( - self, - channel: str, - topic: str, - stream_ctx: StreamContext, - targets: List[str], - producer: ObjectProducer, - fl_ctx: FLContext, - optional=False, - secure=False, - ): - if not self.run_manager: - raise RuntimeError("run_manager has not been created") - - if not self.run_manager.object_streamer: - raise RuntimeError("object_streamer has not been created") - - return self.run_manager.object_streamer.stream( - channel=channel, - topic=topic, - stream_ctx=stream_ctx, - targets=self._to_aux_msg_targets(targets), - producer=producer, - fl_ctx=fl_ctx, - secure=secure, - optional=optional, - ) - - def register_stream_processing( - self, - channel: str, - topic: str, - factory: ConsumerFactory, - stream_done_cb=None, - **cb_kwargs, - ): - if not self.run_manager: - raise RuntimeError("run_manager has not been created") - - if not self.run_manager.object_streamer: - raise RuntimeError("object_streamer has not been created") - - self.run_manager.object_streamer.register_stream_processing( - channel=channel, topic=topic, factory=factory, stream_done_cb=stream_done_cb, **cb_kwargs - ) - - def shutdown_streamer(self): - if self.run_manager and self.run_manager.object_streamer: - self.run_manager.object_streamer.shutdown() - - def register_reliable_request_handler(self, channel: str, topic: str, handler_f, **handler_kwargs): - if not self.run_manager: - raise RuntimeError("run_manager has not been created") - - if not self.run_manager.reliable_messenger: - raise RuntimeError("reliable_messenger has not been created") - - self.run_manager.reliable_messenger.register_request_handler( - channel=channel, - topic=topic, - handler_f=handler_f, - **handler_kwargs, - ) - - def send_reliable_request( - self, - target: str, - channel: str, - topic: str, - request: Shareable, - per_msg_timeout: float, - tx_timeout: float, - fl_ctx: FLContext, - secure=False, - optional=False, - ) -> Shareable: - if not self.run_manager: - raise RuntimeError("run_manager has not been created") - - if not self.run_manager.reliable_messenger: - raise RuntimeError("reliable_messenger has not been created") - - if not target: - target = AuxMsgTarget.server_target() - else: - target = self._get_aux_msg_target(target) - - if not target: - raise ValueError(f"invalid target '{target}'") - - return self.run_manager.reliable_messenger.send_request( - target=target, - channel=channel, - topic=topic, - request=request, - per_msg_timeout=per_msg_timeout, - tx_timeout=tx_timeout, - fl_ctx=fl_ctx, - secure=secure, - optional=optional, - ) - - def shutdown_reliable_messenger(self): - if self.run_manager and self.run_manager.reliable_messenger: - self.run_manager.reliable_messenger.shutdown() - def sync_clients_from_main_process(self): # repeatedly ask the parent process to get participating clients until we receive the result # or timed out after 30 secs (already tried 30 times). @@ -1012,8 +907,7 @@ def pause_server_jobs(self): def close(self): self.executor.shutdown() - self.shutdown_streamer() - self.shutdown_reliable_messenger() + self.shutdown_messaging() def server_shutdown(server, touch_file): diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index ebc8193b0e..392a167b5a 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -20,15 +20,14 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FilterKey, FLContextKey, ReservedKey, ReservedTopic, ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.rm import RMEngine from nvflare.apis.server_engine_spec import ServerEngineSpec from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply from nvflare.apis.signal import Signal -from nvflare.apis.streaming import StreamableEngine from nvflare.apis.utils.fl_context_utils import add_job_audit_event from nvflare.apis.utils.task_utils import apply_filters from nvflare.private.defs import SpecialTaskName, TaskConstant from nvflare.private.fed.tbi import TBI +from nvflare.private.msg_engine import MessagingEngine from nvflare.private.privacy_manager import Scope from nvflare.security.logging import secure_format_exception from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector @@ -217,11 +216,8 @@ def run(self): self.fire_event(EventType.END_RUN, fl_ctx) self.log_info(fl_ctx, "END_RUN fired") - assert isinstance(self.engine, StreamableEngine) - self.engine.shutdown_streamer() - - assert isinstance(self.engine, RMEngine) - self.engine.shutdown_reliable_messenger() + assert isinstance(self.engine, MessagingEngine) + self.engine.shutdown_messaging() self.log_info(fl_ctx, "Server runner finished.") diff --git a/nvflare/private/msg_engine.py b/nvflare/private/msg_engine.py new file mode 100644 index 0000000000..5c3ef74063 --- /dev/null +++ b/nvflare/private/msg_engine.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +from nvflare.apis.aux_spec import AuxMessenger +from nvflare.apis.fl_context import FLContext +from nvflare.apis.rm import RMEngine +from nvflare.apis.shareable import Shareable +from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamableEngine, StreamContext + +from .rm_runner import ReliableMessenger +from .stream_runner import ObjectStreamer + + +class MessagingEngine(StreamableEngine, RMEngine): + def __init__(self, messenger: AuxMessenger): + self.messenger = messenger + self.streamer = ObjectStreamer(messenger) + self.reliable_messenger = ReliableMessenger(messenger) + + def register_reliable_request_handler(self, channel: str, topic: str, handler_f, **handler_kwargs): + self.reliable_messenger.register_request_handler(channel, topic, handler_f, **handler_kwargs) + + def send_reliable_request( + self, + target: str, + channel: str, + topic: str, + request: Shareable, + per_msg_timeout: float, + tx_timeout: float, + fl_ctx: FLContext, + optional=False, + secure=False, + ) -> Shareable: + return self.reliable_messenger.send_request( + target, channel, topic, request, per_msg_timeout, tx_timeout, fl_ctx, secure, optional + ) + + def shutdown_reliable_messenger(self): + self.reliable_messenger.shutdown() + + def register_stream_processing( + self, + channel: str, + topic: str, + factory: ConsumerFactory, + stream_done_cb=None, + **cb_kwargs, + ): + return self.streamer.register_stream_processing(channel, topic, factory, stream_done_cb, **cb_kwargs) + + def stream_objects( + self, + channel: str, + topic: str, + stream_ctx: StreamContext, + targets: List[str], + producer: ObjectProducer, + fl_ctx: FLContext, + optional=False, + secure=False, + ): + return self.streamer.stream( + channel, + topic, + stream_ctx, + targets, + producer, + fl_ctx, + secure=secure, + optional=optional, + ) + + def shutdown_streamer(self): + self.streamer.shutdown() + + def shutdown_messaging(self): + self.shutdown_streamer() + self.shutdown_reliable_messenger() diff --git a/nvflare/private/rm_runner.py b/nvflare/private/rm_runner.py index 46af73be4e..5ee0d8bbe9 100644 --- a/nvflare/private/rm_runner.py +++ b/nvflare/private/rm_runner.py @@ -17,6 +17,7 @@ import time import uuid +from nvflare.apis.aux_spec import AuxMessenger from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import ConfigVarName, SystemConfigs from nvflare.apis.fl_context import FLContext @@ -29,8 +30,7 @@ from nvflare.fuel.f3.message import Message from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.log_utils import get_obj_logger -from nvflare.fuel.utils.validation_utils import check_callable, check_object_type, check_positive_number, check_str -from nvflare.private.aux_runner import AuxMsgTarget, AuxRunner +from nvflare.fuel.utils.validation_utils import check_callable, check_positive_number, check_str from nvflare.security.logging import secure_format_exception, secure_format_traceback # Operation Types @@ -65,11 +65,11 @@ PROP_KEY_OP = "RM.OP" -def _extract_result(reply: dict, target: AuxMsgTarget): +def _extract_result(reply: dict, target: str): err_rc = ReturnCode.COMMUNICATION_ERROR if not isinstance(reply, dict): return make_reply(err_rc), err_rc - result = reply.get(target.name) + result = reply.get(target) if not result: return make_reply(err_rc), err_rc return result, result.get_return_code() @@ -118,11 +118,9 @@ def process(self, request: Shareable, fl_ctx: FLContext) -> Shareable: op = request.get_header(HEADER_OP) peer_ctx = fl_ctx.get_peer_context() assert isinstance(peer_ctx, FLContext) - source_name = peer_ctx.get_identity_name() + self.source = peer_ctx.get_identity_name() msg = request.get_cell_message() assert isinstance(msg, Message) - source_fqcn = msg.get_header(MessageHeaderKey.ORIGIN) - self.source = AuxMsgTarget(source_name, source_fqcn) self.msg_secure = msg.get_header(MessageHeaderKey.SECURE, False) self.msg_optional = msg.get_header(MessageHeaderKey.OPTIONAL, False) @@ -244,7 +242,7 @@ def process(self, reply: Shareable) -> Shareable: class ReliableMessenger(FLComponent): - def __init__(self, aux_runner: AuxRunner): + def __init__(self, aux_runner: AuxMessenger): FLComponent.__init__(self) self.aux_runner = aux_runner self.registry = Registry() @@ -457,15 +455,15 @@ def debug(self, fl_ctx: FLContext, msg: str): def send_request( self, - target: AuxMsgTarget, + target: str, channel: str, topic: str, request: Shareable, per_msg_timeout: float, tx_timeout: float, fl_ctx: FLContext, - secure=False, optional=False, + secure=False, ) -> Shareable: """Send a request reliably. @@ -488,7 +486,7 @@ def send_request( the request will be sent only once without retrying. """ - check_object_type("target", target, AuxMsgTarget) + check_str("target", target) check_positive_number("per_msg_timeout", per_msg_timeout) if tx_timeout: check_positive_number("tx_timeout", tx_timeout) @@ -520,7 +518,7 @@ def send_request( def _send_request( self, - target: AuxMsgTarget, + target: str, request: Shareable, fl_ctx: FLContext, receiver: _ReplyReceiver, @@ -602,7 +600,7 @@ def _send_request( def _query_result( self, - target: AuxMsgTarget, + target: str, abort_signal: Signal, fl_ctx: FLContext, receiver: _ReplyReceiver, diff --git a/nvflare/private/stream_runner.py b/nvflare/private/stream_runner.py index 30a0cf7a3b..ab5ea79696 100644 --- a/nvflare/private/stream_runner.py +++ b/nvflare/private/stream_runner.py @@ -17,6 +17,7 @@ from threading import Lock from typing import Any, List, Tuple +from nvflare.apis.aux_spec import AuxMessenger from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext @@ -26,7 +27,6 @@ from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.fuel.utils.validation_utils import check_callable, check_object_type, check_str -from nvflare.private.aux_runner import AuxMsgTarget, AuxRunner from nvflare.security.logging import secure_format_exception # Topics for streaming messages @@ -103,9 +103,9 @@ def stream_done(self, rc: str, fl_ctx: FLContext): class ObjectStreamer(FLComponent): - def __init__(self, aux_runner: AuxRunner): + def __init__(self, messenger: AuxMessenger): FLComponent.__init__(self) - self.aux_runner = aux_runner + self.messenger = messenger self.registry = Registry() self.tx_lock = Lock() self.tx_table = {} # tx_id => _ProcessorInfo @@ -115,11 +115,11 @@ def __init__(self, aux_runner: AuxRunner): max_concurrent_streaming_sessions = ConfigService.get_int_var("max_concurrent_streaming_sessions", default=20) self.streaming_executor = ThreadPoolExecutor(max_workers=max_concurrent_streaming_sessions) - aux_runner.register_aux_message_handler( + messenger.register_aux_message_handler( topic=TOPIC_STREAM_REQUEST, message_handle_func=self._handle_request, ) - aux_runner.register_aux_message_handler( + messenger.register_aux_message_handler( topic=TOPIC_STREAM_ABORT, message_handle_func=self._handle_abort, ) @@ -307,7 +307,7 @@ def _handle_request(self, topic: str, request: Shareable, fl_ctx: FLContext) -> def _notify_abort_streaming( self, - targets: List[AuxMsgTarget], + targets: List[str], tx_id: str, secure: bool, fl_ctx: FLContext, @@ -325,7 +325,7 @@ def _notify_abort_streaming( """ msg = make_reply(ReturnCode.TASK_ABORTED) msg.set_header(HeaderKey.TX_ID, tx_id) - self.aux_runner.send_aux_request( + self.messenger.send_aux_request( targets=targets, topic=TOPIC_STREAM_ABORT, request=msg, @@ -340,7 +340,7 @@ def stream( channel: str, topic: str, stream_ctx: StreamContext, - targets: List[AuxMsgTarget], + targets: List[str], producer: ObjectProducer, fl_ctx: FLContext, secure=False, @@ -420,7 +420,7 @@ def stream( seq += 1 # broadcast the message to all targets - replies = self.aux_runner.send_aux_request( + replies = self.messenger.send_aux_request( topic=TOPIC_STREAM_REQUEST, targets=targets, request=request, @@ -448,7 +448,7 @@ def stream_no_wait( channel: str, topic: str, stream_ctx: StreamContext, - targets: List[AuxMsgTarget], + targets: List[str], producer: ObjectProducer, fl_ctx: FLContext, secure=False,