Skip to content

Commit

Permalink
feat: variable vnode count support in batch/streaming scheduler (#18407)
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao authored Sep 9, 2024
1 parent 9a03718 commit 8d5b62b
Show file tree
Hide file tree
Showing 16 changed files with 123 additions and 84 deletions.
6 changes: 4 additions & 2 deletions src/batch/src/worker_manager/worker_node_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::time::Duration;
use rand::seq::SliceRandom;
use risingwave_common::bail;
use risingwave_common::catalog::OBJECT_ID_PLACEHOLDER;
use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
use risingwave_common::hash::{VirtualNode, WorkerSlotId, WorkerSlotMapping};
use risingwave_common::vnode_mapping::vnode_placement::place_vnode;
use risingwave_pb::common::{WorkerNode, WorkerType};

Expand Down Expand Up @@ -374,7 +374,9 @@ impl WorkerNodeSelector {
};
// 2. Temporary mapping that filters out unavailable workers.
let new_workers = self.apply_worker_node_mask(self.manager.list_serving_worker_nodes());
let masked_mapping = place_vnode(hint.as_ref(), &new_workers, parallelism);
// TODO(var-vnode): use vnode count from config
let masked_mapping =
place_vnode(hint.as_ref(), &new_workers, parallelism, VirtualNode::COUNT);
masked_mapping.ok_or_else(|| BatchError::EmptyWorkerNodes)
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/common/src/bitmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,12 @@ impl From<&PbBuffer> for Bitmap {
}
}

impl From<PbBuffer> for Bitmap {
fn from(buf: PbBuffer) -> Self {
Self::from(&buf)
}
}

/// Bitmap iterator.
pub struct BitmapIter<'a> {
bits: Option<&'a [usize]>,
Expand Down
16 changes: 6 additions & 10 deletions src/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ use serde_default::DefaultFromSerde;
use serde_json::Value;

use crate::for_all_params;
use crate::hash::VirtualNode;

