Skip to content

Commit

Permalink
Reduce default websocket frame size and make configurable (#5070)
Browse files Browse the repository at this point in the history
Supersedes #5052 . In addition to making the default websocket maximum-frame-size smaller, this makes the specific value configurable. It's somewhat redundant with distributed.comm.shard, but the constraints on websockets are sufficiently different that a separate config seems okay.

This does not implement the fix in #5061, as that would read a config value for every frame, which is costly. So the config value will in general not be changed after import time.
  • Loading branch information
ian-r-rose authored Jul 15, 2021
1 parent 942f235 commit 48ec3c5
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 10 deletions.
11 changes: 5 additions & 6 deletions distributed/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,18 @@


async def to_frames(
msg, serializers=None, on_error="message", context=None, allow_offload=True
msg,
allow_offload=True,
**kwargs,
):
"""
Serialize a message into a list of Distributed protocol frames.
Any kwargs are forwarded to protocol.dumps().
"""

def _to_frames():
try:
return list(
protocol.dumps(
msg, serializers=serializers, on_error=on_error, context=context
)
)
return list(protocol.dumps(msg, **kwargs))
except Exception as e:
logger.info("Unserializable Message: %s", msg)
logger.exception(e)
Expand Down
8 changes: 8 additions & 0 deletions distributed/comm/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from tornado.iostream import StreamClosedError
from tornado.websocket import WebSocketClosedError, WebSocketHandler, websocket_connect

import dask

from ..utils import ensure_bytes, nbytes
from .addressing import parse_host_port, unparse_host_port
from .core import Comm, CommClosedError, Connector, FatalCommClosedError, Listener
Expand All @@ -22,6 +24,11 @@
logger = logging.getLogger(__name__)


BIG_BYTES_SHARD_SIZE = dask.utils.parse_bytes(
dask.config.get("distributed.comm.websockets.shard")
)


class WSHandler(WebSocketHandler):
def __init__(
self,
Expand Down Expand Up @@ -106,6 +113,7 @@ async def write(self, msg, serializers=None, on_error=None):
"recipient": self.remote_info,
**self.handshake_options,
},
frame_split_size=BIG_BYTES_SHARD_SIZE,
)
n = struct.pack("Q", len(frames))
try:
Expand Down
14 changes: 14 additions & 0 deletions distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,20 @@ properties:
Alternatively, the key can be appended to the cert file
above, and this field left blank
websockets:
type: object
properties:
shard:
type:
- string
description: |
The maximum size of a websocket frame to send through a comm.
This is somewhat duplicative of distributed.comm.shard, but websockets
often have much smaller maximum message sizes than othe protocols, so
this attribute is used to set a smaller default shard size and to
allow separate control of websocket message sharding.
diagnostics:
type: object
properties:
Expand Down
3 changes: 3 additions & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ distributed:
key: null
cert: null

websockets:
shard: 8MiB

diagnostics:
nvml: True

Expand Down
10 changes: 8 additions & 2 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
logger = logging.getLogger(__name__)


def dumps(msg, serializers=None, on_error="message", context=None) -> list:
def dumps(
msg, serializers=None, on_error="message", context=None, frame_split_size=None
) -> list:
"""Transform Python message to bytestream suitable for communication
Developer Notes
Expand Down Expand Up @@ -53,7 +55,11 @@ def _encode_default(obj):
sub_header, sub_frames = obj.header, obj.frames
else:
sub_header, sub_frames = serialize_and_split(
obj, serializers=serializers, on_error=on_error, context=context
obj,
serializers=serializers,
on_error=on_error,
context=context,
size=frame_split_size,
)
_inplace_compress_frames(sub_header, sub_frames)
sub_header["num-sub-frames"] = len(sub_frames)
Expand Down
6 changes: 4 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,9 @@ def deserialize(header, frames, deserializers=None):
return loads(header, frames)


def serialize_and_split(x, serializers=None, on_error="message", context=None):
def serialize_and_split(
x, serializers=None, on_error="message", context=None, size=None
):
"""Serialize and split compressable frames
This function is a drop-in replacement of `serialize()` that calls `serialize()`
Expand All @@ -428,7 +430,7 @@ def serialize_and_split(x, serializers=None, on_error="message", context=None):
frames, header.get("compression") or [None] * len(frames)
):
if compression is None: # default behavior
sub_frames = frame_split_size(frame)
sub_frames = frame_split_size(frame, n=size)
num_sub_frames.append(len(sub_frames))
offsets.append(len(out_frames))
out_frames.extend(sub_frames)
Expand Down
17 changes: 17 additions & 0 deletions distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
except ImportError:
np = None

import dask

from distributed import Nanny, wait
from distributed.comm.utils import from_frames, to_frames
from distributed.protocol import (
Expand Down Expand Up @@ -442,6 +444,21 @@ def _(x):
assert header["compression"] == [False, False]


@gen_test()
async def test_frame_split():
data = b"1234abcd" * (2 ** 20) # 8 MiB
assert dask.sizeof.sizeof(data) == dask.utils.parse_bytes("8MiB")

size = dask.utils.parse_bytes("3MiB")
split_frames = await to_frames({"x": to_serialize(data)}, frame_split_size=size)
print(split_frames)
assert len(split_frames) == 3 + 2 # Three splits and two headers

size = dask.utils.parse_bytes("5MiB")
split_frames = await to_frames({"x": to_serialize(data)}, frame_split_size=size)
assert len(split_frames) == 2 + 2 # Two splits and two headers


@pytest.mark.parametrize(
"data,is_serializable",
[
Expand Down
1 change: 1 addition & 0 deletions distributed/protocol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def frame_split_size(frame, n=BIG_BYTES_SHARD_SIZE) -> list:
>>> frame_split_size([b'12345', b'678'], n=3) # doctest: +SKIP
[b'123', b'45', b'678']
"""
n = n or BIG_BYTES_SHARD_SIZE
frame = memoryview(frame)

if frame.nbytes <= n:
Expand Down

0 comments on commit 48ec3c5

Please sign in to comment.