Skip to content

Commit

Permalink
pw_transfer: Improve Python stream reopening and closing
Browse files Browse the repository at this point in the history
This makes several changes to the way RPC streams are handled in the
Python transfer client:

- Limits RPC stream reopen attempts to a maximum number, after which
  ongoing transfers should fail.
- Refactors stream operations into a wrapper class to minimize duplicate
  code between read/write transfers.
- Improve logging around stream reopening.
- Automatically close RPC streams when the last transfer running on them
  completes.

Tested: Verified successful operation of multiple back-to-back transfers
on real hardware.

Change-Id: Ie4b3a9faacce2f2916840c9e1e9aec2cd03d6f41
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/184931
Presubmit-Verified: CQ Bot Account <[email protected]>
Reviewed-by: Jordan Brauer <[email protected]>
Commit-Queue: Alexei Frolov <[email protected]>
  • Loading branch information
frolv authored and CQ Bot Account committed Dec 21, 2023
1 parent 2ace3ee commit baed4c8
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 84 deletions.
176 changes: 97 additions & 79 deletions pw_transfer/py/pw_transfer/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Pigweed Authors
# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
Expand All @@ -17,7 +17,7 @@
import ctypes
import logging
import threading
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from pw_rpc.callback_client import BidirectionalStreamingCall
from pw_status import Status
Expand All @@ -42,6 +42,68 @@
_TransferDict = Dict[int, Transfer]


class _TransferStream:
def __init__(
self,
method,
chunk_handler: Callable[[Chunk], Any],
error_handler: Callable[[Status], Any],
max_reopen_attempts=3,
):
self._method = method
self._chunk_handler = chunk_handler
self._error_handler = error_handler
self._call: Optional[BidirectionalStreamingCall] = None
self._reopen_attempts = 0
self._max_reopen_attempts = max_reopen_attempts

def is_open(self) -> bool:
return self._call is not None

def open(self, force: bool = False) -> None:
if force or self._call is None:
self._call = self._method.invoke(
lambda _, chunk: self._on_chunk_received(chunk),
on_error=lambda _, status: self._on_stream_error(status),
)

def close(self) -> None:
if self._call is not None:
self._call.cancel()
self._call = None

def send(self, chunk: Chunk) -> None:
assert self._call is not None
self._call.send(chunk.to_message())

def _on_chunk_received(self, chunk: Chunk) -> None:
self._reopen_attempts = 0
self._chunk_handler(chunk)

def _on_stream_error(self, rpc_status: Status) -> None:
if rpc_status is Status.FAILED_PRECONDITION:
# FAILED_PRECONDITION indicates that the stream packet was not
# recognized as the stream is not open. Attempt to re-open the
# stream to allow pending transfers to continue.
self._reopen_attempts += 1
if self._reopen_attempts > self._max_reopen_attempts:
_LOG.error(
'Failed to reopen transfer stream after %d tries',
self._max_reopen_attempts,
)
self._error_handler(Status.UNAVAILABLE)
else:
_LOG.info(
'Transfer stream failed to write; attempting to re-open'
)
self.open(force=True)
else:
# Other errors are unrecoverable; clear the stream.
_LOG.error('Transfer stream shut down with status %s', rpc_status)
self._call = None
self._error_handler(rpc_status)


