diff --git a/nvflare/fuel/f3/streaming/byte_receiver.py b/nvflare/fuel/f3/streaming/byte_receiver.py index 7e70a8113b..d1bae93079 100644 --- a/nvflare/fuel/f3/streaming/byte_receiver.py +++ b/nvflare/fuel/f3/streaming/byte_receiver.py @@ -309,22 +309,22 @@ def close(self): class ByteReceiver: + + received_stream_counter_pool = StatsPoolManager.add_counter_pool( + name="Received_Stream_Counters", + description="Counters of received streams", + counter_names=[COUNTER_NAME_RECEIVED], + ) + + received_stream_size_pool = StatsPoolManager.add_msg_size_pool( + "Received_Stream_Sizes", "Sizes of streams received (MBs)" + ) + def __init__(self, cell: CoreCell): self.cell = cell self.cell.register_request_cb(channel=STREAM_CHANNEL, topic=STREAM_DATA_TOPIC, cb=self._data_handler) self.registry = Registry() - self.received_stream_counter_pool = StatsPoolManager.add_counter_pool( - name="Received_Stream_Counters", - description="Counters of received streams", - counter_names=[COUNTER_NAME_RECEIVED], - scope=self.cell.my_info.fqcn, - ) - - self.received_stream_size_pool = StatsPoolManager.add_msg_size_pool( - "Received_Stream_Sizes", "Sizes of streams received (MBs)", scope=self.cell.my_info.fqcn - ) - def register_callback(self, channel: str, topic: str, stream_cb: Callable, *args, **kwargs): if not callable(stream_cb): raise StreamError(f"specified stream_cb {type(stream_cb)} is not callable") @@ -345,13 +345,15 @@ def _data_handler(self, message: Message): task.stop(StreamError(f"{task} No callback is registered for {task.channel}/{task.topic}")) return - self.received_stream_counter_pool.increment( - category=stream_stats_category(task.channel, task.topic, "stream"), + fqcn = self.cell.my_info.fqcn + ByteReceiver.received_stream_counter_pool.increment( + category=stream_stats_category(fqcn, task.channel, task.topic, "stream"), counter_name=COUNTER_NAME_RECEIVED, ) - self.received_stream_size_pool.record_value( - category=stream_stats_category(task.channel, task.topic, "stream"), value=task.size / ONE_MB + ByteReceiver.received_stream_size_pool.record_value( + category=stream_stats_category(fqcn, task.channel, task.topic, "stream"), + value=task.size / ONE_MB, ) stream_thread_pool.submit(self._callback_wrapper, task, callback) diff --git a/nvflare/fuel/f3/streaming/byte_streamer.py b/nvflare/fuel/f3/streaming/byte_streamer.py index be2ad46577..de82d38bb6 100644 --- a/nvflare/fuel/f3/streaming/byte_streamer.py +++ b/nvflare/fuel/f3/streaming/byte_streamer.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import threading -from typing import Optional +from typing import Callable, Optional from nvflare.fuel.f3.cellnet.core_cell import CoreCell from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey @@ -38,7 +38,7 @@ STREAM_CHUNK_SIZE = 1024 * 1024 STREAM_WINDOW_SIZE = 16 * STREAM_CHUNK_SIZE -STREAM_ACK_WAIT = 60 +STREAM_ACK_WAIT = 300 STREAM_TYPE_BYTE = "byte" STREAM_TYPE_BLOB = "blob" @@ -51,10 +51,21 @@ class TxTask: def __init__( - self, channel: str, topic: str, target: str, headers: dict, stream: Stream, secure: bool, optional: bool + self, + cell: CoreCell, + chunk_size: int, + channel: str, + topic: str, + target: str, + headers: dict, + stream: Stream, + secure: bool, + optional: bool, ): + self.cell = cell + self.chunk_size = chunk_size self.sid = gen_stream_id() - self.buffer = bytearray(ByteStreamer.get_chunk_size()) + self.buffer = wrap_view(bytearray(chunk_size)) # Optimization to send the original buffer without copying self.direct_buf: Optional[bytes] = None self.buffer_size = 0 @@ -71,200 +82,233 @@ def __init__( self.offset_ack = 0 self.secure = secure self.optional = optional + self.stopped = False - def __str__(self): - return f"Tx[SID:{self.sid} to {self.target} for {self.channel}/{self.topic}]" + self.stream_future = StreamFuture(self.sid) + self.stream_future.set_size(stream.get_size()) + self.window_size = CommConfigurator().get_streaming_window_size(STREAM_WINDOW_SIZE) + self.ack_wait = CommConfigurator().get_streaming_ack_wait(STREAM_ACK_WAIT) -class ByteStreamer: - def __init__(self, cell: CoreCell): - self.cell = cell - self.cell.register_request_cb(channel=STREAM_CHANNEL, topic=STREAM_ACK_TOPIC, cb=self._ack_handler) - self.tx_task_map = {} - self.map_lock = threading.Lock() - - self.sent_stream_counter_pool = StatsPoolManager.add_counter_pool( - name="Sent_Stream_Counters", - description="Counters of sent streams", - counter_names=[COUNTER_NAME_SENT], - scope=self.cell.my_info.fqcn, - ) - - self.sent_stream_size_pool = StatsPoolManager.add_msg_size_pool( - "Sent_Stream_Sizes", "Sizes of streams sent (MBs)", scope=self.cell.my_info.fqcn - ) - - @staticmethod - def get_chunk_size(): - return CommConfigurator().get_streaming_chunk_size(STREAM_CHUNK_SIZE) - - def send( - self, - channel: str, - topic: str, - target: str, - headers: dict, - stream: Stream, - stream_type=STREAM_TYPE_BYTE, - secure=False, - optional=False, - ) -> StreamFuture: - tx_task = TxTask(channel, topic, target, headers, stream, secure, optional) - with self.map_lock: - self.tx_task_map[tx_task.sid] = tx_task - - future = StreamFuture(tx_task.sid) - future.set_size(stream.get_size()) - tx_task.stream_future = future - tx_task.task_future = stream_thread_pool.submit(self._transmit_task, tx_task) - - self.sent_stream_counter_pool.increment( - category=stream_stats_category(channel, topic, stream_type), counter_name=COUNTER_NAME_SENT - ) - - self.sent_stream_size_pool.record_value( - category=stream_stats_category(channel, topic, stream_type), value=stream.get_size() / ONE_MB - ) - - return future + def __str__(self): + return f"Tx[SID:{self.sid} to {self.target} for {self.channel}/{self.topic}]" - def _transmit_task(self, task: TxTask): + def send_loop(self): + """Read/send loop to transmit the whole stream with flow control""" - chunk_size = self.get_chunk_size() - while True: - buf = task.stream.read(chunk_size) + while not self.stopped: + buf = self.stream.read(self.chunk_size) if not buf: # End of Stream - self._transmit(task, final=True) - self._stop_task(task) + self.send_pending_buffer(final=True) + self.stop() return # Flow control - window = task.offset - task.offset_ack + window = self.offset - self.offset_ack # It may take several ACKs to clear up the window - window_size = CommConfigurator().get_streaming_window_size(STREAM_WINDOW_SIZE) - while window > window_size: - log.debug(f"{task} window size {window} exceeds limit: {window_size}") - task.ack_waiter.clear() - ack_wait = CommConfigurator().get_streaming_ack_wait(STREAM_ACK_WAIT) - if not task.ack_waiter.wait(timeout=ack_wait): - self._stop_task(task, StreamError(f"{task} ACK timeouts after {ack_wait} seconds")) + while window > self.window_size: + log.debug(f"{self} window size {window} exceeds limit: {self.window_size}") + self.ack_waiter.clear() + + if not self.ack_waiter.wait(timeout=self.ack_wait): + self.stop(StreamError(f"{self} ACK timeouts after {self.ack_wait} seconds")) return - window = task.offset - task.offset_ack + window = self.offset - self.offset_ack size = len(buf) - if size > chunk_size: - raise StreamError(f"Stream returns invalid size: {size} for {task}") - if size + task.buffer_size > chunk_size: - self._transmit(task) + if size > self.chunk_size: + raise StreamError(f"{self} Stream returns invalid size: {size}") + + # Don't push out chunk when it's equal, wait till next round to detect EOS + # For example, if the stream size is chunk size (1M), this avoids sending two chunks. + if size + self.buffer_size > self.chunk_size: + self.send_pending_buffer() - if size == chunk_size: - task.direct_buf = buf + if size == self.chunk_size: + self.direct_buf = buf else: - task.buffer[task.buffer_size : task.buffer_size + size] = buf - task.buffer_size += size + self.buffer[self.buffer_size : self.buffer_size + size] = buf + self.buffer_size += size - def _transmit(self, task: TxTask, final=False): + def send_pending_buffer(self, final=False): - if task.buffer_size == 0: + if self.buffer_size == 0: payload = bytes(0) - elif task.buffer_size == self.get_chunk_size(): - if task.direct_buf: - payload = task.direct_buf + elif self.buffer_size == self.chunk_size: + if self.direct_buf: + payload = self.direct_buf else: - payload = task.buffer + payload = self.buffer else: - payload = wrap_view(task.buffer)[0 : task.buffer_size] + payload = self.buffer[0 : self.buffer_size] message = Message(None, payload) - if task.headers: - message.add_headers(task.headers) + if self.headers: + message.add_headers(self.headers) message.add_headers( { - StreamHeaderKey.CHANNEL: task.channel, - StreamHeaderKey.TOPIC: task.topic, - StreamHeaderKey.SIZE: task.stream.get_size(), - StreamHeaderKey.STREAM_ID: task.sid, + StreamHeaderKey.CHANNEL: self.channel, + StreamHeaderKey.TOPIC: self.topic, + StreamHeaderKey.SIZE: self.stream.get_size(), + StreamHeaderKey.STREAM_ID: self.sid, StreamHeaderKey.DATA_TYPE: StreamDataType.FINAL if final else StreamDataType.CHUNK, - StreamHeaderKey.SEQUENCE: task.seq, - StreamHeaderKey.OFFSET: task.offset, - StreamHeaderKey.OPTIONAL: task.optional, + StreamHeaderKey.SEQUENCE: self.seq, + StreamHeaderKey.OFFSET: self.offset, + StreamHeaderKey.OPTIONAL: self.optional, } ) errors = self.cell.fire_and_forget( - STREAM_CHANNEL, STREAM_DATA_TOPIC, task.target, message, secure=task.secure, optional=task.optional + STREAM_CHANNEL, STREAM_DATA_TOPIC, self.target, message, secure=self.secure, optional=self.optional ) - error = errors.get(task.target) + error = errors.get(self.target) if error: - msg = f"Message sending error to target {task.target}: {error}" - log.debug(msg) - self._stop_task(task, StreamError(msg)) + msg = f"{self} Message sending error to target {self.target}: {error}" + self.stop(StreamError(msg)) return # Update state - task.seq += 1 - task.offset += task.buffer_size - task.buffer_size = 0 - task.direct_buf = None + self.seq += 1 + self.offset += self.buffer_size + self.buffer_size = 0 + self.direct_buf = None # Update future - task.stream_future.set_progress(task.offset) + self.stream_future.set_progress(self.offset) - def _stop_task(self, task: TxTask, error: StreamError = None, notify=True): - with self.map_lock: - self.tx_task_map.pop(task.sid, None) + def stop(self, error: Optional[StreamError] = None, notify=True): - if error: - log.debug(f"Stream error: {error}") - if task.stream_future: - task.stream_future.set_exception(error) - - if notify: - message = Message(None, None) - - if task.headers: - message.add_headers(task.headers) - - message.add_headers( - { - StreamHeaderKey.STREAM_ID: task.sid, - StreamHeaderKey.DATA_TYPE: StreamDataType.ERROR, - StreamHeaderKey.OFFSET: task.offset, - StreamHeaderKey.ERROR_MSG: str(error), - } - ) - self.cell.fire_and_forget( - STREAM_CHANNEL, STREAM_DATA_TOPIC, task.target, message, secure=task.secure, optional=True - ) - else: + self.stopped = True + + if not error: # Result is the number of bytes streamed - if task.stream_future: - task.stream_future.set_result(task.offset) + if self.stream_future: + self.stream_future.set_result(self.offset) + return - def _ack_handler(self, message: Message): - origin = message.get_header(MessageHeaderKey.ORIGIN) - sid = message.get_header(StreamHeaderKey.STREAM_ID) - offset = message.get_header(StreamHeaderKey.OFFSET, None) + # Error handling + log.debug(f"{self} Stream error: {error}") + if self.stream_future: + self.stream_future.set_exception(error) - with self.map_lock: - task = self.tx_task_map.get(sid, None) + if notify: + message = Message(None, None) - if not task: - # Last few ACKs always arrive late so this is normal - log.debug(f"ACK for stream {sid} received late from {origin} with offset {offset}") - return + if self.headers: + message.add_headers(self.headers) + + message.add_headers( + { + StreamHeaderKey.STREAM_ID: self.sid, + StreamHeaderKey.DATA_TYPE: StreamDataType.ERROR, + StreamHeaderKey.OFFSET: self.offset, + StreamHeaderKey.ERROR_MSG: str(error), + } + ) + self.cell.fire_and_forget( + STREAM_CHANNEL, STREAM_DATA_TOPIC, self.target, message, secure=self.secure, optional=True + ) + def handle_ack(self, message: Message): + + origin = message.get_header(MessageHeaderKey.ORIGIN) + offset = message.get_header(StreamHeaderKey.OFFSET, None) error = message.get_header(StreamHeaderKey.ERROR_MSG, None) + if error: - self._stop_task(task, StreamError(f"Received error from {origin}: {error}"), notify=False) + self.stop(StreamError(f"{self} Received error from {origin}: {error}"), notify=False) return - if offset > task.offset_ack: - task.offset_ack = offset + if offset > self.offset_ack: + self.offset_ack = offset + + if not self.ack_waiter.is_set(): + self.ack_waiter.set() + + def start_task_thread(self, task_handler: Callable): + self.task_future = stream_thread_pool.submit(task_handler, self) + + +class ByteStreamer: + + tx_task_map = {} + map_lock = threading.Lock() + + sent_stream_counter_pool = StatsPoolManager.add_counter_pool( + name="Sent_Stream_Counters", + description="Counters of sent streams", + counter_names=[COUNTER_NAME_SENT], + ) + + sent_stream_size_pool = StatsPoolManager.add_msg_size_pool("Sent_Stream_Sizes", "Sizes of streams sent (MBs)") + + def __init__(self, cell: CoreCell): + self.cell = cell + self.cell.register_request_cb(channel=STREAM_CHANNEL, topic=STREAM_ACK_TOPIC, cb=self._ack_handler) + self.chunk_size = CommConfigurator().get_streaming_chunk_size(STREAM_CHUNK_SIZE) + + def get_chunk_size(self): + return self.chunk_size + + def send( + self, + channel: str, + topic: str, + target: str, + headers: dict, + stream: Stream, + stream_type=STREAM_TYPE_BYTE, + secure=False, + optional=False, + ) -> StreamFuture: + tx_task = TxTask(self.cell, self.chunk_size, channel, topic, target, headers, stream, secure, optional) + with ByteStreamer.map_lock: + ByteStreamer.tx_task_map[tx_task.sid] = tx_task + + tx_task.start_task_thread(self._transmit_task) + + fqcn = self.cell.my_info.fqcn + ByteStreamer.sent_stream_counter_pool.increment( + category=stream_stats_category(fqcn, channel, topic, stream_type), counter_name=COUNTER_NAME_SENT + ) + + ByteStreamer.sent_stream_size_pool.record_value( + category=stream_stats_category(fqcn, channel, topic, stream_type), value=stream.get_size() / ONE_MB + ) + + return tx_task.stream_future + + @staticmethod + def _transmit_task(task: TxTask): + + try: + task.send_loop() + except Exception as ex: + msg = f"{task} Error while sending: {ex}" + log.error(msg) + task.stop(StreamError(msg), True) + finally: + # Delete task after it's sent + with ByteStreamer.map_lock: + ByteStreamer.tx_task_map.pop(task.sid, None) + log.debug(f"{task} is removed") + + @staticmethod + def _ack_handler(message: Message): + + sid = message.get_header(StreamHeaderKey.STREAM_ID) + with ByteStreamer.map_lock: + tx_task = ByteStreamer.tx_task_map.get(sid, None) + + if not tx_task: + origin = message.get_header(MessageHeaderKey.ORIGIN) + offset = message.get_header(StreamHeaderKey.OFFSET, None) + # Last few ACKs always arrive late so this is normal + log.debug(f"ACK for stream {sid} received late from {origin} with offset {offset}") + return - if not task.ack_waiter.is_set(): - task.ack_waiter.set() + tx_task.handle_ack(message) diff --git a/nvflare/fuel/f3/streaming/stream_utils.py b/nvflare/fuel/f3/streaming/stream_utils.py index 88d90bc7b2..6be7d2097f 100644 --- a/nvflare/fuel/f3/streaming/stream_utils.py +++ b/nvflare/fuel/f3/streaming/stream_utils.py @@ -95,8 +95,8 @@ def __len__(self): return self.size -def stream_stats_category(channel: str, topic: str, stream_type: str = "byte"): - return f"{stream_type}:{channel}:{topic}" +def stream_stats_category(fqcn: str, channel: str, topic: str, stream_type: str = "byte"): + return f"{fqcn}:{stream_type}:{channel}:{topic}" def stream_shutdown():