Skip to content

Commit

Permalink
pw_transfer: Tidy up bazel proto imports
Browse files Browse the repository at this point in the history
This change switches from relying on `copy_file` to using try/except to
import generated protos from the location bazel expects to find them.
This is less brittle: we're actually providing the generated code as a
dependency rather than tricking the build system into treating it as
opaque "data" that the Python interpreter will hopefully find at
runtime. In particular, it should actually work internally.

See the associated bug for a discussion of other options. This one has
the smallest blast radius.

Bug: 642
Bug: b/232310150
Change-Id: I575bf96dd06346a1a024c08d2aef0914563a9776
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/94264
Reviewed-by: Anthony DiGirolamo <[email protected]>
Commit-Queue: Ted Pudlik <[email protected]>
  • Loading branch information
tpudlik authored and CQ Bot Account committed May 12, 2022
1 parent 1d8f780 commit af32de5
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 149 deletions.
3 changes: 1 addition & 2 deletions pw_transfer/integration_test/python_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from google.protobuf import text_format
from pw_hdlc.rpc import HdlcRpcClient, default_channels
import pw_transfer
from pw_transfer import transfer_pb2

from pigweed.pw_transfer import transfer_pb2
from pigweed.pw_transfer.integration_test import config_pb2

_LOG = logging.getLogger('pw_transfer_integration_test_python_client')
Expand Down
15 changes: 0 additions & 15 deletions pw_transfer/py/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# License for the specific language governing permissions and limitations under
# the License.

load("@bazel_skylib//rules:copy_file.bzl", "copy_file")

package(default_visibility = ["//visibility:public"])

licenses(["notice"])
Expand All @@ -25,9 +23,6 @@ py_library(
"pw_transfer/client.py",
"pw_transfer/transfer.py",
],
data = [
":copy_transfer_pb2",
],
imports = ["."],
deps = [
"//pw_rpc/py:pw_rpc",
Expand All @@ -36,22 +31,12 @@ py_library(
],
)

copy_file(
name = "copy_transfer_pb2",
src = "//pw_transfer:transfer_proto_pb2",
out = "pw_transfer/transfer_pb2.py",
allow_symlink = True,
)