class Manager: # pylint: disable=too-many-instance-attributes
"""A manager for transmitting data through an RPC TransferService.
Expand Down Expand Up @@ -86,11 +148,6 @@ def __init__(
self._write_transfers: _TransferDict = {}
self._next_session_id = ctypes.c_uint32(1)

# RPC streams for read and write transfers. These are shareable by
# multiple transfers of the same type.
self._read_stream: Optional[BidirectionalStreamingCall] = None
self._write_stream: Optional[BidirectionalStreamingCall] = None

self._loop = asyncio.new_event_loop()
# Set the event loop for the current thread.
asyncio.set_event_loop(self._loop)
Expand All @@ -106,6 +163,23 @@ def __init__(
target=self._start_event_loop_thread, daemon=True
)

# RPC streams for read and write transfers. These are shareable by
# multiple transfers of the same type.
self._read_stream = _TransferStream(
self._service.Read,
lambda chunk: self._loop.call_soon_threadsafe(
self._read_chunk_queue.put_nowait, chunk
),
self._on_read_error,
)
self._write_stream = _TransferStream(
self._service.Write,
lambda chunk: self._loop.call_soon_threadsafe(
self._write_chunk_queue.put_nowait, chunk
),
self._on_write_error,
)

self._thread.start()

def __del__(self):
Expand Down Expand Up @@ -158,7 +232,7 @@ def read(
transfer = ReadTransfer(
session_id,
resource_id,
self._send_read_chunk,
self._read_stream.send,
self._end_read_transfer,
chunk_timeout_s,
initial_timeout_s,
Expand Down Expand Up @@ -225,7 +299,7 @@ def write(
session_id,
resource_id,
data,
self._send_write_chunk,
self._write_stream.send,
self._end_write_transfer,
chunk_timeout_s,
initial_timeout_s,
Expand All @@ -241,14 +315,6 @@ def write(
if not transfer.status.ok():
raise Error(transfer.resource_id, transfer.status)

def _send_read_chunk(self, chunk: Chunk) -> None:
assert self._read_stream is not None
self._read_stream.send(chunk.to_message())

def _send_write_chunk(self, chunk: Chunk) -> None:
assert self._write_stream is not None
self._write_stream.send(chunk.to_message())

def assign_session_id(self) -> int:
new_id = self._next_session_id.value

Expand Down Expand Up @@ -352,71 +418,29 @@ async def _handle_chunk(

await transfer.handle_chunk(chunk)

def _open_read_stream(self) -> None:
self._read_stream = self._service.Read.invoke(
lambda _, chunk: self._loop.call_soon_threadsafe(
self._read_chunk_queue.put_nowait, chunk
),
on_error=lambda _, status: self._on_read_error(status),
)

def _on_read_error(self, status: Status) -> None:
"""Callback for an RPC error in the read stream."""

if status is Status.FAILED_PRECONDITION:
# FAILED_PRECONDITION indicates that the stream packet was not
# recognized as the stream is not open. This could occur if the
# server resets during an active transfer. Re-open the stream to
# allow pending transfers to continue.
self._open_read_stream()
else:
# Other errors are unrecoverable. Clear the stream and cancel any
# pending transfers with an INTERNAL status as this is a system
# error.
self._read_stream = None

for transfer in self._read_transfers.values():
transfer.finish(Status.INTERNAL, skip_callback=True)
self._read_transfers.clear()

_LOG.error('Read stream shut down: %s', status)
for transfer in self._read_transfers.values():
transfer.finish(Status.INTERNAL, skip_callback=True)
self._read_transfers.clear()

def _open_write_stream(self) -> None:
self._write_stream = self._service.Write.invoke(
lambda _, chunk: self._loop.call_soon_threadsafe(
self._write_chunk_queue.put_nowait, chunk
),
on_error=lambda _, status: self._on_write_error(status),
)
_LOG.error('Read stream shut down: %s', status)

def _on_write_error(self, status: Status) -> None:
"""Callback for an RPC error in the write stream."""

if status is Status.FAILED_PRECONDITION:
# FAILED_PRECONDITION indicates that the stream packet was not
# recognized as the stream is not open. This could occur if the
# server resets during an active transfer. Re-open the stream to
# allow pending transfers to continue.
self._open_write_stream()
else:
# Other errors are unrecoverable. Clear the stream and cancel any
# pending transfers with an INTERNAL status as this is a system
# error.
self._write_stream = None

for transfer in self._write_transfers.values():
transfer.finish(Status.INTERNAL, skip_callback=True)
self._write_transfers.clear()
for transfer in self._write_transfers.values():
transfer.finish(Status.INTERNAL, skip_callback=True)
self._write_transfers.clear()

_LOG.error('Write stream shut down: %s', status)
_LOG.error('Write stream shut down: %s', status)

def _start_read_transfer(self, transfer: Transfer) -> None:
"""Begins a new read transfer, opening the stream if it isn't."""

self._read_transfers[transfer.resource_id] = transfer

if not self._read_stream:
self._open_read_stream()
self._read_stream.open()

_LOG.debug('Starting new read transfer %d', transfer.id)
self._loop.call_soon_threadsafe(
Expand All @@ -434,19 +458,15 @@ def _end_read_transfer(self, transfer: Transfer) -> None:
transfer.status,
)

# TODO(frolv): This doesn't seem to work. Investigate why.
# If no more transfers are using the read stream, close it.
# if not self._read_transfers and self._read_stream:
# self._read_stream.cancel()
# self._read_stream = None
if not self._read_transfers:
self._read_stream.close()

