Skip to content

Commit

Permalink
refactor(barrier): not cache actor failure in local barrier worker (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
wenym1 authored Nov 13, 2024
1 parent c21a771 commit 0ddb6eb
Showing 1 changed file with 12 additions and 52 deletions.
64 changes: 12 additions & 52 deletions src/stream/src/task/barrier_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::{BTreeSet, HashMap};
use std::collections::BTreeSet;
use std::fmt::Display;
use std::future::pending;
use std::iter::once;
use std::sync::Arc;
use std::time::Duration;

Expand Down Expand Up @@ -260,9 +261,6 @@ pub(super) struct LocalBarrierWorker {
/// Current barrier collection state.
pub(super) state: ManagedBarrierState,

/// Record all unexpected exited actors.
failure_actors: HashMap<ActorId, StreamError>,

control_stream_handle: ControlStreamHandle,

pub(super) actor_manager: Arc<StreamActorManager>,
Expand All @@ -272,9 +270,6 @@ pub(super) struct LocalBarrierWorker {
barrier_event_rx: UnboundedReceiver<LocalBarrierEvent>,

actor_failure_rx: UnboundedReceiver<(ActorId, StreamError)>,

/// Cached result of [`Self::try_find_root_failure`].
cached_root_failure: Option<ScoredStreamError>,
}

impl LocalBarrierWorker {
Expand All @@ -289,14 +284,12 @@ impl LocalBarrierWorker {
},
));
Self {
failure_actors: HashMap::default(),
state: ManagedBarrierState::new(actor_manager.clone(), shared_context.clone()),
control_stream_handle: ControlStreamHandle::empty(),
actor_manager,
current_shared_context: shared_context,
barrier_event_rx: event_rx,
actor_failure_rx: failure_rx,
cached_root_failure: None,
}
}

Expand Down Expand Up @@ -543,19 +536,6 @@ impl LocalBarrierWorker {
request.actor_ids_to_collect
);

for actor_id in &request.actor_ids_to_collect {
if self.failure_actors.contains_key(actor_id) {
// The failure actors could exit before the barrier is issued, while their
// up-downstream actors could be stuck somehow. Return error directly to trigger the
// recovery.
return Err(StreamError::barrier_send(
barrier.clone(),
*actor_id,
"actor has already failed",
));
}
}

self.state.transform_to_issued(barrier, request)?;
Ok(())
}
Expand Down Expand Up @@ -596,8 +576,7 @@ impl LocalBarrierWorker {
err: StreamError,
err_context: &'static str,
) {
self.add_failure(actor_id, err.clone());
let root_err = self.try_find_root_failure().await.unwrap(); // always `Some` because we just added one
let root_err = self.try_find_root_failure(err).await;

if let Some(actor_state) = self.state.actor_states.get(&actor_id)
&& (!actor_state.inflight_barriers.is_empty() || actor_state.is_running())
Expand All @@ -616,10 +595,7 @@ impl LocalBarrierWorker {
/// This is similar to [`Self::notify_actor_failure`], but since there's not always an actor failure,
/// the given `err` will be used if there's no root failure found.
async fn notify_other_failure(&mut self, err: StreamError, message: impl Into<String>) {
let root_err = self
.try_find_root_failure()
.await
.unwrap_or_else(|| ScoredStreamError::new(err));
let root_err = self.try_find_root_failure(err).await;

self.control_stream_handle.reset_stream_with_err(
anyhow!(root_err)
Expand All @@ -628,40 +604,24 @@ impl LocalBarrierWorker {
);
}

fn add_failure(&mut self, actor_id: ActorId, err: StreamError) {
if let Some(prev_err) = self.failure_actors.insert(actor_id, err) {
warn!(
actor_id,
prev_err = %prev_err.as_report(),
"actor error overwritten"
);
}
}

/// Collect actor errors for a while and find the one that might be the root cause.
///
/// Returns `None` if there's no actor error received.
async fn try_find_root_failure(&mut self) -> Option<ScoredStreamError> {
if self.cached_root_failure.is_some() {
return self.cached_root_failure.clone();
}

async fn try_find_root_failure(&mut self, first_err: StreamError) -> ScoredStreamError {
let mut later_errs = vec![];
// fetch more actor errors within a timeout
let _ = tokio::time::timeout(Duration::from_secs(3), async {
while let Some((actor_id, error)) = self.actor_failure_rx.recv().await {
self.add_failure(actor_id, error);
while let Some((_, error)) = self.actor_failure_rx.recv().await {
later_errs.push(error);
}
})
.await;

// Find the error with highest score.
self.cached_root_failure = self
.failure_actors
.values()
once(first_err)
.chain(later_errs.into_iter())
.map(|e| ScoredStreamError::new(e.clone()))
.max_by_key(|e| e.score);

self.cached_root_failure.clone()
.max_by_key(|e| e.score)
.expect("non-empty")
}
}

Expand Down

0 comments on commit 0ddb6eb

Please sign in to comment.