From 0205fd406ec36ba3a2f123fd508cd9a8f3a3c632 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 28 Jul 2021 15:49:26 -0700 Subject: [PATCH 1/9] Use `nbytes` attribute for `memoryview`s --- distributed/comm/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 0088ed6500..b3b25596d5 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -202,7 +202,7 @@ async def read(self, deserializers=None): 2, range(0, frames_nbytes + C_INT_MAX, C_INT_MAX) ): chunk = frames[i:j] - chunk_nbytes = len(chunk) + chunk_nbytes = chunk.nbytes n = await stream.read_into(chunk) assert n == chunk_nbytes, (n, chunk_nbytes) except StreamClosedError as e: From bf83431755d126dc237937bc9d5ed39f8c13e364 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 28 Jul 2021 15:49:27 -0700 Subject: [PATCH 2/9] Drop extra `memoryview` cast --- distributed/comm/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index b3b25596d5..7dfe78f5eb 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -274,7 +274,7 @@ async def write(self, msg, serializers=None, on_error="message"): if isinstance(each_frame, memoryview): # Make sure that `len(data) == data.nbytes` # See - each_frame = memoryview(each_frame).cast("B") + each_frame = each_frame.cast("B") stream._write_buffer.append(each_frame) stream._total_write_index += each_frame_nbytes From 1b974d0d048202da1f978849ba268d05ef9653b2 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 28 Jul 2021 15:49:28 -0700 Subject: [PATCH 3/9] Always cast frames to `memoryview`s --- distributed/comm/tcp.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 7dfe78f5eb..e984b00dfc 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -271,10 +271,11 @@ async def write(self, msg, serializers=None, on_error="message"): if stream._write_buffer is None: raise StreamClosedError() - if isinstance(each_frame, memoryview): - # Make sure that `len(data) == data.nbytes` - # See - each_frame = each_frame.cast("B") + each_frame = memoryview(each_frame) + + # Make sure that `len(data) == data.nbytes` + # See + each_frame = each_frame.cast("B") stream._write_buffer.append(each_frame) stream._total_write_index += each_frame_nbytes From 63b0c79342d402e96d0d48d8adf85fb96c16c8e9 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 28 Jul 2021 16:04:42 -0700 Subject: [PATCH 4/9] Write frames in chunks smaller than 2GB Works around the same OpenSSL issue seen for reading except this does so for writing. As individual frames may not be this large, this may be less of an issue. Still this is a good preventative measure to protect users. --- distributed/comm/tcp.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index e984b00dfc..199f688fa8 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -277,8 +277,14 @@ async def write(self, msg, serializers=None, on_error="message"): # See each_frame = each_frame.cast("B") - stream._write_buffer.append(each_frame) - stream._total_write_index += each_frame_nbytes + # Workaround for OpenSSL 1.0.2 (can drop with OpenSSL 1.1.1) + for i, j in sliding_window( + 2, range(0, each_frame_nbytes + C_INT_MAX, C_INT_MAX) + ): + chunk = each_frame[i:j] + chunk_nbytes = chunk.nbytes + stream._write_buffer.append(chunk) + stream._total_write_index += chunk_nbytes # start writing frames stream.write(b"") From ae9df1844dd09663a818732a37c450c1d65ffe41 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 28 Jul 2021 16:07:55 -0700 Subject: [PATCH 5/9] Only send non-empty chunks --- distributed/comm/tcp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 199f688fa8..31d99cf5b2 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -283,8 +283,10 @@ async def write(self, msg, serializers=None, on_error="message"): ): chunk = each_frame[i:j] chunk_nbytes = chunk.nbytes - stream._write_buffer.append(chunk) - stream._total_write_index += chunk_nbytes + + if chunk_nbytes: + stream._write_buffer.append(chunk) + stream._total_write_index += chunk_nbytes # start writing frames stream.write(b"") From 4bc2eb174d7bcee9f684ffa0901bbecb3278491c Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 28 Jul 2021 16:09:03 -0700 Subject: [PATCH 6/9] Check stream is still open before each send --- distributed/comm/tcp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 31d99cf5b2..52775efc2c 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -268,9 +268,6 @@ async def write(self, msg, serializers=None, on_error="message"): # trick to enque all frames for writing beforehand for each_frame_nbytes, each_frame in zip(frames_nbytes, frames): if each_frame_nbytes: - if stream._write_buffer is None: - raise StreamClosedError() - each_frame = memoryview(each_frame) # Make sure that `len(data) == data.nbytes` @@ -285,6 +282,9 @@ async def write(self, msg, serializers=None, on_error="message"): chunk_nbytes = chunk.nbytes if chunk_nbytes: + if stream._write_buffer is None: + raise StreamClosedError() + stream._write_buffer.append(chunk) stream._total_write_index += chunk_nbytes From 445a38bb5ec2f0e1623d00154d1db3c79d3ac4b0 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 28 Jul 2021 16:19:25 -0700 Subject: [PATCH 7/9] Drop unneeded check All chunks will be of non-trivial size. If they are trivial, the loop would have already ended. --- distributed/comm/tcp.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 52775efc2c..2c09ec7103 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -281,12 +281,11 @@ async def write(self, msg, serializers=None, on_error="message"): chunk = each_frame[i:j] chunk_nbytes = chunk.nbytes - if chunk_nbytes: - if stream._write_buffer is None: - raise StreamClosedError() + if stream._write_buffer is None: + raise StreamClosedError() - stream._write_buffer.append(chunk) - stream._total_write_index += chunk_nbytes + stream._write_buffer.append(chunk) + stream._total_write_index += chunk_nbytes # start writing frames stream.write(b"") From c39339e90f01cf33d5de5d5799765140f8c6a630 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Fri, 10 Jun 2022 13:02:14 -0700 Subject: [PATCH 8/9] Apply suggestions from code review --- distributed/comm/tcp.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index eb207eee0b..df5c123667 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -299,16 +299,11 @@ async def write(self, msg, serializers=None, on_error="message"): # trick to enque all frames for writing beforehand for each_frame_nbytes, each_frame in zip(frames_nbytes, frames): if each_frame_nbytes: - each_frame = memoryview(each_frame) - - if isinstance(each_frame, memoryview): - # Make sure that `len(data) == data.nbytes` - # See - each_frame = ensure_memoryview(each_frame) - - # Workaround for OpenSSL 1.0.2 (can drop with OpenSSL 1.1.1) + # Make sure that `len(data) == data.nbytes` + # See + each_frame = ensure_memoryview(each_frame) for i, j in sliding_window( - 2, range(0, each_frame_nbytes + C_INT_MAX, C_INT_MAX) + 2, range(0, each_frame_nbytes + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE) ): chunk = each_frame[i:j] chunk_nbytes = chunk.nbytes From 4950d33916999eec40512f7a2c4b40dd533b922a Mon Sep 17 00:00:00 2001 From: Julia Signell Date: Wed, 22 Jun 2022 13:23:11 -0400 Subject: [PATCH 9/9] Lint --- distributed/comm/tcp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index df5c123667..10324aaaa6 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -303,7 +303,12 @@ async def write(self, msg, serializers=None, on_error="message"): # See each_frame = ensure_memoryview(each_frame) for i, j in sliding_window( - 2, range(0, each_frame_nbytes + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE) + 2, + range( + 0, + each_frame_nbytes + OPENSSL_MAX_CHUNKSIZE, + OPENSSL_MAX_CHUNKSIZE, + ), ): chunk = each_frame[i:j] chunk_nbytes = chunk.nbytes