def _start_write_transfer(self, transfer: Transfer) -> None:
"""Begins a new write transfer, opening the stream if it isn't."""

self._write_transfers[transfer.resource_id] = transfer

if not self._write_stream:
self._open_write_stream()
self._write_stream.open()

_LOG.debug('Starting new write transfer %d', transfer.id)
self._loop.call_soon_threadsafe(
Expand All @@ -464,11 +484,9 @@ def _end_write_transfer(self, transfer: Transfer) -> None:
transfer.status,
)

# TODO(frolv): This doesn't seem to work. Investigate why.
# If no more transfers are using the write stream, close it.
# if not self._write_transfers and self._write_stream:
# self._write_stream.cancel()
# self._write_stream = None
if not self._write_transfers:
self._write_stream.close()


class Error(Exception):
Expand Down
66 changes: 61 additions & 5 deletions pw_transfer/py/tests/transfer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# Copyright 2022 The Pigweed Authors
# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
Expand Down Expand Up @@ -61,7 +61,7 @@ def setUp(self) -> None:
self._service = self._client.channel(1).rpcs.pw.transfer.Transfer

self._sent_chunks: List[transfer_pb2.Chunk] = []
self._packets_to_send: List[List[bytes]] = []
self._packets_to_send: List[List[packet_pb2.RpcPacket]] = []

def _enqueue_server_responses(
self, method: _Method, responses: Iterable[Iterable[transfer_pb2.Chunk]]
Expand All @@ -77,7 +77,7 @@ def _enqueue_server_responses(
method_id=method.value,
status=Status.OK.value,
payload=response.SerializeToString(),
).SerializeToString()
)
)
self._packets_to_send.append(serialized_group)

Expand All @@ -90,7 +90,7 @@ def _enqueue_server_error(self, method: _Method, error: Status) -> None:
service_id=_TRANSFER_SERVICE_ID,
method_id=method.value,
status=error.value,
).SerializeToString()
)
]
)

Expand All @@ -106,7 +106,8 @@ def _handle_request(self, data: bytes) -> None:
if self._packets_to_send:
responses = self._packets_to_send.pop(0)
for response in responses:
self._client.process_packet(response)
response.call_id = packet.call_id
self._client.process_packet(response.SerializeToString())

def _received_data(self) -> bytearray:
data = bytearray()
Expand Down Expand Up @@ -401,6 +402,61 @@ def test_read_transfer_error(self) -> None:
self.assertEqual(exception.resource_id, 31)
self.assertEqual(exception.status, Status.NOT_FOUND)

def test_read_transfer_reopen(self) -> None:
manager = pw_transfer.Manager(
self._service,
initial_response_timeout_s=DEFAULT_TIMEOUT_S,
default_response_timeout_s=DEFAULT_TIMEOUT_S,
)

# A FAILED_PRECONDITION error should attempt a stream reopen.
self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION)
self._enqueue_server_responses(
_Method.READ,
(
(
transfer_pb2.Chunk(
transfer_id=3,
offset=0,
data=b'xyz',
remaining_bytes=0,
),
),
),
)

# The transfer should complete following reopen, with the first chunk
# being retried.
data = manager.read(3)
self.assertEqual(data, b'xyz')
self.assertEqual(len(self._sent_chunks), 3)
self.assertEqual(self._sent_chunks[0], self._sent_chunks[1])
self.assertTrue(self._sent_chunks[-1].HasField('status'))
self.assertEqual(self._sent_chunks[-1].status, 0)

def test_read_transfer_reopen_max_attempts(self) -> None:
manager = pw_transfer.Manager(
self._service,
initial_response_timeout_s=DEFAULT_TIMEOUT_S,
default_response_timeout_s=DEFAULT_TIMEOUT_S,
)

# A FAILED_PRECONDITION error should attempt a stream reopen; enqueue
# several.
self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION)
self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION)
self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION)
self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION)
self._enqueue_server_error(_Method.READ, Status.FAILED_PRECONDITION)

with self.assertRaises(pw_transfer.Error) as context:
manager.read(81)

exception = context.exception
self.assertEqual(len(self._sent_chunks), 4)
self.assertEqual(exception.resource_id, 81)
self.assertEqual(exception.status, Status.INTERNAL)

def test_read_transfer_server_error(self) -> None:
manager = pw_transfer.Manager(
self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S
Expand Down

0 comments on commit baed4c8

Please sign in to comment.