/// Use the maximum value for HTTP/2 connection window size to avoid deadlock among multiplexed
/// streams on the same connection.
Expand Down Expand Up @@ -427,16 +426,13 @@ impl<'de> Deserialize<'de> for DefaultParallelism {
)))
}
}
Parallelism::Int(i) => Ok(DefaultParallelism::Default(if i > VirtualNode::COUNT {
Err(serde::de::Error::custom(format!(
"default parallelism should be not great than {}",
VirtualNode::COUNT
)))?
} else {
Parallelism::Int(i) => Ok(DefaultParallelism::Default(
// Note: we won't check whether this exceeds the maximum parallelism (i.e., vnode count)
// here because it requires extra context. The check will be done when scheduling jobs.
NonZeroUsize::new(i).ok_or_else(|| {
serde::de::Error::custom("default parallelism should be greater than 0")
})?
})),
serde::de::Error::custom("default parallelism should not be 0")
})?,
)),
}
}
}
Expand Down
39 changes: 29 additions & 10 deletions src/common/src/vnode_mapping/vnode_placement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ pub fn place_vnode(
hint_worker_slot_mapping: Option<&WorkerSlotMapping>,
workers: &[WorkerNode],
max_parallelism: Option<usize>,
vnode_count: usize,
) -> Option<WorkerSlotMapping> {
if let Some(mapping) = hint_worker_slot_mapping {
assert_eq!(mapping.len(), vnode_count);
}

// Get all serving worker slots from all available workers, grouped by worker id and ordered
// by worker slot id in each group.
let mut worker_slots: LinkedList<_> = workers
Expand All @@ -44,7 +49,7 @@ pub fn place_vnode(
// `max_parallelism` and total number of virtual nodes.
let serving_parallelism = std::cmp::min(
worker_slots.iter().map(|slots| slots.len()).sum(),
std::cmp::min(max_parallelism.unwrap_or(usize::MAX), VirtualNode::COUNT),
std::cmp::min(max_parallelism.unwrap_or(usize::MAX), vnode_count),
);

// Select `serving_parallelism` worker slots in a round-robin fashion, to distribute workload
Expand Down Expand Up @@ -79,14 +84,14 @@ pub fn place_vnode(
is_temp: bool,
}

let (expected, mut remain) = VirtualNode::COUNT.div_rem(&selected_slots.len());
let (expected, mut remain) = vnode_count.div_rem(&selected_slots.len());
let mut balances: HashMap<WorkerSlotId, Balance> = HashMap::default();

for slot in &selected_slots {
let mut balance = Balance {
slot: *slot,
balance: -(expected as i32),
builder: BitmapBuilder::zeroed(VirtualNode::COUNT),
builder: BitmapBuilder::zeroed(vnode_count),
is_temp: false,
};

Expand All @@ -102,7 +107,7 @@ pub fn place_vnode(
let mut temp_slot = Balance {
slot: WorkerSlotId::new(0u32, usize::MAX), /* This id doesn't matter for `temp_slot`. It's distinguishable via `is_temp`. */
balance: 0,
builder: BitmapBuilder::zeroed(VirtualNode::COUNT),
builder: BitmapBuilder::zeroed(vnode_count),
is_temp: true,
};
match hint_worker_slot_mapping {
Expand All @@ -123,7 +128,7 @@ pub fn place_vnode(
}
None => {
// No hint is provided, assign all vnodes to `temp_pu`.
for vnode in VirtualNode::all(VirtualNode::COUNT) {
for vnode in VirtualNode::all(vnode_count) {
temp_slot.balance += 1;
temp_slot.builder.set(vnode.to_index(), true);
}
Expand Down Expand Up @@ -158,7 +163,7 @@ pub fn place_vnode(
let mut dst = balances.pop_back().unwrap();
let n = std::cmp::min(src.balance.abs(), dst.balance.abs());
let mut moved = 0;
for idx in 0..VirtualNode::COUNT {
for idx in 0..vnode_count {
if moved >= n {
break;
}
Expand Down Expand Up @@ -189,7 +194,7 @@ pub fn place_vnode(
for (worker_slot, bitmap) in results {
worker_result
.entry(worker_slot)
.or_insert(BitmapBuilder::zeroed(VirtualNode::COUNT).finish())
.or_insert(Bitmap::zeros(vnode_count))
.bitor_assign(&bitmap);
}

Expand All @@ -204,10 +209,24 @@ mod tests {
use risingwave_pb::common::WorkerNode;

use crate::hash::VirtualNode;
use crate::vnode_mapping::vnode_placement::place_vnode;

/// [`super::place_vnode`] with [`VirtualNode::COUNT_FOR_TEST`] as the vnode count.
fn place_vnode(
hint_worker_slot_mapping: Option<&WorkerSlotMapping>,
workers: &[WorkerNode],
max_parallelism: Option<usize>,
) -> Option<WorkerSlotMapping> {
super::place_vnode(
hint_worker_slot_mapping,
workers,
max_parallelism,
VirtualNode::COUNT_FOR_TEST,
)
}

#[test]
fn test_place_vnode() {
assert_eq!(VirtualNode::COUNT, 256);
assert_eq!(VirtualNode::COUNT_FOR_TEST, 256);

let serving_property = Property {
is_unschedulable: false,
Expand All @@ -220,7 +239,7 @@ mod tests {
assert_eq!(wm1.len(), 256);
assert_eq!(wm2.len(), 256);
let mut count: usize = 0;
for idx in 0..VirtualNode::COUNT {
for idx in 0..VirtualNode::COUNT_FOR_TEST {
let vnode = VirtualNode::from_index(idx);
if wm1.get(vnode) == wm2.get(vnode) {
count += 1;
Expand Down
12 changes: 7 additions & 5 deletions src/frontend/src/handler/alter_parallelism.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,23 @@ pub async fn handle_alter_parallelism(
.filter(|w| w.is_streaming_schedulable())
.map(|w| w.parallelism)
.sum::<u32>();
// TODO(var-vnode): use vnode count from config
let max_parallelism = VirtualNode::COUNT;

let mut builder = RwPgResponse::builder(stmt_type);

match &target_parallelism.parallelism {
Some(Parallelism::Adaptive(_)) | Some(Parallelism::Auto(_)) => {
if available_parallelism > VirtualNode::COUNT as u32 {
builder = builder.notice(format!("Available parallelism exceeds the maximum parallelism limit, the actual parallelism will be limited to {}", VirtualNode::COUNT));
if available_parallelism > max_parallelism as u32 {
builder = builder.notice(format!("Available parallelism exceeds the maximum parallelism limit, the actual parallelism will be limited to {max_parallelism}"));
}
}
Some(Parallelism::Fixed(FixedParallelism { parallelism })) => {
if *parallelism > VirtualNode::COUNT as u32 {
builder = builder.notice(format!("Provided parallelism exceeds the maximum parallelism limit, resetting to FIXED({})", VirtualNode::COUNT));
if *parallelism > max_parallelism as u32 {
builder = builder.notice(format!("Provided parallelism exceeds the maximum parallelism limit, resetting to FIXED({max_parallelism})"));
target_parallelism = PbTableParallelism {
parallelism: Some(PbParallelism::Fixed(FixedParallelism {
parallelism: VirtualNode::COUNT as u32,
parallelism: max_parallelism as u32,
})),
};
}
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/scheduler/distributed/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ impl StageRunner {
.expect("no partition info for seq scan")
.into_table()
.expect("PartitionInfo should be TablePartitionInfo");
scan_node.vnode_bitmap = Some(partition.vnode_bitmap);
scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
scan_node.scan_ranges = partition.scan_ranges;
PbPlanNode {
children: vec![],
Expand All @@ -1045,7 +1045,7 @@ impl StageRunner {
.expect("no partition info for seq scan")
.into_table()
.expect("PartitionInfo should be TablePartitionInfo");
scan_node.vnode_bitmap = Some(partition.vnode_bitmap);
scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
PbPlanNode {
children: vec![],
identity,
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/scheduler/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ impl LocalQueryExecution {
let partition = partition
.into_table()
.expect("PartitionInfo should be TablePartitionInfo here");
scan_node.vnode_bitmap = Some(partition.vnode_bitmap);
scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
scan_node.scan_ranges = partition.scan_ranges;
}
}
Expand All @@ -522,7 +522,7 @@ impl LocalQueryExecution {
let partition = partition
.into_table()
.expect("PartitionInfo should be TablePartitionInfo here");
scan_node.vnode_bitmap = Some(partition.vnode_bitmap);
scan_node.vnode_bitmap = Some(partition.vnode_bitmap.to_protobuf());
}
}
_ => unreachable!(),
Expand Down
21 changes: 9 additions & 12 deletions src/frontend/src/scheduler/plan_fragmenter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use risingwave_common::bail;
use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
use risingwave_common::catalog::{Schema, TableDesc};
use risingwave_common::hash::table_distribution::TableDistribution;
use risingwave_common::hash::{VirtualNode, WorkerSlotId, WorkerSlotMapping};
use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
use risingwave_common::util::scan_range::ScanRange;
use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
use risingwave_connector::source::filesystem::opendal_source::{
Expand All @@ -44,7 +44,6 @@ use risingwave_connector::source::{
};
use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
use risingwave_pb::common::Buffer;
use risingwave_pb::plan_common::Field as PbField;
use risingwave_sqlparser::ast::AsOf;
use serde::ser::SerializeStruct;
Expand Down Expand Up @@ -437,7 +436,7 @@ impl TableScanInfo {

#[derive(Clone, Debug)]
pub struct TablePartitionInfo {
pub vnode_bitmap: Buffer,
pub vnode_bitmap: Bitmap,
pub scan_ranges: Vec<ScanRangeProto>,
}

Expand Down Expand Up @@ -922,8 +921,7 @@ impl BatchPlanFragmenter {
.drain()
.take(1)
.update(|(_, info)| {
info.vnode_bitmap =
Bitmap::ones(VirtualNode::COUNT).to_protobuf();
info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
})
.collect();
}
Expand Down Expand Up @@ -1230,7 +1228,7 @@ fn derive_partitions(
table_desc: &TableDesc,
vnode_mapping: &WorkerSlotMapping,
) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
let num_vnodes = vnode_mapping.len();
let vnode_count = vnode_mapping.len();
let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();

if scan_ranges.is_empty() {
Expand All @@ -1241,7 +1239,7 @@ fn derive_partitions(
(
k,
TablePartitionInfo {
vnode_bitmap: vnode_bitmap.to_protobuf(),
vnode_bitmap,
scan_ranges: vec![],
},
)
Expand All @@ -1250,8 +1248,7 @@ fn derive_partitions(
}

let table_distribution = TableDistribution::new_from_storage_table_desc(
// TODO(var-vnode): use vnode count from table desc
Some(Bitmap::ones(VirtualNode::COUNT).into()),
Some(Bitmap::ones(vnode_count).into()),
&table_desc.try_to_protobuf()?,
);

Expand All @@ -1264,7 +1261,7 @@ fn derive_partitions(
|(worker_slot_id, vnode_bitmap)| {
let (bitmap, scan_ranges) = partitions
.entry(worker_slot_id)
.or_insert_with(|| (BitmapBuilder::zeroed(num_vnodes), vec![]));
.or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
vnode_bitmap
.iter()
.enumerate()
Expand All @@ -1278,7 +1275,7 @@ fn derive_partitions(
let worker_slot_id = vnode_mapping[vnode];
let (bitmap, scan_ranges) = partitions
.entry(worker_slot_id)
.or_insert_with(|| (BitmapBuilder::zeroed(num_vnodes), vec![]));
.or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
bitmap.set(vnode.to_index(), true);
scan_ranges.push(scan_range.to_protobuf());
}
Expand All @@ -1291,7 +1288,7 @@ fn derive_partitions(
(
k,
TablePartitionInfo {
vnode_bitmap: bitmap.finish().to_protobuf(),
vnode_bitmap: bitmap.finish(),
scan_ranges,
},
)
Expand Down
1 change: 1 addition & 0 deletions src/jni_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ impl<'a> Deref for JavaBindingIterator<'a> {

#[no_mangle]
extern "system" fn Java_com_risingwave_java_binding_Binding_vnodeCount(_env: EnvParam<'_>) -> jint {
// TODO(var-vnode): use vnode count from config
VirtualNode::COUNT as jint
}

Expand Down
3 changes: 2 additions & 1 deletion src/meta/src/rpc/ddl_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,7 @@ impl DdlController {

let parallelism = self.resolve_stream_parallelism(specified_parallelism, &cluster_info)?;

// TODO(var-vnode): use vnode count from config
const MAX_PARALLELISM: NonZeroUsize = NonZeroUsize::new(VirtualNode::COUNT).unwrap();

let parallelism_limited = parallelism > MAX_PARALLELISM;
Expand Down Expand Up @@ -1645,7 +1646,7 @@ impl DdlController {
// Otherwise, it defaults to FIXED based on deduction.
let table_parallelism = match (specified_parallelism, &self.env.opts.default_parallelism) {
(None, DefaultParallelism::Full) if parallelism_limited => {
tracing::warn!("Parallelism limited to 256 in ADAPTIVE mode");
tracing::warn!("Parallelism limited to {MAX_PARALLELISM} in ADAPTIVE mode");
TableParallelism::Adaptive
}
(None, DefaultParallelism::Full) => TableParallelism::Adaptive,
Expand Down
5 changes: 3 additions & 2 deletions src/meta/src/serving/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::collections::HashMap;
use std::sync::Arc;

use parking_lot::RwLock;
use risingwave_common::hash::WorkerSlotMapping;
use risingwave_common::hash::{VirtualNode, WorkerSlotMapping};
use risingwave_common::vnode_mapping::vnode_placement::place_vnode;
use risingwave_pb::common::{WorkerNode, WorkerType};
use risingwave_pb::meta::subscribe_response::{Info, Operation};
Expand Down Expand Up @@ -57,7 +57,8 @@ impl ServingVnodeMapping {
} else {
None
};
place_vnode(old_mapping, workers, max_parallelism)
// TODO(var-vnode): use vnode count from config
place_vnode(old_mapping, workers, max_parallelism, VirtualNode::COUNT)
};
match new_mapping {
None => {
Expand Down
Loading

0 comments on commit 8d5b62b

Please sign in to comment.