py_test(
name = "transfer_test",
size = "small",
srcs = [
"tests/transfer_test.py",
],
data = [
":copy_transfer_pb2",
],
deps = [
":pw_transfer",
"//pw_rpc/py:pw_rpc",
Expand Down
13 changes: 9 additions & 4 deletions pw_transfer/py/pw_transfer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@

from pw_transfer.transfer import (ProgressCallback, ReadTransfer, Transfer,
WriteTransfer)
from pw_transfer.transfer_pb2 import Chunk
try:
from pw_transfer import transfer_pb2
except ImportError:
# For the bazel build, which puts generated protos in a different location.
from pigweed.pw_transfer import transfer_pb2 # type: ignore

_LOG = logging.getLogger(__package__)

Expand Down Expand Up @@ -156,11 +160,11 @@ def write(self,
if not transfer.status.ok():
raise Error(transfer.id, transfer.status)

def _send_read_chunk(self, chunk: Chunk) -> None:
def _send_read_chunk(self, chunk: transfer_pb2.Chunk) -> None:
assert self._read_stream is not None
self._read_stream.send(chunk)

def _send_write_chunk(self, chunk: Chunk) -> None:
def _send_write_chunk(self, chunk: transfer_pb2.Chunk) -> None:
assert self._write_stream is not None
self._write_stream.send(chunk)

Expand Down Expand Up @@ -216,7 +220,8 @@ async def _transfer_event_loop(self):
self._loop.stop()

@staticmethod
async def _handle_chunk(transfers: _TransferDict, chunk: Chunk) -> None:
async def _handle_chunk(transfers: _TransferDict,
chunk: transfer_pb2.Chunk) -> None:
"""Processes an incoming chunk from a stream.
The chunk is dispatched to an active transfer based on its ID. If the
Expand Down
92 changes: 52 additions & 40 deletions pw_transfer/py/pw_transfer/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from typing import Any, Callable, Optional

from pw_status import Status
from pw_transfer.transfer_pb2 import Chunk
try:
from pw_transfer import transfer_pb2
except ImportError:
# For the bazel build, which puts generated protos in a different location.
from pigweed.pw_transfer import transfer_pb2 # type: ignore

_LOG = logging.getLogger(__package__)

Expand Down Expand Up @@ -88,7 +92,7 @@ class Transfer(abc.ABC):
"""
def __init__(self,
session_id: int,
send_chunk: Callable[[Chunk], None],
send_chunk: Callable[[transfer_pb2.Chunk], None],
end_transfer: Callable[['Transfer'], None],
response_timeout_s: float,
initial_response_timeout_s: float,
Expand Down Expand Up @@ -119,10 +123,10 @@ def data(self) -> bytes:
"""Returns the data read or written in this transfer."""

@abc.abstractmethod
def _initial_chunk(self) -> Chunk:
def _initial_chunk(self) -> transfer_pb2.Chunk:
"""Returns the initial chunk to notify the sever of the transfer."""

async def handle_chunk(self, chunk: Chunk) -> None:
async def handle_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Processes an incoming chunk from the server.
Handles terminating chunks (i.e. those with a status) and forwards
Expand All @@ -145,7 +149,7 @@ async def handle_chunk(self, chunk: Chunk) -> None:
self._response_timer.start()

@abc.abstractmethod
async def _handle_data_chunk(self, chunk: Chunk) -> None:
async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Handles a chunk that contains or requests data."""

@abc.abstractmethod
Expand Down Expand Up @@ -197,9 +201,10 @@ def _update_progress(self, bytes_sent: int, bytes_confirmed_received: int,
def _send_error(self, error: Status) -> None:
"""Sends an error chunk to the server and finishes the transfer."""
self._send_chunk(
Chunk(session_id=self.id,
status=error.value,
type=Chunk.Type.TRANSFER_COMPLETION))
transfer_pb2.Chunk(
session_id=self.id,
status=error.value,
type=transfer_pb2.Chunk.Type.TRANSFER_COMPLETION))
self.finish(error)


Expand All @@ -209,7 +214,7 @@ def __init__(
self,
session_id: int,
data: bytes,
send_chunk: Callable[[Chunk], None],
send_chunk: Callable[[transfer_pb2.Chunk], None],
end_transfer: Callable[[Transfer], None],
response_timeout_s: float,
initial_response_timeout_s: float,
Expand Down Expand Up @@ -238,14 +243,14 @@ def __init__(
def data(self) -> bytes:
return self._data

def _initial_chunk(self) -> Chunk:
def _initial_chunk(self) -> transfer_pb2.Chunk:
# TODO(frolv): session_id should not be set here, but assigned by the
# server during an initial handshake.
return Chunk(session_id=self.id,
resource_id=self.id,
type=Chunk.Type.TRANSFER_START)
return transfer_pb2.Chunk(session_id=self.id,
resource_id=self.id,
type=transfer_pb2.Chunk.Type.TRANSFER_START)

async def _handle_data_chunk(self, chunk: Chunk) -> None:
async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Processes an incoming chunk from the server.
In a write transfer, the server only sends transfer parameter updates
Expand Down Expand Up @@ -288,13 +293,14 @@ async def _handle_data_chunk(self, chunk: Chunk) -> None:

self._last_chunk = write_chunk

def _handle_parameters_update(self, chunk: Chunk) -> bool:
def _handle_parameters_update(self, chunk: transfer_pb2.Chunk) -> bool:
"""Updates transfer state based on a transfer parameters update."""

retransmit = True
if chunk.HasField('type'):
retransmit = (chunk.type == Chunk.Type.PARAMETERS_RETRANSMIT
or chunk.type == Chunk.Type.TRANSFER_START)
retransmit = (
chunk.type == transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT
or chunk.type == transfer_pb2.Chunk.Type.TRANSFER_START)

if chunk.offset > len(self.data):
# Bad offset; terminate the transfer.
Expand Down Expand Up @@ -328,7 +334,7 @@ def _handle_parameters_update(self, chunk: Chunk) -> bool:
len(self.data) - self._offset)
self._window_end_offset = self._offset + max_bytes_to_send
else:
assert chunk.type == Chunk.Type.PARAMETERS_CONTINUE
assert chunk.type == transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE

# Extend the window to the new end offset specified by the server.
self._window_end_offset = min(chunk.window_end_offset,
Expand All @@ -345,11 +351,11 @@ def _handle_parameters_update(self, chunk: Chunk) -> bool:
def _retry_after_timeout(self) -> None:
self._send_chunk(self._last_chunk)

def _next_chunk(self) -> Chunk:
def _next_chunk(self) -> transfer_pb2.Chunk:
"""Returns the next Chunk message to send in the data transfer."""
chunk = Chunk(session_id=self.id,
offset=self._offset,
type=Chunk.Type.TRANSFER_DATA)
chunk = transfer_pb2.Chunk(session_id=self.id,
offset=self._offset,
type=transfer_pb2.Chunk.Type.TRANSFER_DATA)
max_bytes_in_chunk = min(self._max_chunk_size,
self._window_end_offset - self._offset)

Expand Down Expand Up @@ -382,7 +388,7 @@ class ReadTransfer(Transfer):
def __init__( # pylint: disable=too-many-arguments
self,
session_id: int,
send_chunk: Callable[[Chunk], None],
send_chunk: Callable[[transfer_pb2.Chunk], None],
end_transfer: Callable[[Transfer], None],
response_timeout_s: float,
initial_response_timeout_s: float,
Expand All @@ -409,10 +415,11 @@ def data(self) -> bytes:
"""Returns an immutable copy of the data that has been read."""
return bytes(self._data)

def _initial_chunk(self) -> Chunk:
return self._transfer_parameters(Chunk.Type.TRANSFER_START)
def _initial_chunk(self) -> transfer_pb2.Chunk:
return self._transfer_parameters(
transfer_pb2.Chunk.Type.TRANSFER_START)

async def _handle_data_chunk(self, chunk: Chunk) -> None:
async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
"""Processes an incoming chunk from the server.
In a read transfer, the client receives data chunks from the server.
Expand All @@ -424,7 +431,8 @@ async def _handle_data_chunk(self, chunk: Chunk) -> None:
# If data is received out of order, request that the server
# retransmit from the previous offset.
self._send_chunk(
self._transfer_parameters(Chunk.Type.PARAMETERS_RETRANSMIT))
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
return

self._data += chunk.data
Expand All @@ -435,9 +443,10 @@ async def _handle_data_chunk(self, chunk: Chunk) -> None:
if chunk.remaining_bytes == 0:
# No more data to read. Acknowledge receipt and finish.
self._send_chunk(
Chunk(session_id=self.id,
status=Status.OK.value,
type=Chunk.Type.TRANSFER_COMPLETION))
transfer_pb2.Chunk(
session_id=self.id,
status=Status.OK.value,
type=transfer_pb2.Chunk.Type.TRANSFER_COMPLETION))
self.finish(Status.OK)
return

Expand Down Expand Up @@ -483,27 +492,30 @@ async def _handle_data_chunk(self, chunk: Chunk) -> None:
# All pending data was received. Send out a new parameters chunk for
# the next block.
self._send_chunk(
self._transfer_parameters(Chunk.Type.PARAMETERS_RETRANSMIT))
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
elif extend_window:
self._send_chunk(
self._transfer_parameters(Chunk.Type.PARAMETERS_CONTINUE))
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE))

def _retry_after_timeout(self) -> None:
self._send_chunk(
self._transfer_parameters(Chunk.Type.PARAMETERS_RETRANSMIT))
self._transfer_parameters(
transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))

def _transfer_parameters(self, chunk_type: Any) -> Chunk:
def _transfer_parameters(self, chunk_type: Any) -> transfer_pb2.Chunk:
"""Sends an updated transfer parameters chunk to the server."""

self._pending_bytes = self._max_bytes_to_receive
self._window_end_offset = self._offset + self._max_bytes_to_receive

chunk = Chunk(session_id=self.id,
pending_bytes=self._pending_bytes,
window_end_offset=self._window_end_offset,
max_chunk_size_bytes=self._max_chunk_size,
offset=self._offset,
type=chunk_type)
chunk = transfer_pb2.Chunk(session_id=self.id,
pending_bytes=self._pending_bytes,
window_end_offset=self._window_end_offset,
max_chunk_size_bytes=self._max_chunk_size,
offset=self._offset,
type=chunk_type)

if self._chunk_delay_us:
chunk.min_delay_microseconds = self._chunk_delay_us
Expand Down
Loading

0 comments on commit af32de5

Please sign in to comment.