diff --git a/nvflare/private/fed/client/client_engine.py b/nvflare/private/fed/client/client_engine.py index 4e248233c6..6e8d536409 100644 --- a/nvflare/private/fed/client/client_engine.py +++ b/nvflare/private/fed/client/client_engine.py @@ -66,7 +66,6 @@ def __init__(self, client: FederatedClient, args, rank, workers=5): rank: local process rank workers: number of workers """ - super().__init__() MessagingEngine.__init__(self, messenger=self) self.client = client self.client_name = client.client_name diff --git a/nvflare/private/fed/client/client_run_manager.py b/nvflare/private/fed/client/client_run_manager.py index bf83f98ef9..2c1234562b 100644 --- a/nvflare/private/fed/client/client_run_manager.py +++ b/nvflare/private/fed/client/client_run_manager.py @@ -81,7 +81,6 @@ def __init__( handlers: available handlers. conf: ClientJsonConfigurator object """ - super().__init__() MessagingEngine.__init__(self, messenger=self) self.client = client self.handlers = handlers diff --git a/nvflare/private/fed/server/run_manager.py b/nvflare/private/fed/server/run_manager.py index 2691fe494a..14002ef02a 100644 --- a/nvflare/private/fed/server/run_manager.py +++ b/nvflare/private/fed/server/run_manager.py @@ -59,11 +59,7 @@ def __init__( self.client_manager = client_manager self.handlers = handlers self.aux_runner = AuxRunner(self) - self.object_streamer = ObjectStreamer(self.aux_runner) - self.reliable_messenger = ReliableMessenger(self.aux_runner) self.add_handler(self.aux_runner) - self.add_handler(self.object_streamer) - self.add_handler(self.reliable_messenger) if job_id: job_ctx_props = self.create_job_processing_context_properties(workspace, job_id) diff --git a/nvflare/private/msg_engine.py b/nvflare/private/msg_engine.py index 5c3ef74063..6bc5905094 100644 --- a/nvflare/private/msg_engine.py +++ b/nvflare/private/msg_engine.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading from typing import List from nvflare.apis.aux_spec import AuxMessenger @@ -18,18 +19,40 @@ from nvflare.apis.rm import RMEngine from nvflare.apis.shareable import Shareable from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamableEngine, StreamContext +from nvflare.apis.fl_component import FLComponent from .rm_runner import ReliableMessenger from .stream_runner import ObjectStreamer -class MessagingEngine(StreamableEngine, RMEngine): +class MessagingEngine(StreamableEngine, RMEngine, FLComponent): def __init__(self, messenger: AuxMessenger): + FLComponent.__init__(self) self.messenger = messenger - self.streamer = ObjectStreamer(messenger) - self.reliable_messenger = ReliableMessenger(messenger) + self._lock = threading.Lock() + + # We do not immediately create ObjectStreamer and ReliableMessenger here since they need to + # register aux CBs with the AuxMessenger, but the AuxMessenger may not be ready now. + # Instead, we will create them later when needed. + self.streamer = None + self.reliable_messenger = None + + def _open_streamer(self): + with self._lock: + if not self.streamer: + self.streamer = ObjectStreamer(self.messenger) + + def _open_reliable_messenger(self): + self.logger.info(f"trying to open reliable messenger for {id(self)}") + with self._lock: + if not self.reliable_messenger: + self.reliable_messenger = ReliableMessenger(self.messenger) + self.logger.info(f"reliable_messenger is opened for engine {id(self)}, aux {id(self.messenger)}") + else: + self.logger.info(f"reliable messenger for {id(self)} is already opened") def register_reliable_request_handler(self, channel: str, topic: str, handler_f, **handler_kwargs): + self._open_reliable_messenger() self.reliable_messenger.register_request_handler(channel, topic, handler_f, **handler_kwargs) def send_reliable_request( @@ -44,12 +67,15 @@ def send_reliable_request( optional=False, secure=False, ) -> Shareable: + self._open_reliable_messenger() return self.reliable_messenger.send_request( target, channel, topic, request, per_msg_timeout, tx_timeout, fl_ctx, secure, optional ) def shutdown_reliable_messenger(self): - self.reliable_messenger.shutdown() + with self._lock: + if self.reliable_messenger: + self.reliable_messenger.shutdown() def register_stream_processing( self, @@ -59,6 +85,7 @@ def register_stream_processing( stream_done_cb=None, **cb_kwargs, ): + self._open_streamer() return self.streamer.register_stream_processing(channel, topic, factory, stream_done_cb, **cb_kwargs) def stream_objects( @@ -72,6 +99,7 @@ def stream_objects( optional=False, secure=False, ): + self._open_streamer() return self.streamer.stream( channel, topic, @@ -84,7 +112,9 @@ def stream_objects( ) def shutdown_streamer(self): - self.streamer.shutdown() + with self._lock: + if self.streamer: + self.streamer.shutdown() def shutdown_messaging(self): self.shutdown_streamer()