Skip to content

Commit

Permalink
More careful use of memoryviews
Browse files Browse the repository at this point in the history
This should fix some bugs where the code was assuming that bytes-like
objects were actually bytes|bytearray. It is also much more careful to
explicitly release memoryview wrappers rather than relying on the
garbage collector.

Specifically:
- Queue.write (and hence Socket.send) would function incorrectly given a
  buffer-protocol object whose element size was not byte or whose shape
  was not 1D.
- Socket.recv_into assumed the object supported slice assignment with
  byte indices, rather than using a memoryview to do the assignment.

Additionally, some checks were added on the nbytes parameter to
Socket.recv_into.
  • Loading branch information
bmerry committed Mar 23, 2024
1 parent 16cd8e2 commit 0addc36
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions src/async_solipsism/socket.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 Bruce Merry
# Copyright 2020, 2024 Bruce Merry
#
# This file is part of async-solipsism.
#
Expand Down Expand Up @@ -30,6 +30,11 @@
__all__ = ('Socket', 'ListenSocket', 'Queue', 'socketpair')


def _view(bytes):
"""Get a 1D byte-typed memoryview of an object implementing the buffer protocol."""
return memoryview(bytes).cast("B")


class Queue:
def __init__(self, capacity=None):
self.capacity = capacity or DEFAULT_CAPACITY
Expand All @@ -50,11 +55,14 @@ def write(self, data):
raise BrokenPipeError(errno.EPIPE, 'Broken pipe')
if len(self) >= self.capacity:
return None
n = len(data)
if len(self) + n > self.capacity:
n = self.capacity - len(self)
data = memoryview(data)[:n]
self._buffer += data
with _view(data) as view:
n = len(view)
if len(self) + n > self.capacity:
n = self.capacity - len(self)
with view[:n] as prefix:
self._buffer += prefix
else:
self._buffer += view
return n

def read(self, size=-1):
Expand All @@ -68,7 +76,8 @@ def read(self, size=-1):
self._buffer = bytearray()
else:
n = min(size, len(self._buffer))
ret = bytes(memoryview(self._buffer))[:n]
with memoryview(self._buffer) as view, view[:n] as prefix:
ret = bytes(prefix)
self._buffer = self._buffer[n:]
return ret

Expand Down Expand Up @@ -164,11 +173,16 @@ def recv(self, bufsize, flags=0):

def recv_into(self, buffer, nbytes=0, flags=0):
# TODO: implement more efficiently?
if not nbytes:
nbytes = len(buffer)
data = self.recv(nbytes)
buffer[:len(data)] = data
return len(data)
with _view(buffer) as view:
if not nbytes:
nbytes = len(view)
if nbytes < 0:
raise ValueError("negative buffersize in recv_into")
if nbytes > len(view):
raise ValueError("buffer too small for requested bytes")
data = self.recv(nbytes)
view[:len(data)] = data
return len(data)

def send(self, bytes, flags=0):
self._check_closed()
Expand Down

0 comments on commit 0addc36

Please sign in to comment.