Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EventHub] Fix race condition when buffered mode is enabled #34712

Merged
merged 3 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/eventhub/azure-eventhub/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

### Bugs Fixed

- Fixed a bug where using `EventHubProducerClient` in buffered mode could potentially drop a buffered message without actually sending it. ([#34712](https://github.com/Azure/azure-sdk-for-python/pull/34712))

### Other Changes

- Updated network trace logging to replace `None` values in AMQP connection info with empty strings as per the OpenTelemetry specification.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,24 +105,22 @@ def put_events(self, events, timeout_time=None):
raise OperationTimeoutError(
"Failed to enqueue events into buffer due to timeout."
)
try:
# add single event into current batch
self._cur_batch.add(events)
except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer
# if there are events in cur_batch, enqueue cur_batch to the buffer
with self._lock:
with self._lock:
try:
# add single event into current batch
self._cur_batch.add(events)
except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer
# if there are events in cur_batch, enqueue cur_batch to the buffer
if self._cur_batch:
self._buffered_queue.put(self._cur_batch)
self._buffered_queue.put(events)
# create a new batch for incoming events
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
except ValueError:
# add single event exceeds the cur batch size, create new batch
with self._lock:
# create a new batch for incoming events
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
except ValueError:
# add single event exceeds the cur batch size, create new batch
self._buffered_queue.put(self._cur_batch)
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
self._cur_batch.add(events)
with self._lock:
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
self._cur_batch.add(events)
self._cur_buffered_len += new_events_len

def failsafe_callback(self, callback):
Expand All @@ -146,6 +144,7 @@ def flush(self, timeout_time=None, raise_error=True):
_LOGGER.info("Partition: %r started flushing.", self.partition_id)
if self._cur_batch: # if there is batch, enqueue it to the buffer first
self._buffered_queue.put(self._cur_batch)
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
while self._buffered_queue.qsize() > 0:
remaining_time = timeout_time - time.time() if timeout_time else None
if (remaining_time and remaining_time > 0) or remaining_time is None:
Expand Down Expand Up @@ -197,9 +196,6 @@ def flush(self, timeout_time=None, raise_error=True):
break
# after finishing flushing, reset cur batch and put it into the buffer
self._last_send_time = time.time()
#reset buffered count
self._cur_buffered_len = 0
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
_LOGGER.info("Partition %r finished flushing.", self.partition_id)

def check_max_wait_time_worker(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,24 +105,22 @@ async def put_events(self, events, timeout_time=None):
raise OperationTimeoutError(
"Failed to enqueue events into buffer due to timeout."
)
try:
# add single event into current batch
self._cur_batch.add(events)
except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer
# if there are events in cur_batch, enqueue cur_batch to the buffer
async with self._lock:
async with self._lock:
try:
# add single event into current batch
self._cur_batch.add(events)
except AttributeError: # if the input events is a EventDataBatch, put the whole into the buffer
# if there are events in cur_batch, enqueue cur_batch to the buffer
if self._cur_batch:
self._buffered_queue.put(self._cur_batch)
self._buffered_queue.put(events)
# create a new batch for incoming events
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
except ValueError:
# add single event exceeds the cur batch size, create new batch
async with self._lock:
# create a new batch for incoming events
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
except ValueError:
# add single event exceeds the cur batch size, create new batch
self._buffered_queue.put(self._cur_batch)
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
self._cur_batch.add(events)
async with self._lock:
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
self._cur_batch.add(events)
self._cur_buffered_len += new_events_len

def failsafe_callback(self, callback):
Expand Down Expand Up @@ -200,9 +198,6 @@ async def _flush(self, timeout_time=None, raise_error=True):
break
# after finishing flushing, reset cur batch and put it into the buffer
self._last_send_time = time.time()
#reset curr_buffered
self._cur_buffered_len = 0
self._cur_batch = EventDataBatch(self._max_message_size_on_link, amqp_transport=self._amqp_transport)
_LOGGER.info("Partition %r finished flushing.", self.partition_id)

async def check_max_wait_time_worker(self):
Expand Down