Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EH Pyproto] Async recv perf improvement #23122

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ async def _set_state(self, new_state):
self.state = new_state
_LOGGER.info("Connection '%s' state changed: %r -> %r", self._container_id, previous_state, new_state)

await asyncio.gather(*[session._on_connection_state_change() for session in self.outgoing_endpoints.values()])
for session in self.outgoing_endpoints.values():
await session._on_connection_state_change()

async def _connect(self):
try:
Expand Down Expand Up @@ -205,11 +206,11 @@ def _get_next_outgoing_channel(self):

async def _outgoing_empty(self):
if self.network_trace:
_LOGGER.info("<- empty()", extra=self.network_trace_params)
_LOGGER.info("-> empty()", extra=self.network_trace_params)
try:
if self._can_write():
await self.transport.write(EMPTY_FRAME)
self._last_frame_sent_time = time.time()
self.last_frame_sent_time = time.time()
except (OSError, IOError, SSLError, socket.error) as exc:
self._error = AMQPConnectionError(
ErrorCondition.SocketError,
Expand Down Expand Up @@ -451,11 +452,10 @@ async def listen(self, wait=False, batch=1, **kwargs):
)
return
try:
tasks = [asyncio.ensure_future(self._listen_one_frame(**kwargs)) for _ in range(batch)]
await asyncio.gather(*tasks)
for _ in range(batch):
await asyncio.ensure_future(self._listen_one_frame(**kwargs))
except ValueError:
for task in tasks:
task.cancel()
pass
annatisch marked this conversation as resolved.
Show resolved Hide resolved
except (OSError, IOError, SSLError, socket.error) as exc:
self._error = AMQPConnectionError(
ErrorCondition.SocketError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,7 @@ async def _read(self, toread, initial=False, buffer=None,
try:
while toread:
try:
# TODO: await self.reader.readexactly would not return until it has received something which
# is problematic in the case timeout is required while no frame coming in.
# asyncio.wait_for is used here for timeout control
# set socket timeout does not work, not triggering socket error maybe should be a different config?
# also we could consider using a low level socket instead of high level reader/writer
# https://docs.python.org/3/library/asyncio-eventloop.html
view[nbytes:nbytes + toread] = await asyncio.wait_for(self.reader.readexactly(toread), timeout=1)
view[nbytes:nbytes + toread] = await self.reader.readexactly(toread)
nbytes = toread
except asyncio.IncompleteReadError as exc:
pbytes = len(exc.partial)
Expand Down
2 changes: 1 addition & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# Licensed under the MIT License.
# ------------------------------------

VERSION = "5.8.0b3"
VERSION = "5.8.0a3"
annatisch marked this conversation as resolved.
Show resolved Hide resolved
111 changes: 64 additions & 47 deletions sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def __init__(self, client: "EventHubConsumerClient", source: str, **kwargs) -> N
)
self._message_buffer = deque() # type: Deque[Message]
self._last_received_event = None # type: Optional[EventData]
self._message_buffer_lock = asyncio.Lock()
self._last_callback_called_time = None
self._callback_task_run = None

def _create_handler(self, auth: "JWTTokenAuthAsync") -> None:
source = Source(self._source, filters={})
Expand Down Expand Up @@ -162,7 +165,8 @@ async def _open_with_retry(self) -> None:
await self._do_retryable_operation(self._open, operation_need_param=False)

async def _message_received(self, message: Message) -> None:
self._message_buffer.append(message)
async with self._message_buffer_lock:
self._message_buffer.append(message)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the lock around this achieve?
We're no longer processing the incoming message batch concurrently with gather - so shouldn't that remove the possibility of a race condition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have two coroutines interacting with self._message_buffer simultaneously -- receive task keeps adding messages while callback task keeps pumping messages, race condition between the two tasks is what I'm trying to avoid here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could investigate this further in term of optimization - it seem expensive to await the lock for every individual message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah.. it is interesting that actually in my v0, I indeed tried to batch things, however, I guess I didn't do thing correctly that the perf is not good.

I can put a todo for this


def _next_message_in_buffer(self):
# pylint:disable=protected-access
Expand All @@ -171,54 +175,67 @@ def _next_message_in_buffer(self):
self._last_received_event = event_data
return event_data

async def receive(self, batch=False, max_batch_size=300, max_wait_time=None) -> None:
async def _callback_coroutine(self, batch, max_batch_size, max_wait_time):
while self._callback_task_run:
try:
async with self._message_buffer_lock:
messages = [
self._message_buffer.popleft() for _ in range(min(max_batch_size, len(self._message_buffer)))
]
events = [EventData._from_message(message) for message in messages]
now_time = time.time()
if len(events) > 0:
await self._on_event_received(events if batch else events[0])
self._last_callback_called_time = now_time
else:
if max_wait_time and (now_time - self._last_callback_called_time) > max_wait_time:
# no events received, and need to callback
await self._on_event_received([] if batch else None)
self._last_callback_called_time = now_time
# backoff a bit to avoid throttling CPU when no events are coming
await asyncio.sleep(0.05)
except asyncio.CancelledError:
raise
yunhaoling marked this conversation as resolved.
Show resolved Hide resolved

async def _receive_coroutine(self):
max_retries = (
self._client._config.max_retries # pylint:disable=protected-access
)
has_not_fetched_once = True # ensure one trip when max_wait_time is very small
deadline = time.time() + (max_wait_time or 0) # max_wait_time can be None
while len(self._message_buffer) < max_batch_size and \
(time.time() < deadline or has_not_fetched_once):
retried_times = 0
has_not_fetched_once = False
while retried_times <= max_retries:
try:
await self._open()
await cast(ReceiveClientAsync, self._handler).do_work_async(batch=self._prefetch)
break
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as exception: # pylint: disable=broad-except
if (
retried_times = 0
while retried_times <= max_retries:
annatisch marked this conversation as resolved.
Show resolved Hide resolved
try:
await self._open()
await cast(ReceiveClientAsync, self._handler).do_work_async(batch=self._prefetch)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as exception: # pylint: disable=broad-except
if (
isinstance(exception, error.AMQPLinkError)
and exception.condition == error.ErrorCondition.LinkStolen # pylint: disable=no-member
):
raise await self._handle_exception(exception)
if not self.running: # exit by close
return
if self._last_received_event:
self._offset = self._last_received_event.offset
last_exception = await self._handle_exception(exception)
retried_times += 1
if retried_times > max_retries:
_LOGGER.info(
"%r operation has exhausted retry. Last exception: %r.",
self._name,
last_exception,
)
raise last_exception

if self._message_buffer:
while self._message_buffer:
if batch:
events_for_callback = [] # type: List[EventData]
for _ in range(min(max_batch_size, len(self._message_buffer))):
events_for_callback.append(self._next_message_in_buffer())
await self._on_event_received(events_for_callback)
else:
await self._on_event_received(self._next_message_in_buffer())
elif max_wait_time:
if batch:
await self._on_event_received([])
else:
await self._on_event_received(None)
):
raise await self._handle_exception(exception)
if not self.running: # exit by close
return
if self._last_received_event:
self._offset = self._last_received_event.offset
last_exception = await self._handle_exception(exception)
annatisch marked this conversation as resolved.
Show resolved Hide resolved
retried_times += 1
if retried_times > max_retries:
_LOGGER.info(
"%r operation has exhausted retry. Last exception: %r.",
self._name,
last_exception,
)
raise last_exception

async def receive(self, batch=False, max_batch_size=300, max_wait_time=None) -> None:
self._callback_task_run = True
self._last_callback_called_time = time.time()
callback_task = asyncio.ensure_future(self._callback_coroutine(batch, max_batch_size, max_wait_time))
receive_task = asyncio.ensure_future(self._receive_coroutine())
yunhaoling marked this conversation as resolved.
Show resolved Hide resolved

try:
await receive_task
finally:
self._callback_task_run = False
await callback_task
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,9 @@ async def on_event_received(event):
await consumer._handler.do_work_async()
assert consumer._handler._connection.state == constants.ConnectionState.END

duration = 10
now_time = time.time()
end_time = now_time + duration

while now_time < end_time:
await consumer.receive()
await asyncio.sleep(0.01)
now_time = time.time()
try:
await asyncio.wait_for(consumer.receive(), timeout=10)
except asyncio.TimeoutError:
pass

assert on_event_received.event.body_as_str() == "Event"