Skip to content

Commit

Permalink
fix streamer order
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Dec 12, 2024
1 parent 6880ac2 commit 8a4e450
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 11 deletions.
1 change: 0 additions & 1 deletion nvflare/private/fed/client/client_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(self, client: FederatedClient, args, rank, workers=5):
rank: local process rank
workers: number of workers
"""
super().__init__()
MessagingEngine.__init__(self, messenger=self)
self.client = client
self.client_name = client.client_name
Expand Down
1 change: 0 additions & 1 deletion nvflare/private/fed/client/client_run_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __init__(
handlers: available handlers.
conf: ClientJsonConfigurator object
"""
super().__init__()
MessagingEngine.__init__(self, messenger=self)
self.client = client
self.handlers = handlers
Expand Down
4 changes: 0 additions & 4 deletions nvflare/private/fed/server/run_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@ def __init__(
self.client_manager = client_manager
self.handlers = handlers
self.aux_runner = AuxRunner(self)
self.object_streamer = ObjectStreamer(self.aux_runner)
self.reliable_messenger = ReliableMessenger(self.aux_runner)
self.add_handler(self.aux_runner)
self.add_handler(self.object_streamer)
self.add_handler(self.reliable_messenger)

if job_id:
job_ctx_props = self.create_job_processing_context_properties(workspace, job_id)
Expand Down
40 changes: 35 additions & 5 deletions nvflare/private/msg_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,48 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from typing import List

from nvflare.apis.aux_spec import AuxMessenger
from nvflare.apis.fl_context import FLContext
from nvflare.apis.rm import RMEngine
from nvflare.apis.shareable import Shareable
from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamableEngine, StreamContext
from nvflare.apis.fl_component import FLComponent

from .rm_runner import ReliableMessenger
from .stream_runner import ObjectStreamer


class MessagingEngine(StreamableEngine, RMEngine):
class MessagingEngine(StreamableEngine, RMEngine, FLComponent):
def __init__(self, messenger: AuxMessenger):
FLComponent.__init__(self)
self.messenger = messenger
self.streamer = ObjectStreamer(messenger)
self.reliable_messenger = ReliableMessenger(messenger)
self._lock = threading.Lock()

# We do not immediately create ObjectStreamer and ReliableMessenger here since they need to
# register aux CBs with the AuxMessenger, but the AuxMessenger may not be ready now.
# Instead, we will create them later when needed.
self.streamer = None
self.reliable_messenger = None

def _open_streamer(self):
with self._lock:
if not self.streamer:
self.streamer = ObjectStreamer(self.messenger)

def _open_reliable_messenger(self):
self.logger.info(f"trying to open reliable messenger for {id(self)}")
with self._lock:
if not self.reliable_messenger:
self.reliable_messenger = ReliableMessenger(self.messenger)
self.logger.info(f"reliable_messenger is opened for engine {id(self)}, aux {id(self.messenger)}")
else:
self.logger.info(f"reliable messenger for {id(self)} is already opened")

def register_reliable_request_handler(self, channel: str, topic: str, handler_f, **handler_kwargs):
self._open_reliable_messenger()
self.reliable_messenger.register_request_handler(channel, topic, handler_f, **handler_kwargs)

def send_reliable_request(
Expand All @@ -44,12 +67,15 @@ def send_reliable_request(
optional=False,
secure=False,
) -> Shareable:
self._open_reliable_messenger()
return self.reliable_messenger.send_request(
target, channel, topic, request, per_msg_timeout, tx_timeout, fl_ctx, secure, optional
)

def shutdown_reliable_messenger(self):
self.reliable_messenger.shutdown()
with self._lock:
if self.reliable_messenger:
self.reliable_messenger.shutdown()

def register_stream_processing(
self,
Expand All @@ -59,6 +85,7 @@ def register_stream_processing(
stream_done_cb=None,
**cb_kwargs,
):
self._open_streamer()
return self.streamer.register_stream_processing(channel, topic, factory, stream_done_cb, **cb_kwargs)

def stream_objects(
Expand All @@ -72,6 +99,7 @@ def stream_objects(
optional=False,
secure=False,
):
self._open_streamer()
return self.streamer.stream(
channel,
topic,
Expand All @@ -84,7 +112,9 @@ def stream_objects(
)

def shutdown_streamer(self):
self.streamer.shutdown()
with self._lock:
if self.streamer:
self.streamer.shutdown()

def shutdown_messaging(self):
self.shutdown_streamer()
Expand Down

0 comments on commit 8a4e450

Please sign in to comment.