diff --git a/proto/stream_service.proto b/proto/stream_service.proto index 62cc8746aeca1..070f029cb1d31 100644 --- a/proto/stream_service.proto +++ b/proto/stream_service.proto @@ -12,9 +12,10 @@ option optimize_for = SPEED; message InjectBarrierRequest { string request_id = 1; stream_plan.Barrier barrier = 2; + uint32 database_id = 3; repeated uint32 actor_ids_to_collect = 4; repeated uint32 table_ids_to_sync = 5; - uint64 partial_graph_id = 6; + uint32 partial_graph_id = 6; repeated common.ActorInfo broadcast_info = 8; repeated stream_plan.StreamActor actors_to_build = 9; @@ -48,9 +49,10 @@ message BarrierCompleteResponse { uint32 worker_id = 5; map table_watermarks = 6; repeated hummock.SstableInfo old_value_sstables = 7; - uint64 partial_graph_id = 8; + uint32 partial_graph_id = 8; // prev_epoch of barrier uint64 epoch = 9; + uint32 database_id = 10; } message WaitEpochCommitRequest { @@ -64,20 +66,27 @@ message WaitEpochCommitResponse { message StreamingControlStreamRequest { message InitialPartialGraph { - uint64 partial_graph_id = 1; + uint32 partial_graph_id = 1; repeated stream_plan.SubscriptionUpstreamInfo subscriptions = 2; } + message DatabaseInitialPartialGraph { + uint32 database_id = 1; + repeated InitialPartialGraph graphs = 2; + } + message InitRequest { - repeated InitialPartialGraph graphs = 1; + repeated DatabaseInitialPartialGraph databases = 1; } message CreatePartialGraphRequest { - uint64 partial_graph_id = 1; + uint32 partial_graph_id = 1; + uint32 database_id = 2; } message RemovePartialGraphRequest { - repeated uint64 partial_graph_ids = 1; + repeated uint32 partial_graph_ids = 1; + uint32 database_id = 2; } oneof request { diff --git a/proto/task_service.proto b/proto/task_service.proto index 121d189c923df..cb14ee809d943 100644 --- a/proto/task_service.proto +++ b/proto/task_service.proto @@ -95,6 +95,7 @@ message GetStreamRequest { uint32 down_actor_id = 2; uint32 up_fragment_id = 3; uint32 down_fragment_id = 4; + uint32 database_id = 5; } oneof value { diff --git a/src/common/src/catalog/mod.rs b/src/common/src/catalog/mod.rs index 1fbabdfe57771..d64881e22c41a 100644 --- a/src/common/src/catalog/mod.rs +++ b/src/common/src/catalog/mod.rs @@ -177,7 +177,7 @@ pub struct DatabaseId { } impl DatabaseId { - pub fn new(database_id: u32) -> Self { + pub const fn new(database_id: u32) -> Self { DatabaseId { database_id } } diff --git a/src/compute/src/rpc/service/exchange_service.rs b/src/compute/src/rpc/service/exchange_service.rs index e4082a88ea9e6..7e76099edc3ab 100644 --- a/src/compute/src/rpc/service/exchange_service.rs +++ b/src/compute/src/rpc/service/exchange_service.rs @@ -19,6 +19,7 @@ use either::Either; use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; use futures_async_stream::try_stream; use risingwave_batch::task::BatchManager; +use risingwave_common::catalog::DatabaseId; use risingwave_pb::task_service::exchange_service_server::ExchangeService; use risingwave_pb::task_service::{ permits, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, @@ -93,6 +94,7 @@ impl ExchangeService for ExchangeServiceImpl { down_actor_id, up_fragment_id, down_fragment_id, + database_id, } = { let req = request_stream .next() @@ -106,7 +108,7 @@ impl ExchangeService for ExchangeServiceImpl { let receiver = self .stream_mgr - .take_receiver((up_actor_id, down_actor_id)) + .take_receiver(DatabaseId::new(database_id), (up_actor_id, down_actor_id)) .await?; // Map the remaining stream to add-permits. diff --git a/src/meta/src/barrier/checkpoint/control.rs b/src/meta/src/barrier/checkpoint/control.rs index 1d7eef6f81b5e..62375e0fac1af 100644 --- a/src/meta/src/barrier/checkpoint/control.rs +++ b/src/meta/src/barrier/checkpoint/control.rs @@ -89,7 +89,7 @@ impl CheckpointControl { resp: BarrierCompleteResponse, control_stream_manager: &mut ControlStreamManager, ) -> MetaResult<()> { - let database_id = from_partial_graph_id(resp.partial_graph_id).0; + let database_id = DatabaseId::new(resp.database_id); self.databases .get_mut(&database_id) .expect("should exist") @@ -435,8 +435,7 @@ impl DatabaseCheckpointControl { partial_graph_id = resp.partial_graph_id, "barrier collected" ); - let (database_id, creating_job_id) = from_partial_graph_id(resp.partial_graph_id); - assert_eq!(database_id, self.database_id); + let creating_job_id = from_partial_graph_id(resp.partial_graph_id); match creating_job_id { None => { if let Some(node) = self.command_ctx_queue.get_mut(&prev_epoch) { diff --git a/src/meta/src/barrier/rpc.rs b/src/meta/src/barrier/rpc.rs index dfb9f1cc13d37..ff349156d2ed7 100644 --- a/src/meta/src/barrier/rpc.rs +++ b/src/meta/src/barrier/rpc.rs @@ -31,7 +31,8 @@ use risingwave_pb::meta::PausedReason; use risingwave_pb::stream_plan::barrier_mutation::Mutation; use risingwave_pb::stream_plan::{Barrier, BarrierMutation, StreamActor, SubscriptionUpstreamInfo}; use risingwave_pb::stream_service::streaming_control_stream_request::{ - CreatePartialGraphRequest, PbInitRequest, PbInitialPartialGraph, RemovePartialGraphRequest, + CreatePartialGraphRequest, PbDatabaseInitialPartialGraph, PbInitRequest, PbInitialPartialGraph, + RemovePartialGraphRequest, }; use risingwave_pb::stream_service::{ streaming_control_stream_request, streaming_control_stream_response, BarrierCompleteResponse, @@ -54,25 +55,21 @@ use crate::{MetaError, MetaResult}; const COLLECT_ERROR_TIMEOUT: Duration = Duration::from_secs(3); -fn to_partial_graph_id(database_id: DatabaseId, job_id: Option) -> u64 { - ((database_id.database_id as u64) << u32::BITS) - | (job_id - .map(|table| { - assert_ne!(table.table_id, u32::MAX); - table.table_id - }) - .unwrap_or(u32::MAX) as u64) +fn to_partial_graph_id(job_id: Option) -> u32 { + job_id + .map(|table| { + assert_ne!(table.table_id, u32::MAX); + table.table_id + }) + .unwrap_or(u32::MAX) } -pub(super) fn from_partial_graph_id(partial_graph_id: u64) -> (DatabaseId, Option) { - let database_id = DatabaseId::new((partial_graph_id >> u32::BITS) as u32); - let job_id = (partial_graph_id & (u32::MAX as u64)) as u32; - let job_id = if job_id == u32::MAX { +pub(super) fn from_partial_graph_id(partial_graph_id: u32) -> Option { + if partial_graph_id == u32::MAX { None } else { - Some(TableId::new(job_id)) - }; - (database_id, job_id) + Some(TableId::new(partial_graph_id)) + } } struct ControlStreamNode { @@ -272,10 +269,13 @@ impl ControlStreamManager { initial_subscriptions: impl Iterator, ) -> PbInitRequest { PbInitRequest { - graphs: initial_subscriptions - .map(|(database_id, info)| PbInitialPartialGraph { - partial_graph_id: to_partial_graph_id(database_id, None), - subscriptions: info.into_iter().collect_vec(), + databases: initial_subscriptions + .map(|(database_id, info)| PbDatabaseInitialPartialGraph { + database_id: database_id.database_id, + graphs: vec![PbInitialPartialGraph { + partial_graph_id: to_partial_graph_id(None), + subscriptions: info.into_iter().collect_vec(), + }], }) .collect(), } @@ -335,7 +335,7 @@ impl ControlStreamManager { "inject_barrier_err" )); - let partial_graph_id = to_partial_graph_id(database_id, creating_table_id); + let partial_graph_id = to_partial_graph_id(creating_table_id); let node_actors = InflightFragmentInfo::actor_ids_to_collect(pre_applied_graph_info); @@ -399,6 +399,7 @@ impl ControlStreamManager { InjectBarrierRequest { request_id: Uuid::new_v4().to_string(), barrier: Some(barrier), + database_id: database_id.database_id, actor_ids_to_collect, table_ids_to_sync: table_ids_to_sync .iter() @@ -451,14 +452,17 @@ impl ControlStreamManager { database_id: DatabaseId, creating_job_id: Option, ) -> MetaResult<()> { - let partial_graph_id = to_partial_graph_id(database_id, creating_job_id); + let partial_graph_id = to_partial_graph_id(creating_job_id); self.nodes.iter().try_for_each(|(_, node)| { node.handle .request_sender .send(StreamingControlStreamRequest { request: Some( streaming_control_stream_request::Request::CreatePartialGraph( - CreatePartialGraphRequest { partial_graph_id }, + CreatePartialGraphRequest { + database_id: database_id.database_id, + partial_graph_id, + }, ), ), }) @@ -477,7 +481,7 @@ impl ControlStreamManager { } let partial_graph_ids = creating_job_ids .into_iter() - .map(|job_id| to_partial_graph_id(database_id, Some(job_id))) + .map(|job_id| to_partial_graph_id(Some(job_id))) .collect_vec(); self.nodes.iter().for_each(|(_, node)| { if node.handle @@ -487,6 +491,7 @@ impl ControlStreamManager { streaming_control_stream_request::Request::RemovePartialGraph( RemovePartialGraphRequest { partial_graph_ids: partial_graph_ids.clone(), + database_id: database_id.database_id, }, ), ), @@ -567,34 +572,3 @@ pub(super) fn merge_node_rpc_errors( }); anyhow!(concat).into() } - -#[cfg(test)] -mod tests { - use risingwave_common::catalog::{DatabaseId, TableId}; - - use crate::barrier::rpc::{from_partial_graph_id, to_partial_graph_id}; - - #[test] - fn test_partial_graph_id_convert() { - fn test_convert(database_id: u32, job_id: Option) { - let database_id = DatabaseId::new(database_id); - let job_id = job_id.map(TableId::new); - assert_eq!( - (database_id, job_id), - from_partial_graph_id(to_partial_graph_id(database_id, job_id)) - ); - } - for database_id in [0, 1, 2, u32::MAX - 1, u32::MAX >> 1] { - for job_id in [ - Some(0), - Some(1), - Some(2), - None, - Some(u32::MAX >> 1), - Some(u32::MAX - 1), - ] { - test_convert(database_id, job_id); - } - } - } -} diff --git a/src/rpc_client/src/compute_client.rs b/src/rpc_client/src/compute_client.rs index 5ff5671cb3320..b92afe9035f1a 100644 --- a/src/rpc_client/src/compute_client.rs +++ b/src/rpc_client/src/compute_client.rs @@ -17,6 +17,7 @@ use std::time::Duration; use async_trait::async_trait; use futures::StreamExt; +use risingwave_common::catalog::DatabaseId; use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, STREAM_WINDOW_SIZE}; use risingwave_common::monitor::{EndpointExt, TcpConfig}; use risingwave_common::util::addr::HostAddr; @@ -115,6 +116,7 @@ impl ComputeClient { down_actor_id: u32, up_fragment_id: u32, down_fragment_id: u32, + database_id: DatabaseId, ) -> Result<( Streaming, mpsc::UnboundedSender, @@ -132,6 +134,7 @@ impl ComputeClient { down_actor_id, up_fragment_id, down_fragment_id, + database_id: database_id.database_id, })), }, )) diff --git a/src/stream/src/executor/dispatch.rs b/src/stream/src/executor/dispatch.rs index 609fed1be038f..90e6ef9592194 100644 --- a/src/stream/src/executor/dispatch.rs +++ b/src/stream/src/executor/dispatch.rs @@ -1254,16 +1254,17 @@ mod tests { }, )); barrier_test_env.inject_barrier(&b1, [actor_id]); - barrier_test_env - .shared_context - .local_barrier_manager - .flush_all_events() - .await; + barrier_test_env.flush_all_events().await; let input = Executor::new( Default::default(), - ReceiverExecutor::for_test(actor_id, rx, barrier_test_env.shared_context.clone()) - .boxed(), + ReceiverExecutor::for_test( + actor_id, + rx, + barrier_test_env.shared_context.clone(), + barrier_test_env.local_barrier_manager.clone(), + ) + .boxed(), ); let executor = Box::new(DispatchExecutor::new( input, diff --git a/src/stream/src/executor/exchange/input.rs b/src/stream/src/executor/exchange/input.rs index 5437a4ae977ce..1c25eab15256f 100644 --- a/src/stream/src/executor/exchange/input.rs +++ b/src/stream/src/executor/exchange/input.rs @@ -161,6 +161,7 @@ pub struct RemoteInput { } use remote_input::RemoteInputStreamInner; +use risingwave_common::catalog::DatabaseId; impl RemoteInput { /// Create a remote input from compute client and related info. Should provide the corresponding @@ -170,6 +171,7 @@ impl RemoteInput { upstream_addr: HostAddr, up_down_ids: UpDownActorIds, up_down_frag: UpDownFragmentIds, + database_id: DatabaseId, metrics: Arc, batched_permits: usize, ) -> Self { @@ -182,6 +184,7 @@ impl RemoteInput { upstream_addr, up_down_ids, up_down_frag, + database_id, metrics, batched_permits, ), @@ -194,6 +197,7 @@ mod remote_input { use anyhow::Context; use await_tree::InstrumentAwait; + use risingwave_common::catalog::DatabaseId; use risingwave_common::util::addr::HostAddr; use risingwave_pb::task_service::{permits, GetStreamResponse}; use risingwave_rpc_client::ComputeClientPool; @@ -211,6 +215,7 @@ mod remote_input { upstream_addr: HostAddr, up_down_ids: UpDownActorIds, up_down_frag: UpDownFragmentIds, + database_id: DatabaseId, metrics: Arc, batched_permits_limit: usize, ) -> RemoteInputStreamInner { @@ -219,6 +224,7 @@ mod remote_input { upstream_addr, up_down_ids, up_down_frag, + database_id, metrics, batched_permits_limit, ) @@ -230,12 +236,19 @@ mod remote_input { upstream_addr: HostAddr, up_down_ids: UpDownActorIds, up_down_frag: UpDownFragmentIds, + database_id: DatabaseId, metrics: Arc, batched_permits_limit: usize, ) { let client = client_pool.get_by_addr(upstream_addr).await?; let (stream, permits_tx) = client - .get_stream(up_down_ids.0, up_down_ids.1, up_down_frag.0, up_down_frag.1) + .get_stream( + up_down_ids.0, + up_down_ids.1, + up_down_frag.0, + up_down_frag.1, + database_id, + ) .await?; let up_actor_id = up_down_ids.0.to_string(); @@ -336,6 +349,7 @@ pub(crate) fn new_input( upstream_addr, (upstream_actor_id, actor_id), (upstream_fragment_id, fragment_id), + context.database_id, metrics, context.config.developer.exchange_batched_permits, ) diff --git a/src/stream/src/executor/integration_tests.rs b/src/stream/src/executor/integration_tests.rs index 01d4ced06805c..c28c72ea70e61 100644 --- a/src/stream/src/executor/integration_tests.rs +++ b/src/stream/src/executor/integration_tests.rs @@ -66,6 +66,7 @@ async fn test_merger_sum_aggr() { fields: vec![Field::unnamed(DataType::Int64)], }; let shared_context = barrier_test_env.shared_context.clone(); + let local_barrier_manager = barrier_test_env.local_barrier_manager.clone(); let expr_context = expr_context.clone(); let (tx, rx) = channel_for_test(); let actor_future = async move { @@ -75,7 +76,13 @@ async fn test_merger_sum_aggr() { pk_indices: PkIndices::new(), identity: "ReceiverExecutor".to_string(), }, - ReceiverExecutor::for_test(actor_id, input_rx, shared_context.clone()).boxed(), + ReceiverExecutor::for_test( + actor_id, + input_rx, + shared_context.clone(), + local_barrier_manager.clone(), + ) + .boxed(), ); let agg_calls = vec![ AggCall::from_pretty("(count:int8)"), @@ -97,7 +104,7 @@ async fn test_merger_sum_aggr() { StreamingMetrics::unused().into(), actor_ctx, expr_context, - shared_context.local_barrier_manager.clone(), + local_barrier_manager.clone(), ); actor.run().await @@ -130,6 +137,7 @@ async fn test_merger_sum_aggr() { let (input, rx) = channel_for_test(); let actor_future = { let shared_context = barrier_test_env.shared_context.clone(); + let local_barrier_manager = barrier_test_env.local_barrier_manager.clone(); let expr_context = expr_context.clone(); async move { let receiver_op = Executor::new( @@ -139,7 +147,13 @@ async fn test_merger_sum_aggr() { pk_indices: PkIndices::new(), identity: "ReceiverExecutor".to_string(), }, - ReceiverExecutor::for_test(actor_id, rx, shared_context.clone()).boxed(), + ReceiverExecutor::for_test( + actor_id, + rx, + shared_context.clone(), + local_barrier_manager.clone(), + ) + .boxed(), ); let dispatcher = DispatchExecutor::new( receiver_op, @@ -160,7 +174,7 @@ async fn test_merger_sum_aggr() { StreamingMetrics::unused().into(), ActorContext::for_test(actor_id), expr_context, - shared_context.local_barrier_manager.clone(), + local_barrier_manager.clone(), ); actor.run().await } @@ -173,6 +187,7 @@ async fn test_merger_sum_aggr() { let items = Arc::new(Mutex::new(vec![])); let actor_future = { let shared_context = barrier_test_env.shared_context.clone(); + let local_barrier_manager = barrier_test_env.local_barrier_manager.clone(); let expr_context = expr_context.clone(); let items = items.clone(); async move { @@ -188,8 +203,14 @@ async fn test_merger_sum_aggr() { pk_indices: PkIndices::new(), identity: "MergeExecutor".to_string(), }, - MergeExecutor::for_test(actor_ctx.id, outputs, shared_context.clone(), schema) - .boxed(), + MergeExecutor::for_test( + actor_ctx.id, + outputs, + shared_context.clone(), + local_barrier_manager.clone(), + schema, + ) + .boxed(), ); // for global aggregator, we need to sum data and sum row count @@ -233,7 +254,7 @@ async fn test_merger_sum_aggr() { StreamingMetrics::unused().into(), actor_ctx.clone(), expr_context, - shared_context.local_barrier_manager.clone(), + local_barrier_manager.clone(), ); actor.run().await } @@ -244,11 +265,7 @@ async fn test_merger_sum_aggr() { let mut epoch = test_epoch(1); let b1 = Barrier::new_test_barrier(epoch); barrier_test_env.inject_barrier(&b1, actors.clone()); - barrier_test_env - .shared_context - .local_barrier_manager - .flush_all_events() - .await; + barrier_test_env.flush_all_events().await; let handles = actor_futures .into_iter() .map(|actor_future| tokio::spawn(actor_future)) diff --git a/src/stream/src/executor/merge.rs b/src/stream/src/executor/merge.rs index 8136662b5eaf9..fb8364dc54ee5 100644 --- a/src/stream/src/executor/merge.rs +++ b/src/stream/src/executor/merge.rs @@ -172,14 +172,13 @@ impl MergeExecutor { actor_id: ActorId, inputs: Vec, shared_context: Arc, + local_barrier_manager: crate::task::LocalBarrierManager, schema: Schema, ) -> Self { use super::exchange::input::LocalInput; use crate::executor::exchange::input::Input; - let barrier_rx = shared_context - .local_barrier_manager - .subscribe_barrier(actor_id); + let barrier_rx = local_barrier_manager.subscribe_barrier(actor_id); let metrics = StreamingMetrics::unused(); let actor_ctx = ActorContext::for_test(actor_id); @@ -806,11 +805,7 @@ mod tests { let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch) .with_mutation(Mutation::Stop(HashSet::default())); barrier_test_env.inject_barrier(&b2, [actor_id]); - barrier_test_env - .shared_context - .local_barrier_manager - .flush_all_events() - .await; + barrier_test_env.flush_all_events().await; for (tx_id, tx) in txs.into_iter().enumerate() { let epochs = epochs.clone(); @@ -845,6 +840,7 @@ mod tests { actor_id, rxs, barrier_test_env.shared_context.clone(), + barrier_test_env.local_barrier_manager.clone(), Schema::new(vec![]), ); let mut merger = merger.boxed().execute(); @@ -942,13 +938,11 @@ mod tests { }, )); barrier_test_env.inject_barrier(&b1, [actor_id]); - barrier_test_env - .shared_context - .local_barrier_manager - .flush_all_events() - .await; + barrier_test_env.flush_all_events().await; - let barrier_rx = ctx.local_barrier_manager.subscribe_barrier(actor_id); + let barrier_rx = barrier_test_env + .local_barrier_manager + .subscribe_barrier(actor_id); let actor_ctx = ActorContext::for_test(actor_id); let upstream = MergeExecutor::new_select_receiver(inputs, &metrics, &actor_ctx); @@ -1110,6 +1104,7 @@ mod tests { addr.into(), (0, 0), (0, 0), + test_env.shared_context.database_id, Arc::new(StreamingMetrics::unused()), BATCHED_PERMITS, ) diff --git a/src/stream/src/executor/receiver.rs b/src/stream/src/executor/receiver.rs index c3fd4f9f7e7e2..1fd5e04eb6804 100644 --- a/src/stream/src/executor/receiver.rs +++ b/src/stream/src/executor/receiver.rs @@ -83,13 +83,12 @@ impl ReceiverExecutor { actor_id: ActorId, input: super::exchange::permit::Receiver, shared_context: Arc, + local_barrier_manager: crate::task::LocalBarrierManager, ) -> Self { use super::exchange::input::LocalInput; use crate::executor::exchange::input::Input; - let barrier_rx = shared_context - .local_barrier_manager - .subscribe_barrier(actor_id); + let barrier_rx = local_barrier_manager.subscribe_barrier(actor_id); Self::new( ActorContext::for_test(actor_id), @@ -260,11 +259,7 @@ mod tests { )); barrier_test_env.inject_barrier(&b1, [actor_id]); - barrier_test_env - .shared_context - .local_barrier_manager - .flush_all_events() - .await; + barrier_test_env.flush_all_events().await; let input = new_input( &ctx, @@ -283,7 +278,9 @@ mod tests { input, ctx.clone(), metrics.clone(), - ctx.local_barrier_manager.subscribe_barrier(actor_id), + barrier_test_env + .local_barrier_manager + .subscribe_barrier(actor_id), ) .boxed() .execute(); diff --git a/src/stream/src/from_proto/barrier_recv.rs b/src/stream/src/from_proto/barrier_recv.rs index 21bbdece8008e..f5f7de6e9d805 100644 --- a/src/stream/src/from_proto/barrier_recv.rs +++ b/src/stream/src/from_proto/barrier_recv.rs @@ -33,7 +33,6 @@ impl ExecutorBuilder for BarrierRecvExecutorBuilder { ); let barrier_receiver = params - .shared_context .local_barrier_manager .subscribe_barrier(params.actor_context.id); diff --git a/src/stream/src/from_proto/merge.rs b/src/stream/src/from_proto/merge.rs index f56090f359eef..68ab992e0fe93 100644 --- a/src/stream/src/from_proto/merge.rs +++ b/src/stream/src/from_proto/merge.rs @@ -91,7 +91,6 @@ impl ExecutorBuilder for MergeExecutorBuilder { _store: impl StateStore, ) -> StreamResult { let barrier_rx = params - .shared_context .local_barrier_manager .subscribe_barrier(params.actor_context.id); Ok(Self::new_input( diff --git a/src/stream/src/from_proto/now.rs b/src/stream/src/from_proto/now.rs index d3cc352150292..544385436ed8e 100644 --- a/src/stream/src/from_proto/now.rs +++ b/src/stream/src/from_proto/now.rs @@ -36,7 +36,6 @@ impl ExecutorBuilder for NowExecutorBuilder { store: impl StateStore, ) -> StreamResult { let barrier_receiver = params - .shared_context .local_barrier_manager .subscribe_barrier(params.actor_context.id); diff --git a/src/stream/src/from_proto/source/trad_source.rs b/src/stream/src/from_proto/source/trad_source.rs index 4d4786eea3bfa..d3efeef6dee79 100644 --- a/src/stream/src/from_proto/source/trad_source.rs +++ b/src/stream/src/from_proto/source/trad_source.rs @@ -141,7 +141,6 @@ impl ExecutorBuilder for SourceExecutorBuilder { store: impl StateStore, ) -> StreamResult { let barrier_receiver = params - .shared_context .local_barrier_manager .subscribe_barrier(params.actor_context.id); let system_params = params.env.system_params_manager_ref().get_params(); diff --git a/src/stream/src/from_proto/values.rs b/src/stream/src/from_proto/values.rs index 10654c5f75b6a..2f803d942b263 100644 --- a/src/stream/src/from_proto/values.rs +++ b/src/stream/src/from_proto/values.rs @@ -35,7 +35,6 @@ impl ExecutorBuilder for ValuesExecutorBuilder { _store: impl StateStore, ) -> StreamResult { let barrier_receiver = params - .shared_context .local_barrier_manager .subscribe_barrier(params.actor_context.id); let progress = params diff --git a/src/stream/src/task/barrier_manager.rs b/src/stream/src/task/barrier_manager.rs index fec0d74ab6d5f..002abbf98cc67 100644 --- a/src/stream/src/task/barrier_manager.rs +++ b/src/stream/src/task/barrier_manager.rs @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{BTreeSet, HashSet}; +use std::collections::{HashMap, HashSet}; use std::fmt::Display; -use std::future::pending; +use std::future::{pending, poll_fn}; use std::iter::once; use std::sync::Arc; +use std::task::Poll; use std::time::Duration; use anyhow::anyhow; @@ -55,7 +56,7 @@ use risingwave_common::util::runtime::BackgroundShutdownRuntime; use risingwave_hummock_sdk::table_stats::to_prost_table_stats_map; use risingwave_hummock_sdk::{LocalSstableInfo, SyncResult}; use risingwave_pb::stream_service::streaming_control_stream_request::{ - InitRequest, InitialPartialGraph, Request, + DatabaseInitialPartialGraph, InitRequest, Request, }; use risingwave_pb::stream_service::streaming_control_stream_response::{ InitResponse, ShutdownResponse, @@ -69,7 +70,8 @@ use crate::executor::exchange::permit::Receiver; use crate::executor::monitor::StreamingMetrics; use crate::executor::{Barrier, BarrierInner, StreamExecutorError}; use crate::task::barrier_manager::managed_state::{ - ManagedBarrierStateDebugInfo, PartialGraphManagedBarrierState, + DatabaseManagedBarrierState, ManagedBarrierStateDebugInfo, ManagedBarrierStateEvent, + PartialGraphManagedBarrierState, }; use crate::task::barrier_manager::progress::BackfillState; @@ -155,9 +157,14 @@ impl ControlStreamHandle { } } - fn send_response(&mut self, response: StreamingControlStreamResponse) { + fn send_response(&mut self, response: streaming_control_stream_response::Response) { if let Some((sender, _)) = self.pair.as_ref() { - if sender.send(Ok(response)).is_err() { + if sender + .send(Ok(StreamingControlStreamResponse { + response: Some(response), + })) + .is_err() + { self.pair = None; warn!("fail to send response. control stream reset"); } @@ -198,8 +205,6 @@ pub(super) enum LocalBarrierEvent { actor_id: ActorId, barrier_sender: mpsc::UnboundedSender, }, - #[cfg(test)] - Flush(oneshot::Sender<()>), } #[derive(strum_macros::Display)] @@ -209,11 +214,14 @@ pub(super) enum LocalActorOperation { init_request: InitRequest, }, TakeReceiver { + database_id: DatabaseId, ids: UpDownActorIds, result_sender: oneshot::Sender>, }, #[cfg(test)] - GetCurrentSharedContext(oneshot::Sender>), + GetCurrentSharedContext(oneshot::Sender<(Arc, LocalBarrierManager)>), + #[cfg(test)] + Flush(oneshot::Sender<()>), InspectState { result_sender: oneshot::Sender, }, @@ -237,25 +245,25 @@ pub(crate) struct StreamActorManager { } pub(super) struct LocalBarrierWorkerDebugInfo<'a> { - running_actors: BTreeSet, - managed_barrier_state: ManagedBarrierStateDebugInfo<'a>, + managed_barrier_state: HashMap>, has_control_stream_connected: bool, } impl Display for LocalBarrierWorkerDebugInfo<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "running_actors: ")?; - for actor_id in &self.running_actors { - write!(f, "{}, ", actor_id)?; - } - writeln!( f, "\nhas_control_stream_connected: {}", self.has_control_stream_connected )?; - writeln!(f, "managed_barrier_state:\n{}", self.managed_barrier_state)?; + for (database_id, managed_barrier_state) in &self.managed_barrier_state { + writeln!( + f, + "database {} managed_barrier_state:\n{}", + database_id.database_id, managed_barrier_state + )?; + } Ok(()) } } @@ -268,97 +276,111 @@ pub(super) struct LocalBarrierWorker { pub(super) state: ManagedBarrierState, /// Futures will be finished in the order of epoch in ascending order. - await_epoch_completed_futures: FuturesOrdered, + await_epoch_completed_futures: HashMap>, control_stream_handle: ControlStreamHandle, pub(super) actor_manager: Arc, - - pub(super) current_shared_context: Arc, - - barrier_event_rx: UnboundedReceiver, - - actor_failure_rx: UnboundedReceiver<(ActorId, StreamError)>, } impl LocalBarrierWorker { pub(super) fn new( actor_manager: Arc, - initial_partial_graphs: Vec, + initial_partial_graphs: Vec, ) -> Self { - let (event_tx, event_rx) = unbounded_channel(); - let (failure_tx, failure_rx) = unbounded_channel(); - let shared_context = Arc::new(SharedContext::new( - &actor_manager.env, - LocalBarrierManager { - barrier_event_sender: event_tx, - actor_failure_sender: failure_tx, - }, - )); + let state = ManagedBarrierState::new(actor_manager.clone(), initial_partial_graphs); Self { - state: ManagedBarrierState::new( - actor_manager.clone(), - shared_context.clone(), - initial_partial_graphs, - ), + state, await_epoch_completed_futures: Default::default(), control_stream_handle: ControlStreamHandle::empty(), actor_manager, - current_shared_context: shared_context, - barrier_event_rx: event_rx, - actor_failure_rx: failure_rx, } } fn to_debug_info(&self) -> LocalBarrierWorkerDebugInfo<'_> { LocalBarrierWorkerDebugInfo { - running_actors: self.state.actor_states.keys().cloned().collect(), - managed_barrier_state: self.state.to_debug_info(), + managed_barrier_state: self + .state + .databases + .iter() + .map(|(database_id, state)| (*database_id, state.to_debug_info())) + .collect(), has_control_stream_connected: self.control_stream_handle.connected(), } } + pub(crate) fn get_or_insert_database_shared_context<'a>( + current_shared_context: &'a mut HashMap>, + database_id: DatabaseId, + actor_manager: &StreamActorManager, + ) -> &'a Arc { + current_shared_context + .entry(database_id) + .or_insert_with(|| Arc::new(SharedContext::new(database_id, &actor_manager.env))) + } + + async fn next_completed_epoch( + futures: &mut HashMap>, + ) -> ( + DatabaseId, + PartialGraphId, + Barrier, + StreamResult, + ) { + poll_fn(|cx| { + for (database_id, futures) in &mut *futures { + if let Poll::Ready(Some((partial_graph_id, barrier, result))) = + futures.poll_next_unpin(cx) + { + return Poll::Ready((*database_id, partial_graph_id, barrier, result)); + } + } + Poll::Pending + }) + .await + } + async fn run(mut self, mut actor_op_rx: UnboundedReceiver) { loop { select! { biased; - (partial_graph_id, barrier) = self.state.next_collected_epoch() => { - self.complete_barrier(partial_graph_id, barrier.epoch.prev); + (database_id, event) = self.state.next_event() => { + match event { + ManagedBarrierStateEvent::BarrierCollected{ + partial_graph_id, + barrier, + } => { + self.complete_barrier(database_id, partial_graph_id, barrier.epoch.prev); + } + ManagedBarrierStateEvent::ActorError{ + actor_id, + err, + } => { + self.notify_actor_failure(database_id, actor_id, err, "recv actor failure").await; + } + } } - (partial_graph_id, barrier, result) = rw_futures_util::pending_on_none(self.await_epoch_completed_futures.next()) => { + (database_id, partial_graph_id, barrier, result) = Self::next_completed_epoch(&mut self.await_epoch_completed_futures) => { match result { Ok(result) => { - self.on_epoch_completed(partial_graph_id, barrier.epoch.prev, result); + self.on_epoch_completed(database_id, partial_graph_id, barrier.epoch.prev, result); } Err(err) => { self.notify_other_failure(err, "failed to complete epoch").await; } } }, - event = self.barrier_event_rx.recv() => { - // event should not be None because the LocalBarrierManager holds a copy of tx - let result = self.handle_barrier_event(event.expect("should not be none")); - if let Err((actor_id, err)) = result { - self.notify_actor_failure(actor_id, err, "failed to handle barrier event").await; - } - }, - failure = self.actor_failure_rx.recv() => { - let (actor_id, err) = failure.unwrap(); - self.notify_actor_failure(actor_id, err, "recv actor failure").await; - }, actor_op = actor_op_rx.recv() => { if let Some(actor_op) = actor_op { match actor_op { LocalActorOperation::NewControlStream { handle, init_request } => { self.control_stream_handle.reset_stream_with_err(Status::internal("control stream has been reset to a new one")); - self.reset(init_request.graphs).await; + self.reset(init_request.databases).await; self.control_stream_handle = handle; - self.control_stream_handle.send_response(StreamingControlStreamResponse { - response: Some(streaming_control_stream_response::Response::Init(InitResponse {})) - }); + self.control_stream_handle.send_response(streaming_control_stream_response::Response::Init(InitResponse {})); } LocalActorOperation::Shutdown { result_sender } => { - if !self.state.actor_states.is_empty() { + if self.state.databases.values().any(|database| !database.actor_states.is_empty()) { tracing::warn!( "shutdown with running actors, scaling or migration will be triggered" ); @@ -376,7 +398,7 @@ impl LocalBarrierWorker { } }, request = self.control_stream_handle.next_request() => { - let result = self.handle_streaming_control_request(request); + let result = self.handle_streaming_control_request(request.request.expect("non empty")); if let Err(err) = result { self.notify_other_failure(err, "failed to inject barrier").await; } @@ -385,25 +407,29 @@ impl LocalBarrierWorker { } } - fn handle_streaming_control_request( - &mut self, - request: StreamingControlStreamRequest, - ) -> StreamResult<()> { - match request.request.expect("should not be empty") { + fn handle_streaming_control_request(&mut self, request: Request) -> StreamResult<()> { + match request { Request::InjectBarrier(req) => { let barrier = Barrier::from_protobuf(req.get_barrier().unwrap())?; - self.update_actor_info(req.broadcast_info.iter().cloned())?; + self.update_actor_info( + DatabaseId::new(req.database_id), + req.broadcast_info.iter().cloned(), + )?; self.send_barrier(&barrier, req)?; Ok(()) } Request::RemovePartialGraph(req) => { self.remove_partial_graphs( + DatabaseId::new(req.database_id), req.partial_graph_ids.into_iter().map(PartialGraphId::new), ); Ok(()) } Request::CreatePartialGraph(req) => { - self.add_partial_graph(PartialGraphId::new(req.partial_graph_id)); + self.add_partial_graph( + DatabaseId::new(req.database_id), + PartialGraphId::new(req.partial_graph_id), + ); Ok(()) } Request::Init(_) => { @@ -412,53 +438,65 @@ impl LocalBarrierWorker { } } - fn handle_barrier_event( - &mut self, - event: LocalBarrierEvent, - ) -> Result<(), (ActorId, StreamError)> { - match event { - LocalBarrierEvent::ReportActorCollected { actor_id, epoch } => { - self.collect(actor_id, epoch) + fn handle_actor_op(&mut self, actor_op: LocalActorOperation) { + match actor_op { + LocalActorOperation::NewControlStream { .. } | LocalActorOperation::Shutdown { .. } => { + unreachable!("event {actor_op} should be handled separately in async context") } - LocalBarrierEvent::ReportCreateProgress { - epoch, - actor, - state, + LocalActorOperation::TakeReceiver { + database_id, + ids, + result_sender, } => { - self.update_create_mview_progress(epoch, actor, state); + let _ = result_sender.send( + LocalBarrierWorker::get_or_insert_database_shared_context( + &mut self.state.current_shared_context, + database_id, + &self.actor_manager, + ) + .take_receiver(ids), + ); } - LocalBarrierEvent::RegisterBarrierSender { - actor_id, - barrier_sender, - } => { - self.state - .register_barrier_sender(actor_id, barrier_sender) - .map_err(|e| (actor_id, e))?; + #[cfg(test)] + LocalActorOperation::GetCurrentSharedContext(sender) => { + let database_state = self + .state + .databases + .get(&crate::task::TEST_DATABASE_ID) + .unwrap(); + let _ = sender.send(( + database_state.current_shared_context.clone(), + database_state.local_barrier_manager.clone(), + )); } #[cfg(test)] - LocalBarrierEvent::Flush(sender) => { + LocalActorOperation::Flush(sender) => { use futures::FutureExt; while let Some(request) = self.control_stream_handle.next_request().now_or_never() { - self.handle_streaming_control_request(request).unwrap(); + self.handle_streaming_control_request( + request.request.expect("should not be empty"), + ) + .unwrap(); + } + while let Some((database_id, event)) = self.state.next_event().now_or_never() { + match event { + ManagedBarrierStateEvent::BarrierCollected { + partial_graph_id, + barrier, + } => { + self.complete_barrier( + database_id, + partial_graph_id, + barrier.epoch.prev, + ); + } + ManagedBarrierStateEvent::ActorError { .. } => { + unreachable!() + } + } } sender.send(()).unwrap() } - } - Ok(()) - } - - fn handle_actor_op(&mut self, actor_op: LocalActorOperation) { - match actor_op { - LocalActorOperation::NewControlStream { .. } | LocalActorOperation::Shutdown { .. } => { - unreachable!("event {actor_op} should be handled separately in async context") - } - LocalActorOperation::TakeReceiver { ids, result_sender } => { - let _ = result_sender.send(self.current_shared_context.take_receiver(ids)); - } - #[cfg(test)] - LocalActorOperation::GetCurrentSharedContext(sender) => { - let _ = sender.send(self.current_shared_context.clone()); - } LocalActorOperation::InspectState { result_sender } => { let debug_info = self.to_debug_info(); let _ = result_sender.send(debug_info.to_string()); @@ -522,7 +560,7 @@ mod await_epoch_completed_future { } use await_epoch_completed_future::*; -use risingwave_common::catalog::TableId; +use risingwave_common::catalog::{DatabaseId, TableId}; use risingwave_storage::StateStoreImpl; fn sync_epoch( @@ -557,10 +595,18 @@ fn sync_epoch( } impl LocalBarrierWorker { - fn complete_barrier(&mut self, partial_graph_id: PartialGraphId, prev_epoch: u64) { + fn complete_barrier( + &mut self, + database_id: DatabaseId, + partial_graph_id: PartialGraphId, + prev_epoch: u64, + ) { { let (barrier, table_ids, create_mview_progress) = self .state + .databases + .get_mut(&database_id) + .expect("should exist") .pop_barrier_to_complete(partial_graph_id, prev_epoch); let complete_barrier_future = match &barrier.kind { @@ -582,20 +628,24 @@ impl LocalBarrierWorker { )), }; - self.await_epoch_completed_futures.push_back({ - instrument_complete_barrier_future( - partial_graph_id, - complete_barrier_future, - barrier, - self.actor_manager.await_tree_reg.as_ref(), - create_mview_progress, - ) - }); + self.await_epoch_completed_futures + .entry(database_id) + .or_default() + .push_back({ + instrument_complete_barrier_future( + partial_graph_id, + complete_barrier_future, + barrier, + self.actor_manager.await_tree_reg.as_ref(), + create_mview_progress, + ) + }); } } fn on_epoch_completed( &mut self, + database_id: DatabaseId, partial_graph_id: PartialGraphId, epoch: u64, result: BarrierCompleteResult, @@ -615,8 +665,8 @@ impl LocalBarrierWorker { }) .unwrap_or_default(); - let result = StreamingControlStreamResponse { - response: Some( + let result = { + { streaming_control_stream_response::Response::CompleteBarrier( BarrierCompleteResponse { request_id: "todo".to_string(), @@ -647,9 +697,10 @@ impl LocalBarrierWorker { .into_iter() .map(|sst| sst.sst_info.into()) .collect(), + database_id: database_id.database_id, }, - ), - ), + ) + } }; self.control_stream_handle.send_response(result); @@ -673,13 +724,28 @@ impl LocalBarrierWorker { request.actor_ids_to_collect ); - self.state.transform_to_issued(barrier, request)?; + self.state + .databases + .get_mut(&DatabaseId::new(request.database_id)) + .expect("should exist") + .transform_to_issued(barrier, request)?; Ok(()) } - fn remove_partial_graphs(&mut self, partial_graph_ids: impl Iterator) { + fn remove_partial_graphs( + &mut self, + database_id: DatabaseId, + partial_graph_ids: impl Iterator, + ) { + let Some(database_state) = self.state.databases.get_mut(&database_id) else { + warn!( + database_id = database_id.database_id, + "database to remove partial graph not exist" + ); + return; + }; for partial_graph_id in partial_graph_ids { - if let Some(graph) = self.state.graph_states.remove(&partial_graph_id) { + if let Some(graph) = database_state.graph_states.remove(&partial_graph_id) { assert!( graph.is_empty(), "non empty graph to be removed: {}", @@ -694,9 +760,27 @@ impl LocalBarrierWorker { } } - pub(super) fn add_partial_graph(&mut self, partial_graph_id: PartialGraphId) { + pub(super) fn add_partial_graph( + &mut self, + database_id: DatabaseId, + partial_graph_id: PartialGraphId, + ) { assert!(self .state + .databases + .entry(database_id) + .or_insert_with(|| { + DatabaseManagedBarrierState::new( + self.actor_manager.clone(), + LocalBarrierWorker::get_or_insert_database_shared_context( + &mut self.state.current_shared_context, + database_id, + &self.actor_manager, + ) + .clone(), + vec![], + ) + }) .graph_states .insert( partial_graph_id, @@ -706,27 +790,23 @@ impl LocalBarrierWorker { } /// Reset all internal states. - pub(super) fn reset_state(&mut self, initial_partial_graphs: Vec) { + pub(super) fn reset_state(&mut self, initial_partial_graphs: Vec) { *self = Self::new(self.actor_manager.clone(), initial_partial_graphs); } - /// When a [`crate::executor::StreamConsumer`] (typically [`crate::executor::DispatchExecutor`]) get a barrier, it should report - /// and collect this barrier with its own `actor_id` using this function. - fn collect(&mut self, actor_id: ActorId, epoch: EpochPair) { - self.state.collect(actor_id, epoch) - } - /// When a actor exit unexpectedly, the error is reported using this function. The control stream /// will be reset and the meta service will then trigger recovery. async fn notify_actor_failure( &mut self, + database_id: DatabaseId, actor_id: ActorId, err: StreamError, err_context: &'static str, ) { let root_err = self.try_find_root_failure(err).await; - if let Some(actor_state) = self.state.actor_states.get(&actor_id) + if let Some(database_state) = self.state.databases.get(&database_id) + && let Some(actor_state) = database_state.actor_states.get(&actor_id) && (!actor_state.inflight_barriers.is_empty() || actor_state.is_running()) { self.control_stream_handle.reset_stream_with_err( @@ -759,7 +839,18 @@ impl LocalBarrierWorker { let mut later_errs = vec![]; // fetch more actor errors within a timeout let _ = tokio::time::timeout(Duration::from_secs(3), async { - while let Some((_, error)) = self.actor_failure_rx.recv().await { + loop { + let error = poll_fn(|cx| { + for database in self.state.databases.values_mut() { + if let Poll::Ready(option) = database.actor_failure_rx.poll_recv(cx) { + let (_, err) = option + .expect("should not be none when tx in local_barrier_manager"); + return Poll::Ready(err); + } + } + Poll::Pending + }) + .await; later_errs.push(error); } }) @@ -838,6 +929,23 @@ impl EventSender { } impl LocalBarrierManager { + pub(super) fn new() -> ( + Self, + UnboundedReceiver, + UnboundedReceiver<(ActorId, StreamError)>, + ) { + let (event_tx, event_rx) = unbounded_channel(); + let (err_tx, err_rx) = unbounded_channel(); + ( + Self { + barrier_event_sender: event_tx, + actor_failure_sender: err_tx, + }, + event_rx, + err_rx, + ) + } + fn send_event(&self, event: LocalBarrierEvent) { // ignore error, because the current barrier manager maybe a stale one let _ = self.barrier_event_sender.send(event); @@ -976,12 +1084,6 @@ impl LocalBarrierManager { actor_failure_sender: failure_tx, } } - - pub async fn flush_all_events(&self) { - let (tx, rx) = oneshot::channel(); - self.send_event(LocalBarrierEvent::Flush(tx)); - rx.await.unwrap() - } } #[cfg(test)] @@ -991,23 +1093,26 @@ pub(crate) mod barrier_test_utils { use assert_matches::assert_matches; use futures::StreamExt; use risingwave_pb::stream_service::streaming_control_stream_request::{ - InitRequest, PbInitialPartialGraph, + InitRequest, PbDatabaseInitialPartialGraph, PbInitialPartialGraph, }; use risingwave_pb::stream_service::{ streaming_control_stream_request, streaming_control_stream_response, InjectBarrierRequest, StreamingControlStreamRequest, StreamingControlStreamResponse, }; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; + use tokio::sync::oneshot; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::Status; use crate::executor::Barrier; use crate::task::barrier_manager::{ControlStreamHandle, EventSender, LocalActorOperation}; - use crate::task::{ActorId, LocalBarrierManager, SharedContext}; + use crate::task::{ + ActorId, LocalBarrierManager, SharedContext, TEST_DATABASE_ID, TEST_PARTIAL_GRAPH_ID, + }; pub(crate) struct LocalBarrierTestEnv { pub shared_context: Arc, - #[expect(dead_code)] + pub local_barrier_manager: LocalBarrierManager, pub(super) actor_op_tx: EventSender, pub request_tx: UnboundedSender>, pub response_rx: UnboundedReceiver>, @@ -1026,9 +1131,12 @@ pub(crate) mod barrier_test_utils { UnboundedReceiverStream::new(request_rx).boxed(), ), init_request: InitRequest { - graphs: vec![PbInitialPartialGraph { - partial_graph_id: u64::MAX, - subscriptions: vec![], + databases: vec![PbDatabaseInitialPartialGraph { + database_id: TEST_DATABASE_ID.database_id, + graphs: vec![PbInitialPartialGraph { + partial_graph_id: TEST_PARTIAL_GRAPH_ID.into(), + subscriptions: vec![], + }], }], }, }); @@ -1038,13 +1146,14 @@ pub(crate) mod barrier_test_utils { streaming_control_stream_response::Response::Init(_) ); - let shared_context = actor_op_tx + let (shared_context, local_barrier_manager) = actor_op_tx .send_and_await(LocalActorOperation::GetCurrentSharedContext) .await .unwrap(); Self { shared_context, + local_barrier_manager, actor_op_tx, request_tx, response_rx, @@ -1062,9 +1171,10 @@ pub(crate) mod barrier_test_utils { InjectBarrierRequest { request_id: "".to_string(), barrier: Some(barrier.to_protobuf()), + database_id: TEST_DATABASE_ID.database_id, actor_ids_to_collect: actor_to_collect.into_iter().collect(), table_ids_to_sync: vec![], - partial_graph_id: u64::MAX, + partial_graph_id: TEST_PARTIAL_GRAPH_ID.into(), broadcast_info: vec![], actors_to_build: vec![], subscriptions_to_add: vec![], @@ -1074,5 +1184,15 @@ pub(crate) mod barrier_test_utils { })) .unwrap(); } + + pub(crate) async fn flush_all_events(&self) { + Self::flush_all_events_impl(&self.actor_op_tx).await + } + + pub(super) async fn flush_all_events_impl(actor_op_tx: &EventSender) { + let (tx, rx) = oneshot::channel(); + actor_op_tx.send_event(LocalActorOperation::Flush(tx)); + rx.await.unwrap() + } } } diff --git a/src/stream/src/task/barrier_manager/managed_state.rs b/src/stream/src/task/barrier_manager/managed_state.rs index bd5c92570f13d..2c57f8012195d 100644 --- a/src/stream/src/task/barrier_manager/managed_state.rs +++ b/src/stream/src/task/barrier_manager/managed_state.rs @@ -18,21 +18,24 @@ use std::fmt::{Debug, Display, Formatter}; use std::future::{pending, poll_fn, Future}; use std::mem::replace; use std::sync::Arc; -use std::task::Poll; +use std::task::{Context, Poll}; use prometheus::HistogramTimer; -use risingwave_common::catalog::TableId; +use risingwave_common::catalog::{DatabaseId, TableId}; use risingwave_common::util::epoch::EpochPair; use risingwave_pb::stream_plan::barrier::BarrierKind; use risingwave_storage::StateStoreImpl; use tokio::sync::mpsc; +use tokio::sync::mpsc::UnboundedReceiver; use tokio::task::JoinHandle; use super::progress::BackfillState; use crate::error::{StreamError, StreamResult}; use crate::executor::monitor::StreamingMetrics; use crate::executor::Barrier; -use crate::task::{ActorId, PartialGraphId, SharedContext, StreamActorManager}; +use crate::task::{ + ActorId, LocalBarrierManager, PartialGraphId, SharedContext, StreamActorManager, +}; struct IssuedState { /// Actor ids remaining to be collected. @@ -70,15 +73,24 @@ struct BarrierState { use risingwave_common::must_match; use risingwave_pb::stream_plan::SubscriptionUpstreamInfo; use risingwave_pb::stream_service::barrier_complete_response::PbCreateMviewProgress; -use risingwave_pb::stream_service::streaming_control_stream_request::InitialPartialGraph; +use risingwave_pb::stream_service::streaming_control_stream_request::{ + DatabaseInitialPartialGraph, InitialPartialGraph, +}; use risingwave_pb::stream_service::InjectBarrierRequest; +use crate::task::barrier_manager::LocalBarrierEvent; + pub(super) struct ManagedBarrierStateDebugInfo<'a> { + running_actors: BTreeSet, graph_states: &'a HashMap, } impl Display for ManagedBarrierStateDebugInfo<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "running_actors: ")?; + for actor_id in &self.running_actors { + write!(f, "{}, ", actor_id)?; + } for (partial_graph_id, graph_states) in self.graph_states { writeln!(f, "--- Partial Group {}", partial_graph_id.0)?; write!(f, "{}", graph_states)?; @@ -324,22 +336,84 @@ impl PartialGraphManagedBarrierState { } pub(crate) struct ManagedBarrierState { + pub(crate) databases: HashMap, + pub(crate) current_shared_context: HashMap>, +} + +pub(super) enum ManagedBarrierStateEvent { + BarrierCollected { + partial_graph_id: PartialGraphId, + barrier: Barrier, + }, + ActorError { + actor_id: ActorId, + err: StreamError, + }, +} + +impl ManagedBarrierState { + pub(super) fn new( + actor_manager: Arc, + initial_partial_graphs: Vec, + ) -> Self { + let mut databases = HashMap::new(); + let mut current_shared_context = HashMap::new(); + for database in initial_partial_graphs { + let database_id = DatabaseId::new(database.database_id); + assert!(!databases.contains_key(&database_id)); + let shared_context = Arc::new(SharedContext::new(database_id, &actor_manager.env)); + let state = DatabaseManagedBarrierState::new( + actor_manager.clone(), + shared_context.clone(), + database.graphs, + ); + databases.insert(database_id, state); + current_shared_context.insert(database_id, shared_context); + } + + Self { + databases, + current_shared_context, + } + } + + pub(super) fn next_event( + &mut self, + ) -> impl Future + '_ { + poll_fn(|cx| { + for (database_id, database) in &mut self.databases { + if let Poll::Ready(event) = database.poll_next_event(cx) { + return Poll::Ready((*database_id, event)); + } + } + Poll::Pending + }) + } +} + +pub(crate) struct DatabaseManagedBarrierState { pub(super) actor_states: HashMap, pub(super) graph_states: HashMap, actor_manager: Arc, - current_shared_context: Arc, + pub(super) current_shared_context: Arc, + pub(super) local_barrier_manager: LocalBarrierManager, + + barrier_event_rx: UnboundedReceiver, + pub(super) actor_failure_rx: UnboundedReceiver<(ActorId, StreamError)>, } -impl ManagedBarrierState { +impl DatabaseManagedBarrierState { /// Create a barrier manager state. This will be called only once. pub(super) fn new( actor_manager: Arc, current_shared_context: Arc, initial_partial_graphs: Vec, ) -> Self { + let (local_barrier_manager, barrier_event_rx, actor_failure_rx) = + LocalBarrierManager::new(); Self { actor_states: Default::default(), graph_states: initial_partial_graphs @@ -352,11 +426,15 @@ impl ManagedBarrierState { .collect(), actor_manager, current_shared_context, + local_barrier_manager, + barrier_event_rx, + actor_failure_rx, } } pub(super) fn to_debug_info(&self) -> ManagedBarrierStateDebugInfo<'_> { ManagedBarrierStateDebugInfo { + running_actors: self.actor_states.keys().cloned().collect(), graph_states: &self.graph_states, } } @@ -403,7 +481,7 @@ impl InflightActorState { } } -impl ManagedBarrierState { +impl DatabaseManagedBarrierState { pub(super) fn register_barrier_sender( &mut self, actor_id: ActorId, @@ -469,7 +547,7 @@ impl PartialGraphManagedBarrierState { } } -impl ManagedBarrierState { +impl DatabaseManagedBarrierState { pub(super) fn transform_to_issued( &mut self, barrier: &Barrier, @@ -508,6 +586,7 @@ impl ManagedBarrierState { actor, (*subscriptions).clone(), self.current_shared_context.clone(), + self.local_barrier_manager.clone(), ); assert!(self .actor_states @@ -567,10 +646,47 @@ impl ManagedBarrierState { Ok(()) } - pub(super) fn next_collected_epoch( + pub(super) fn poll_next_event( &mut self, - ) -> impl Future + '_ { - poll_fn(|_| { + cx: &mut Context<'_>, + ) -> Poll { + if let Poll::Ready(option) = self.actor_failure_rx.poll_recv(cx) { + let (actor_id, err) = option.expect("non-empty when tx in local_barrier_manager"); + return Poll::Ready(ManagedBarrierStateEvent::ActorError { actor_id, err }); + } + while let Poll::Ready(event) = self.barrier_event_rx.poll_recv(cx) { + match event.expect("non-empty when tx in local_barrier_manager") { + LocalBarrierEvent::ReportActorCollected { actor_id, epoch } => { + self.collect(actor_id, epoch); + } + LocalBarrierEvent::ReportCreateProgress { + epoch, + actor, + state, + } => { + self.update_create_mview_progress(epoch, actor, state); + } + LocalBarrierEvent::RegisterBarrierSender { + actor_id, + barrier_sender, + } => { + if let Err(err) = self.register_barrier_sender(actor_id, barrier_sender) { + return Poll::Ready(ManagedBarrierStateEvent::ActorError { actor_id, err }); + } + } + } + } + if let Some((partial_graph_id, barrier)) = self.next_collected_epoch() { + return Poll::Ready(ManagedBarrierStateEvent::BarrierCollected { + partial_graph_id, + barrier, + }); + } + Poll::Pending + } + + pub(super) fn next_collected_epoch(&mut self) -> Option<(PartialGraphId, Barrier)> { + { let mut output = None; for (partial_graph_id, graph_state) in &mut self.graph_states { if let Some(barrier) = graph_state.may_have_collected_all() { @@ -581,12 +697,12 @@ impl ManagedBarrierState { break; } } - output.map(Poll::Ready).unwrap_or(Poll::Pending) - }) + output + } } } -impl ManagedBarrierState { +impl DatabaseManagedBarrierState { pub(super) fn collect(&mut self, actor_id: ActorId, epoch: EpochPair) { let (prev_partial_graph_id, is_finished) = self .actor_states diff --git a/src/stream/src/task/barrier_manager/progress.rs b/src/stream/src/task/barrier_manager/progress.rs index c860b8f430fa1..1a449c6e811d8 100644 --- a/src/stream/src/task/barrier_manager/progress.rs +++ b/src/stream/src/task/barrier_manager/progress.rs @@ -19,8 +19,8 @@ use risingwave_common::util::epoch::EpochPair; use risingwave_pb::stream_service::barrier_complete_response::PbCreateMviewProgress; use super::LocalBarrierManager; +use crate::task::barrier_manager::managed_state::DatabaseManagedBarrierState; use crate::task::barrier_manager::LocalBarrierEvent::ReportCreateProgress; -use crate::task::barrier_manager::LocalBarrierWorker; use crate::task::ActorId; type ConsumedEpoch = u64; @@ -86,16 +86,16 @@ impl Display for BackfillState { } } -impl LocalBarrierWorker { +impl DatabaseManagedBarrierState { pub(crate) fn update_create_mview_progress( &mut self, epoch: EpochPair, actor: ActorId, state: BackfillState, ) { - if let Some(actor_state) = self.state.actor_states.get(&actor) + if let Some(actor_state) = self.actor_states.get(&actor) && let Some(partial_graph_id) = actor_state.inflight_barriers.get(&epoch.prev) - && let Some(graph_state) = self.state.graph_states.get_mut(partial_graph_id) + && let Some(graph_state) = self.graph_states.get_mut(partial_graph_id) { graph_state .create_mview_progress diff --git a/src/stream/src/task/barrier_manager/tests.rs b/src/stream/src/task/barrier_manager/tests.rs index a9ba0b4b7ed01..d1f873a6d3ebb 100644 --- a/src/stream/src/task/barrier_manager/tests.rs +++ b/src/stream/src/task/barrier_manager/tests.rs @@ -28,13 +28,10 @@ use crate::task::barrier_test_utils::LocalBarrierTestEnv; async fn test_managed_barrier_collection() -> StreamResult<()> { let mut test_env = LocalBarrierTestEnv::for_test().await; - let manager = &test_env.shared_context.local_barrier_manager; + let manager = &test_env.local_barrier_manager; let register_sender = |actor_id: u32| { - let barrier_rx = test_env - .shared_context - .local_barrier_manager - .subscribe_barrier(actor_id); + let barrier_rx = test_env.local_barrier_manager.subscribe_barrier(actor_id); (actor_id, barrier_rx) }; @@ -48,7 +45,7 @@ async fn test_managed_barrier_collection() -> StreamResult<()> { test_env.inject_barrier(&barrier, actor_ids.clone()); - manager.flush_all_events().await; + test_env.flush_all_events().await; let count = actor_ids.len(); let mut rxs = actor_ids @@ -77,7 +74,7 @@ async fn test_managed_barrier_collection() -> StreamResult<()> { // Report to local barrier manager for (i, (actor_id, barrier)) in collected_barriers.into_iter().enumerate() { manager.collect(actor_id, &barrier); - manager.flush_all_events().await; + LocalBarrierTestEnv::flush_all_events_impl(&test_env.actor_op_tx).await; let notified = poll_fn(|cx| Poll::Ready(await_epoch_future.as_mut().poll(cx).is_ready())).await; assert_eq!(notified, i == count - 1); @@ -90,13 +87,10 @@ async fn test_managed_barrier_collection() -> StreamResult<()> { async fn test_managed_barrier_collection_separately() -> StreamResult<()> { let mut test_env = LocalBarrierTestEnv::for_test().await; - let manager = &test_env.shared_context.local_barrier_manager; + let manager = &test_env.local_barrier_manager; let register_sender = |actor_id: u32| { - let barrier_rx = test_env - .shared_context - .local_barrier_manager - .subscribe_barrier(actor_id); + let barrier_rx = test_env.local_barrier_manager.subscribe_barrier(actor_id); (actor_id, barrier_rx) }; @@ -114,7 +108,7 @@ async fn test_managed_barrier_collection_separately() -> StreamResult<()> { test_env.inject_barrier(&barrier, actor_ids_to_collect.clone()); - manager.flush_all_events().await; + test_env.flush_all_events().await; // Register actors let count = actor_ids_to_send.len(); @@ -159,7 +153,7 @@ async fn test_managed_barrier_collection_separately() -> StreamResult<()> { // Report to local barrier manager for (i, (actor_id, barrier)) in collected_barriers.into_iter().enumerate() { manager.collect(actor_id, &barrier); - manager.flush_all_events().await; + LocalBarrierTestEnv::flush_all_events_impl(&test_env.actor_op_tx).await; let notified = poll_fn(|cx| Poll::Ready(await_epoch_future.as_mut().poll(cx).is_ready())).await; assert_eq!(notified, i == count - 1); @@ -172,13 +166,10 @@ async fn test_managed_barrier_collection_separately() -> StreamResult<()> { async fn test_late_register_barrier_sender() -> StreamResult<()> { let mut test_env = LocalBarrierTestEnv::for_test().await; - let manager = &test_env.shared_context.local_barrier_manager; + let manager = &test_env.local_barrier_manager; let register_sender = |actor_id: u32| { - let barrier_rx = test_env - .shared_context - .local_barrier_manager - .subscribe_barrier(actor_id); + let barrier_rx = test_env.local_barrier_manager.subscribe_barrier(actor_id); (actor_id, barrier_rx) }; @@ -203,7 +194,7 @@ async fn test_late_register_barrier_sender() -> StreamResult<()> { test_env.inject_barrier(&barrier1, actor_ids_to_collect.clone()); test_env.inject_barrier(&barrier2, actor_ids_to_collect.clone()); - manager.flush_all_events().await; + test_env.flush_all_events().await; // register sender after inject barrier let mut rxs = actor_ids_to_send @@ -250,7 +241,7 @@ async fn test_late_register_barrier_sender() -> StreamResult<()> { // Report to local barrier manager for (i, (actor_id, barrier)) in collected_barriers.into_iter().enumerate() { manager.collect(actor_id, &barrier); - manager.flush_all_events().await; + LocalBarrierTestEnv::flush_all_events_impl(&test_env.actor_op_tx).await; let notified = poll_fn(|cx| Poll::Ready(await_epoch_future.as_mut().poll(cx).is_ready())).await; assert_eq!(notified, i == count - 1); diff --git a/src/stream/src/task/mod.rs b/src/stream/src/task/mod.rs index 39da5e0b4ed93..3f0609bc3f830 100644 --- a/src/stream/src/task/mod.rs +++ b/src/stream/src/task/mod.rs @@ -30,6 +30,7 @@ mod stream_manager; pub use barrier_manager::*; pub use env::*; +use risingwave_common::catalog::DatabaseId; pub use stream_manager::*; pub type ConsumableChannelPair = (Option, Option); @@ -40,16 +41,23 @@ pub type UpDownActorIds = (ActorId, ActorId); pub type UpDownFragmentIds = (FragmentId, FragmentId); #[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)] -struct PartialGraphId(u64); +pub(crate) struct PartialGraphId(u32); + +#[cfg(test)] +pub(crate) const TEST_DATABASE_ID: risingwave_common::catalog::DatabaseId = + risingwave_common::catalog::DatabaseId::new(u32::MAX); + +#[cfg(test)] +pub(crate) const TEST_PARTIAL_GRAPH_ID: PartialGraphId = PartialGraphId(u32::MAX); impl PartialGraphId { - fn new(id: u64) -> Self { + fn new(id: u32) -> Self { Self(id) } } -impl From for u64 { - fn from(val: PartialGraphId) -> u64 { +impl From for u32 { + fn from(val: PartialGraphId) -> u32 { val.0 } } @@ -61,6 +69,8 @@ impl From for u64 { /// `SharedContext`, and the original one becomes stale. The new one is shared by actors created after /// recovery. pub struct SharedContext { + pub(crate) database_id: DatabaseId, + /// Stores the senders and receivers for later `Processor`'s usage. /// /// Each actor has several senders and several receivers. Senders and receivers are created @@ -96,8 +106,6 @@ pub struct SharedContext { pub(crate) compute_client_pool: ComputeClientPoolRef, pub(crate) config: StreamingConfig, - - pub(super) local_barrier_manager: LocalBarrierManager, } impl std::fmt::Debug for SharedContext { @@ -109,14 +117,14 @@ impl std::fmt::Debug for SharedContext { } impl SharedContext { - pub fn new(env: &StreamEnvironment, local_barrier_manager: LocalBarrierManager) -> Self { + pub fn new(database_id: DatabaseId, env: &StreamEnvironment) -> Self { Self { + database_id, channel_map: Default::default(), actor_infos: Default::default(), addr: env.server_address().clone(), config: env.config().as_ref().to_owned(), compute_client_pool: env.client_pool(), - local_barrier_manager, } } @@ -128,6 +136,7 @@ impl SharedContext { use risingwave_rpc_client::ComputeClientPool; Self { + database_id: TEST_DATABASE_ID, channel_map: Default::default(), actor_infos: Default::default(), addr: LOCAL_TEST_ADDR.clone(), @@ -141,7 +150,6 @@ impl SharedContext { ..Default::default() }, compute_client_pool: Arc::new(ComputeClientPool::for_test()), - local_barrier_manager: LocalBarrierManager::for_test(), } } diff --git a/src/stream/src/task/stream_manager.rs b/src/stream/src/task/stream_manager.rs index 648afb81ebc8d..16fe1f4c90f96 100644 --- a/src/stream/src/task/stream_manager.rs +++ b/src/stream/src/task/stream_manager.rs @@ -21,11 +21,12 @@ use std::time::Instant; use async_recursion::async_recursion; use await_tree::InstrumentAwait; +use futures::future::join_all; use futures::stream::BoxStream; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use risingwave_common::bitmap::Bitmap; -use risingwave_common::catalog::{ColumnId, Field, Schema, TableId}; +use risingwave_common::catalog::{ColumnId, DatabaseId, Field, Schema, TableId}; use risingwave_common::config::MetricLevel; use risingwave_common::{bail, must_match}; use risingwave_pb::common::ActorInfo; @@ -34,7 +35,7 @@ use risingwave_pb::stream_plan; use risingwave_pb::stream_plan::stream_node::NodeBody; use risingwave_pb::stream_plan::{StreamActor, StreamNode, StreamScanNode, StreamScanType}; use risingwave_pb::stream_service::streaming_control_stream_request::{ - InitRequest, InitialPartialGraph, + DatabaseInitialPartialGraph, InitRequest, }; use risingwave_pb::stream_service::{ StreamingControlStreamRequest, StreamingControlStreamResponse, @@ -217,9 +218,14 @@ impl LocalStreamManager { }) } - pub async fn take_receiver(&self, ids: UpDownActorIds) -> StreamResult { + pub async fn take_receiver( + &self, + database_id: DatabaseId, + ids: UpDownActorIds, + ) -> StreamResult { self.actor_op_tx .send_and_await(|result_sender| LocalActorOperation::TakeReceiver { + database_id, ids, result_sender, }) @@ -250,8 +256,14 @@ impl LocalStreamManager { impl LocalBarrierWorker { /// Force stop all actors on this worker, and then drop their resources. - pub(super) async fn reset(&mut self, initial_partial_graphs: Vec) { - self.state.abort_actors().await; + pub(super) async fn reset(&mut self, initial_partial_graphs: Vec) { + join_all( + self.state + .databases + .values_mut() + .map(|database| database.abort_actors()), + ) + .await; if let Some(m) = self.actor_manager.await_tree_reg.as_ref() { m.clear(); } @@ -352,6 +364,7 @@ impl StreamActorManager { vnode_bitmap: Option, shared_context: &Arc, env: StreamEnvironment, + local_barrier_manager: &LocalBarrierManager, state_store: impl StateStore, ) -> StreamResult { let [upstream_node, _]: &[_; 2] = stream_node.input.as_slice().try_into().unwrap(); @@ -377,14 +390,10 @@ impl StreamActorManager { .map(ColumnId::from) .collect_vec(); - let progress = shared_context - .local_barrier_manager - .register_create_mview_progress(actor_context.id); + let progress = local_barrier_manager.register_create_mview_progress(actor_context.id); let vnodes = vnode_bitmap.map(Arc::new); - let barrier_rx = shared_context - .local_barrier_manager - .subscribe_barrier(actor_context.id); + let barrier_rx = local_barrier_manager.subscribe_barrier(actor_context.id); let upstream_table = StorageTable::new_partial(state_store.clone(), column_ids, vnodes, table_desc); @@ -434,6 +443,7 @@ impl StreamActorManager { has_stateful: bool, subtasks: &mut Vec, shared_context: &Arc, + local_barrier_manager: &LocalBarrierManager, ) -> StreamResult { if let NodeBody::StreamScan(stream_scan) = node.get_node_body().unwrap() && let Ok(StreamScanType::SnapshotBackfill) = stream_scan.get_stream_scan_type() @@ -446,6 +456,7 @@ impl StreamActorManager { vnode_bitmap, shared_context, env, + local_barrier_manager, store, ) }); @@ -483,6 +494,7 @@ impl StreamActorManager { has_stateful || is_stateful, subtasks, shared_context, + local_barrier_manager, ) .await?, ); @@ -518,7 +530,7 @@ impl StreamActorManager { eval_error_report, watermark_epoch: self.watermark_epoch.clone(), shared_context: shared_context.clone(), - local_barrier_manager: shared_context.local_barrier_manager.clone(), + local_barrier_manager: local_barrier_manager.clone(), }; let executor = create_executor(executor_params, node, store).await?; @@ -548,6 +560,7 @@ impl StreamActorManager { } /// Create a chain(tree) of nodes and return the head executor. + #[expect(clippy::too_many_arguments)] async fn create_nodes( &self, fragment_id: FragmentId, @@ -556,6 +569,7 @@ impl StreamActorManager { actor_context: &ActorContextRef, vnode_bitmap: Option, shared_context: &Arc, + local_barrier_manager: &LocalBarrierManager, ) -> StreamResult<(Executor, Vec)> { let mut subtasks = vec![]; @@ -570,6 +584,7 @@ impl StreamActorManager { false, &mut subtasks, shared_context, + local_barrier_manager, ) .await })?; @@ -582,6 +597,7 @@ impl StreamActorManager { actor: StreamActor, shared_context: Arc, related_subscriptions: Arc>>, + local_barrier_manager: LocalBarrierManager, ) -> StreamResult> { { let actor_id = actor.actor_id; @@ -606,6 +622,7 @@ impl StreamActorManager { &actor_context, vnode_bitmap, &shared_context, + &local_barrier_manager, ) // If hummock tracing is not enabled, it directly returns wrapped future. .may_trace_hummock() @@ -625,7 +642,7 @@ impl StreamActorManager { self.streaming_metrics.clone(), actor_context.clone(), expr_context, - shared_context.local_barrier_manager.clone(), + local_barrier_manager, ); Ok(actor) } @@ -638,6 +655,7 @@ impl StreamActorManager { actor: StreamActor, related_subscriptions: Arc>>, current_shared_context: Arc, + local_barrier_manager: LocalBarrierManager, ) -> (JoinHandle<()>, Option>) { { let monitor = tokio_metrics::TaskMonitor::new(); @@ -646,9 +664,9 @@ impl StreamActorManager { let handle = { let trace_span = format!("Actor {actor_id}: `{}`", stream_actor_ref.mview_definition); - let barrier_manager = current_shared_context.local_barrier_manager.clone(); + let barrier_manager = local_barrier_manager.clone(); // wrap the future of `create_actor` with `boxed` to avoid stack overflow - let actor = self.clone().create_actor(actor, current_shared_context, related_subscriptions).boxed().and_then(|actor| actor.run()).map(move |result| { + let actor = self.clone().create_actor(actor, current_shared_context, related_subscriptions, barrier_manager.clone()).boxed().and_then(|actor| actor.run()).map(move |result| { if let Err(err) = result { // TODO: check error type and panic if it's unexpected. // Intentionally use `?` on the report to also include the backtrace. @@ -727,10 +745,17 @@ impl LocalBarrierWorker { /// This function could only be called once during the lifecycle of `LocalStreamManager` for /// now. pub fn update_actor_info( - &self, + &mut self, + database_id: DatabaseId, new_actor_infos: impl Iterator, ) -> StreamResult<()> { - let mut actor_infos = self.current_shared_context.actor_infos.write(); + let mut actor_infos = Self::get_or_insert_database_shared_context( + &mut self.state.current_shared_context, + database_id, + &self.actor_manager, + ) + .actor_infos + .write(); for actor in new_actor_infos { if let Some(prev_actor) = actor_infos.get(&actor.get_actor_id()) && &actor != prev_actor