Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix StreamConsumer wakeup races #666

Merged
merged 3 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,21 @@ impl NativeClient {
}
}

pub(crate) enum EventPollResult<T> {
None,
EventConsumed,
Event(T),
}

impl<T> From<EventPollResult<T>> for Option<T> {
fn from(val: EventPollResult<T>) -> Self {
match val {
EventPollResult::None | EventPollResult::EventConsumed => None,
EventPollResult::Event(evt) => Some(evt),
}
}
}

/// A low-level rdkafka client.
///
/// This type is the basis of the consumers and producers in the [`consumer`]
Expand Down Expand Up @@ -278,31 +293,42 @@ impl<C: ClientContext> Client<C> {
&self.context
}

pub(crate) fn poll_event(&self, queue: &NativeQueue, timeout: Timeout) -> Option<NativeEvent> {
pub(crate) fn poll_event(
&self,
queue: &NativeQueue,
timeout: Timeout,
) -> EventPollResult<NativeEvent> {
let event = unsafe { NativeEvent::from_ptr(queue.poll(timeout)) };
if let Some(ev) = event {
let evtype = unsafe { rdsys::rd_kafka_event_type(ev.ptr()) };
match evtype {
rdsys::RD_KAFKA_EVENT_LOG => self.handle_log_event(ev.ptr()),
rdsys::RD_KAFKA_EVENT_STATS => self.handle_stats_event(ev.ptr()),
rdsys::RD_KAFKA_EVENT_LOG => {
self.handle_log_event(ev.ptr());
return EventPollResult::EventConsumed;
}
rdsys::RD_KAFKA_EVENT_STATS => {
self.handle_stats_event(ev.ptr());
return EventPollResult::EventConsumed;
}
rdsys::RD_KAFKA_EVENT_ERROR => {
// rdkafka reports consumer errors via RD_KAFKA_EVENT_ERROR but producer errors gets
// embedded on the ack returned via RD_KAFKA_EVENT_DR. Hence we need to return this event
// for the consumer case in order to return the error to the user.
self.handle_error_event(ev.ptr());
return Some(ev);
return EventPollResult::Event(ev);
}
rdsys::RD_KAFKA_EVENT_OAUTHBEARER_TOKEN_REFRESH => {
if C::ENABLE_REFRESH_OAUTH_TOKEN {
self.handle_oauth_refresh_event(ev.ptr());
}
return EventPollResult::EventConsumed;
}
_ => {
return Some(ev);
return EventPollResult::Event(ev);
}
}
}
None
EventPollResult::None
}

fn handle_log_event(&self, event: *mut RDKafkaEvent) {
Expand Down
84 changes: 47 additions & 37 deletions src/consumer/base_consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use log::{error, warn};
use rdkafka_sys as rdsys;
use rdkafka_sys::types::*;

use crate::client::{Client, NativeClient, NativeQueue};
use crate::client::{Client, EventPollResult, NativeClient, NativeQueue};
use crate::config::{
ClientConfig, FromClientConfig, FromClientConfigAndContext, NativeClientConfig,
};
Expand Down Expand Up @@ -115,59 +115,69 @@ where
///
/// The returned message lives in the memory of the consumer and cannot outlive it.
pub fn poll<T: Into<Timeout>>(&self, timeout: T) -> Option<KafkaResult<BorrowedMessage<'_>>> {
self.poll_queue(self.get_queue(), timeout)
self.poll_queue(self.get_queue(), timeout).into()
}

pub(crate) fn poll_queue<T: Into<Timeout>>(
&self,
queue: &NativeQueue,
timeout: T,
) -> Option<KafkaResult<BorrowedMessage<'_>>> {
) -> EventPollResult<KafkaResult<BorrowedMessage<'_>>> {
let now = Instant::now();
let mut timeout = timeout.into();
let initial_timeout = timeout.into();
let mut timeout = initial_timeout;
let min_poll_interval = self.context().main_queue_min_poll_interval();
loop {
let op_timeout = std::cmp::min(timeout, min_poll_interval);
let maybe_event = self.client().poll_event(queue, op_timeout);
if let Some(event) = maybe_event {
let evtype = unsafe { rdsys::rd_kafka_event_type(event.ptr()) };
match evtype {
rdsys::RD_KAFKA_EVENT_FETCH => {
if let Some(result) = self.handle_fetch_event(event) {
return Some(result);
match maybe_event {
EventPollResult::Event(event) => {
let evtype = unsafe { rdsys::rd_kafka_event_type(event.ptr()) };
match evtype {
rdsys::RD_KAFKA_EVENT_FETCH => {
if let Some(result) = self.handle_fetch_event(event) {
return EventPollResult::Event(result);
}
}
}
rdsys::RD_KAFKA_EVENT_ERROR => {
if let Some(err) = self.handle_error_event(event) {
return Some(Err(err));
rdsys::RD_KAFKA_EVENT_ERROR => {
if let Some(err) = self.handle_error_event(event) {
return EventPollResult::Event(Err(err));
}
}
}
rdsys::RD_KAFKA_EVENT_REBALANCE => {
self.handle_rebalance_event(event);
if timeout != Timeout::Never {
return None;
rdsys::RD_KAFKA_EVENT_REBALANCE => {
self.handle_rebalance_event(event);
if timeout != Timeout::Never {
return EventPollResult::EventConsumed;
}
}
}
rdsys::RD_KAFKA_EVENT_OFFSET_COMMIT => {
self.handle_offset_commit_event(event);
if timeout != Timeout::Never {
return None;
rdsys::RD_KAFKA_EVENT_OFFSET_COMMIT => {
self.handle_offset_commit_event(event);
if timeout != Timeout::Never {
return EventPollResult::EventConsumed;
}
}
_ => {
let evname = unsafe {
let evname = rdsys::rd_kafka_event_name(event.ptr());
CStr::from_ptr(evname).to_string_lossy()
};
warn!("Ignored event '{evname}' on consumer poll");
}
}
_ => {
let evname = unsafe {
let evname = rdsys::rd_kafka_event_name(event.ptr());
CStr::from_ptr(evname).to_string_lossy()
};
warn!("Ignored event '{evname}' on consumer poll");
}
EventPollResult::None => {
timeout = initial_timeout.saturating_sub(now.elapsed());
if timeout.is_zero() {
return EventPollResult::None;
}
}
}

timeout = timeout.saturating_sub(now.elapsed());
if timeout.is_zero() {
return None;
}
EventPollResult::EventConsumed => {
timeout = initial_timeout.saturating_sub(now.elapsed());
if timeout.is_zero() {
return EventPollResult::EventConsumed;
}
}
};
}
}

Expand Down Expand Up @@ -802,7 +812,7 @@ where
/// associated consumer regularly, even if no messages are expected, to
/// serve events.
pub fn poll<T: Into<Timeout>>(&self, timeout: T) -> Option<KafkaResult<BorrowedMessage<'_>>> {
self.consumer.poll_queue(&self.queue, timeout)
self.consumer.poll_queue(&self.queue, timeout).into()
}

/// Sets a callback that will be invoked whenever the queue becomes
Expand Down
56 changes: 35 additions & 21 deletions src/consumer/stream_consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use slab::Slab;
use rdkafka_sys as rdsys;
use rdkafka_sys::types::*;

use crate::client::{Client, NativeQueue};
use crate::client::{Client, EventPollResult, NativeQueue};
use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext};
use crate::consumer::base_consumer::{BaseConsumer, PartitionQueue};
use crate::consumer::{
Expand Down Expand Up @@ -122,11 +122,12 @@ impl<'a, C: ConsumerContext> MessageStream<'a, C> {
}
}

fn poll(&self) -> Option<KafkaResult<BorrowedMessage<'a>>> {
fn poll(&self) -> EventPollResult<KafkaResult<BorrowedMessage<'a>>> {
if let Some(queue) = self.partition_queue {
self.consumer.poll_queue(queue, Duration::ZERO)
} else {
self.consumer.poll(Duration::ZERO)
self.consumer
.poll_queue(self.consumer.get_queue(), Duration::ZERO)
}
}
}
Expand All @@ -135,25 +136,38 @@ impl<'a, C: ConsumerContext> Stream for MessageStream<'a, C> {
type Item = KafkaResult<BorrowedMessage<'a>>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// If there is a message ready, yield it immediately to avoid the
// taking the lock in `self.set_waker`.
if let Some(message) = self.poll() {
return Poll::Ready(Some(message));
}

// Otherwise, we need to wait for a message to become available. Store
// the waker so that we are woken up if the queue flips from non-empty
// to empty. We have to store the waker repatedly in case this future
// migrates between tasks.
self.wakers.set_waker(self.slot, cx.waker().clone());

// Check whether a new message became available after we installed the
// waker. This avoids a race where `poll` returns None to indicate that
// the queue is empty, but the queue becomes non-empty before we've
// installed the waker.
match self.poll() {
None => Poll::Pending,
Some(message) => Poll::Ready(Some(message)),
EventPollResult::Event(message) => {
// If there is a message ready, yield it immediately to avoid the
// taking the lock in `self.set_waker`.
Poll::Ready(Some(message))
}
EventPollResult::EventConsumed => {
// Event was consumed, yield to runtime
cx.waker().wake_by_ref();
Poll::Pending
}
EventPollResult::None => {
// Otherwise, we need to wait for a message to become available. Store
// the waker so that we are woken up if the queue flips from non-empty
// to empty. We have to store the waker repatedly in case this future
// migrates between tasks.
self.wakers.set_waker(self.slot, cx.waker().clone());

// Check whether a new message became available after we installed the
// waker. This avoids a race where `poll` returns None to indicate that
// the queue is empty, but the queue becomes non-empty before we've
// installed the waker.
match self.poll() {
EventPollResult::Event(message) => Poll::Ready(Some(message)),
EventPollResult::EventConsumed => {
// Event was consumed, yield to runtime
cx.waker().wake_by_ref();
Poll::Pending
}
EventPollResult::None => Poll::Pending,
}
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/producer/base_producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ use rdkafka_sys as rdsys;
use rdkafka_sys::rd_kafka_vtype_t::*;
use rdkafka_sys::types::*;

use crate::client::{Client, NativeQueue};
use crate::client::{Client, EventPollResult, NativeQueue};
use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext};
use crate::consumer::ConsumerGroupMetadata;
use crate::error::{IsError, KafkaError, KafkaResult, RDKafkaError};
Expand Down Expand Up @@ -363,7 +363,7 @@ where
/// the message delivery callbacks.
pub fn poll<T: Into<Timeout>>(&self, timeout: T) {
let event = self.client().poll_event(&self.queue, timeout.into());
if let Some(ev) = event {
if let EventPollResult::Event(ev) = event {
let evtype = unsafe { rdsys::rd_kafka_event_type(ev.ptr()) };
match evtype {
rdsys::RD_KAFKA_EVENT_DR => self.handle_delivery_report_event(ev),
Expand Down
Loading