From 48ec3c5795c3f1bac74bb1ce908a310d0cdfa66b Mon Sep 17 00:00:00 2001 From: Ian Rose Date: Thu, 15 Jul 2021 16:44:10 +0200 Subject: [PATCH] Reduce default websocket frame size and make configurable (#5070) Supersedes #5052 . In addition to making the default websocket maximum-frame-size smaller, this makes the specific value configurable. It's somewhat redundant with distributed.comm.shard, but the constraints on websockets are sufficiently different that a separate config seems okay. This does not implement the fix in #5061, as that would read a config value for every frame, which is costly. So the config value will in general not be changed after import time. --- distributed/comm/utils.py | 11 +++++------ distributed/comm/ws.py | 8 ++++++++ distributed/distributed-schema.yaml | 14 ++++++++++++++ distributed/distributed.yaml | 3 +++ distributed/protocol/core.py | 10 ++++++++-- distributed/protocol/serialize.py | 6 ++++-- distributed/protocol/tests/test_serialize.py | 17 +++++++++++++++++ distributed/protocol/utils.py | 1 + 8 files changed, 60 insertions(+), 10 deletions(-) diff --git a/distributed/comm/utils.py b/distributed/comm/utils.py index 5301265caf5..0ce4f8f891b 100644 --- a/distributed/comm/utils.py +++ b/distributed/comm/utils.py @@ -19,19 +19,18 @@ async def to_frames( - msg, serializers=None, on_error="message", context=None, allow_offload=True + msg, + allow_offload=True, + **kwargs, ): """ Serialize a message into a list of Distributed protocol frames. + Any kwargs are forwarded to protocol.dumps(). """ def _to_frames(): try: - return list( - protocol.dumps( - msg, serializers=serializers, on_error=on_error, context=context - ) - ) + return list(protocol.dumps(msg, **kwargs)) except Exception as e: logger.info("Unserializable Message: %s", msg) logger.exception(e) diff --git a/distributed/comm/ws.py b/distributed/comm/ws.py index be7679513aa..8a64f209dbd 100644 --- a/distributed/comm/ws.py +++ b/distributed/comm/ws.py @@ -12,6 +12,8 @@ from tornado.iostream import StreamClosedError from tornado.websocket import WebSocketClosedError, WebSocketHandler, websocket_connect +import dask + from ..utils import ensure_bytes, nbytes from .addressing import parse_host_port, unparse_host_port from .core import Comm, CommClosedError, Connector, FatalCommClosedError, Listener @@ -22,6 +24,11 @@ logger = logging.getLogger(__name__) +BIG_BYTES_SHARD_SIZE = dask.utils.parse_bytes( + dask.config.get("distributed.comm.websockets.shard") +) + + class WSHandler(WebSocketHandler): def __init__( self, @@ -106,6 +113,7 @@ async def write(self, msg, serializers=None, on_error=None): "recipient": self.remote_info, **self.handshake_options, }, + frame_split_size=BIG_BYTES_SHARD_SIZE, ) n = struct.pack("Q", len(frames)) try: diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 440a39fb2f0..f0beee814e3 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -759,6 +759,20 @@ properties: Alternatively, the key can be appended to the cert file above, and this field left blank + websockets: + type: object + properties: + shard: + type: + - string + description: | + The maximum size of a websocket frame to send through a comm. + + This is somewhat duplicative of distributed.comm.shard, but websockets + often have much smaller maximum message sizes than othe protocols, so + this attribute is used to set a smaller default shard size and to + allow separate control of websocket message sharding. + diagnostics: type: object properties: diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 28ca8cf7bf5..86e2e5489d6 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -196,6 +196,9 @@ distributed: key: null cert: null + websockets: + shard: 8MiB + diagnostics: nvml: True diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index c4bd909ecb3..1be2d761e35 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -16,7 +16,9 @@ logger = logging.getLogger(__name__) -def dumps(msg, serializers=None, on_error="message", context=None) -> list: +def dumps( + msg, serializers=None, on_error="message", context=None, frame_split_size=None +) -> list: """Transform Python message to bytestream suitable for communication Developer Notes @@ -53,7 +55,11 @@ def _encode_default(obj): sub_header, sub_frames = obj.header, obj.frames else: sub_header, sub_frames = serialize_and_split( - obj, serializers=serializers, on_error=on_error, context=context + obj, + serializers=serializers, + on_error=on_error, + context=context, + size=frame_split_size, ) _inplace_compress_frames(sub_header, sub_frames) sub_header["num-sub-frames"] = len(sub_frames) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index c88d36a8995..51815677f5a 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -406,7 +406,9 @@ def deserialize(header, frames, deserializers=None): return loads(header, frames) -def serialize_and_split(x, serializers=None, on_error="message", context=None): +def serialize_and_split( + x, serializers=None, on_error="message", context=None, size=None +): """Serialize and split compressable frames This function is a drop-in replacement of `serialize()` that calls `serialize()` @@ -428,7 +430,7 @@ def serialize_and_split(x, serializers=None, on_error="message", context=None): frames, header.get("compression") or [None] * len(frames) ): if compression is None: # default behavior - sub_frames = frame_split_size(frame) + sub_frames = frame_split_size(frame, n=size) num_sub_frames.append(len(sub_frames)) offsets.append(len(out_frames)) out_frames.extend(sub_frames) diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index d946b01496b..36359304830 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -11,6 +11,8 @@ except ImportError: np = None +import dask + from distributed import Nanny, wait from distributed.comm.utils import from_frames, to_frames from distributed.protocol import ( @@ -442,6 +444,21 @@ def _(x): assert header["compression"] == [False, False] +@gen_test() +async def test_frame_split(): + data = b"1234abcd" * (2 ** 20) # 8 MiB + assert dask.sizeof.sizeof(data) == dask.utils.parse_bytes("8MiB") + + size = dask.utils.parse_bytes("3MiB") + split_frames = await to_frames({"x": to_serialize(data)}, frame_split_size=size) + print(split_frames) + assert len(split_frames) == 3 + 2 # Three splits and two headers + + size = dask.utils.parse_bytes("5MiB") + split_frames = await to_frames({"x": to_serialize(data)}, frame_split_size=size) + assert len(split_frames) == 2 + 2 # Two splits and two headers + + @pytest.mark.parametrize( "data,is_serializable", [ diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 3f5a2f8f500..cf4f1815ea5 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -25,6 +25,7 @@ def frame_split_size(frame, n=BIG_BYTES_SHARD_SIZE) -> list: >>> frame_split_size([b'12345', b'678'], n=3) # doctest: +SKIP [b'123', b'45', b'678'] """ + n = n or BIG_BYTES_SHARD_SIZE frame = memoryview(frame) if frame.nbytes <= n: