From aa438cae8c552a18f6701d6237bcbcc982bed25a Mon Sep 17 00:00:00 2001 From: Draco Date: Fri, 27 Sep 2024 15:40:42 +0800 Subject: [PATCH] feat: use multithreading to optimize WAL replay (#1572) ## Rationale 1. Currently, the WAL replayer uses coroutines to replay the WAL logs of multiple tables in parallel. However, coroutines utilize at most one CPU. By switching to a multithreaded approach, we can fully leverage multiple CPUs. 2. We observed that during the replay phase, decoding the WAL log is a CPU-intensive operation, so parallelize it. ## Detailed Changes 1. Modify both `TableBasedReplay` and `RegionBasedReplay` to use the `spawn task` approach for parallelism, with a maximum of 20 tasks running concurrently. 2. Preload next segment in WAL based on local storage. 4. In `BatchLogIteratorAdapter::simulated_async_next`, we first retrieve all the payloads in a batch and then decode them in parallel. ## Test Plan Manual testing. --- Cargo.lock | 22 ++- src/analytic_engine/Cargo.toml | 1 + .../src/instance/wal_replayer.rs | 116 +++++++------ src/wal/Cargo.toml | 2 + .../src/local_storage_impl/record_encoding.rs | 16 +- src/wal/src/local_storage_impl/segment.rs | 164 +++++++++++++----- src/wal/src/manager.rs | 39 +++-- 7 files changed, 238 insertions(+), 122 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 43d74e7e66..584f2e98b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,6 +83,7 @@ dependencies = [ "arc-swap 1.6.0", "arena", "arrow 49.0.0", + "async-scoped", "async-stream", "async-trait", "atomic_enum", @@ -764,6 +765,17 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "async-scoped" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4042078ea593edffc452eef14e99fdb2b120caa4ad9618bcdeabc4a023b98740" +dependencies = [ + "futures 0.3.28", + "pin-project", + "tokio", +] + [[package]] name = "async-stream" version = "0.3.4" @@ -5981,9 +5993,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.8.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" dependencies = [ "either", "rayon-core", @@ -5991,9 +6003,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -8220,6 +8232,7 @@ name = "wal" version = "2.0.0" dependencies = [ "anyhow", + "async-scoped", "async-trait", "bytes_ext", "chrono", @@ -8237,6 +8250,7 @@ dependencies = [ "prometheus 0.12.0", "prost 0.11.8", "rand 0.8.5", + "rayon", "rocksdb", "runtime", "serde", diff --git a/src/analytic_engine/Cargo.toml b/src/analytic_engine/Cargo.toml index 8197b4eeb1..09ff47af21 100644 --- a/src/analytic_engine/Cargo.toml +++ b/src/analytic_engine/Cargo.toml @@ -43,6 +43,7 @@ anyhow = { workspace = true } arc-swap = "1.4.0" arena = { workspace = true } arrow = { workspace = true } +async-scoped = { version = "0.9.0", features = ["use-tokio"] } async-stream = { workspace = true } async-trait = { workspace = true } atomic_enum = { workspace = true } diff --git a/src/analytic_engine/src/instance/wal_replayer.rs b/src/analytic_engine/src/instance/wal_replayer.rs index f782895145..6c67414037 100644 --- a/src/analytic_engine/src/instance/wal_replayer.rs +++ b/src/analytic_engine/src/instance/wal_replayer.rs @@ -30,14 +30,13 @@ use common_types::{ schema::{IndexInWriterSchema, Schema}, table::ShardId, }; -use futures::StreamExt; use generic_error::BoxError; use lazy_static::lazy_static; use logger::{debug, error, info, trace, warn}; use prometheus::{exponential_buckets, register_histogram, Histogram}; use snafu::ResultExt; use table_engine::table::TableId; -use tokio::sync::{Mutex, MutexGuard}; +use tokio::sync::{Mutex, MutexGuard, Semaphore}; use wal::{ log_batch::LogEntry, manager::{ @@ -74,6 +73,8 @@ lazy_static! { .unwrap(); } +const MAX_REPLAY_TASK_NUM: usize = 20; + /// Wal replayer supporting both table based and region based // TODO: limit the memory usage in `RegionBased` mode. pub struct WalReplayer<'a> { @@ -189,22 +190,23 @@ impl Replay for TableBasedReplay { ..Default::default() }; - let mut tasks = futures::stream::iter( - table_datas - .iter() - .map(|table_data| { - let table_id = table_data.id; - let read_ctx = &read_ctx; - async move { - let ret = Self::recover_table_logs(context, table_data, read_ctx).await; - (table_id, ret) - } - }) - .collect::>(), - ) - .buffer_unordered(20); - while let Some((table_id, ret)) = tasks.next().await { - if let Err(e) = ret { + let ((), results) = async_scoped::TokioScope::scope_and_block(|scope| { + // Limit the maximum number of concurrent tasks. + let semaphore = Arc::new(Semaphore::new(MAX_REPLAY_TASK_NUM)); + for table_data in table_datas { + let table_id = table_data.id; + let read_ctx = &read_ctx; + let semaphore = semaphore.clone(); + scope.spawn(async move { + let _permit = semaphore.acquire().await.unwrap(); + let ret = Self::recover_table_logs(context, table_data, read_ctx).await; + (table_id, ret) + }); + } + }); + + for result in results.into_iter().flatten() { + if let (table_id, Err(e)) = result { // If occur error, mark this table as failed and store the cause. failed_tables.insert(table_id, e); } @@ -345,7 +347,7 @@ impl RegionBasedReplay { table_data: table_data.clone(), serial_exec, }; - serial_exec_ctxs.insert(table_data.id, serial_exec_ctx); + serial_exec_ctxs.insert(table_data.id, Mutex::new(serial_exec_ctx)); table_datas_by_id.insert(table_data.id.as_u64(), table_data.clone()); } @@ -353,7 +355,7 @@ impl RegionBasedReplay { let schema_provider = TableSchemaProviderAdapter { table_datas: table_datas_by_id.clone(), }; - let serial_exec_ctxs = Arc::new(Mutex::new(serial_exec_ctxs)); + let serial_exec_ctxs = serial_exec_ctxs; // Split and replay logs. loop { let _timer = PULL_LOGS_DURATION_HISTOGRAM.start_timer(); @@ -381,49 +383,53 @@ impl RegionBasedReplay { async fn replay_single_batch( context: &ReplayContext, log_batch: &VecDeque>, - serial_exec_ctxs: &Arc>>>, + serial_exec_ctxs: &HashMap>>, failed_tables: &mut FailedTables, ) -> Result<()> { let mut table_batches = Vec::new(); // TODO: No `group_by` method in `VecDeque`, so implement it manually here... Self::split_log_batch_by_table(log_batch, &mut table_batches); - // TODO: Replay logs of different tables in parallel. - let mut replay_tasks = Vec::with_capacity(table_batches.len()); - for table_batch in table_batches { - // Some tables may have failed in previous replay, ignore them. - if failed_tables.contains_key(&table_batch.table_id) { - continue; - } - let log_entries: Vec<_> = table_batch - .ranges - .iter() - .flat_map(|range| log_batch.range(range.clone())) - .collect(); - - let serial_exec_ctxs = serial_exec_ctxs.clone(); - replay_tasks.push(async move { - // Some tables may have been moved to other shards or dropped, ignore such logs. - if let Some(ctx) = serial_exec_ctxs.lock().await.get_mut(&table_batch.table_id) { - let result = replay_table_log_entries( - &context.flusher, - context.max_retry_flush_limit, - &mut ctx.serial_exec, - &ctx.table_data, - log_entries.into_iter(), - ) - .await; - (table_batch.table_id, Some(result)) - } else { - (table_batch.table_id, None) + let ((), results) = async_scoped::TokioScope::scope_and_block(|scope| { + // Limit the maximum number of concurrent tasks. + let semaphore = Arc::new(Semaphore::new(MAX_REPLAY_TASK_NUM)); + + for table_batch in table_batches { + // Some tables may have failed in previous replay, ignore them. + if failed_tables.contains_key(&table_batch.table_id) { + continue; } - }); - } + let log_entries: Vec<_> = table_batch + .ranges + .iter() + .flat_map(|range| log_batch.range(range.clone())) + .collect(); + let semaphore = semaphore.clone(); + + scope.spawn(async move { + let _permit = semaphore.acquire().await.unwrap(); + // Some tables may have been moved to other shards or dropped, ignore such logs. + if let Some(ctx) = serial_exec_ctxs.get(&table_batch.table_id) { + let mut ctx = ctx.lock().await; + let table_data = ctx.table_data.clone(); + let result = replay_table_log_entries( + &context.flusher, + context.max_retry_flush_limit, + &mut ctx.serial_exec, + &table_data, + log_entries.into_iter(), + ) + .await; + (table_batch.table_id, Some(result)) + } else { + (table_batch.table_id, None) + } + }); + } + }); - // Run at most 20 tasks in parallel - let mut replay_tasks = futures::stream::iter(replay_tasks).buffer_unordered(20); - while let Some((table_id, ret)) = replay_tasks.next().await { - if let Some(Err(e)) = ret { + for result in results.into_iter().flatten() { + if let (table_id, Some(Err(e))) = result { // If occur error, mark this table as failed and store the cause. failed_tables.insert(table_id, e); } diff --git a/src/wal/Cargo.toml b/src/wal/Cargo.toml index 30a5b00461..1464016453 100644 --- a/src/wal/Cargo.toml +++ b/src/wal/Cargo.toml @@ -48,6 +48,7 @@ required-features = ["wal-message-queue", "wal-table-kv", "wal-rocksdb", "wal-lo [dependencies] anyhow = { workspace = true } +async-scoped = { version = "0.9.0", features = ["use-tokio"] } async-trait = { workspace = true } bytes_ext = { workspace = true } chrono = { workspace = true } @@ -64,6 +65,7 @@ memmap2 = { version = "0.9.4", optional = true } message_queue = { workspace = true, optional = true } prometheus = { workspace = true } prost = { workspace = true } +rayon = "1.10.0" runtime = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/src/wal/src/local_storage_impl/record_encoding.rs b/src/wal/src/local_storage_impl/record_encoding.rs index 6e8011c2e1..e91d3f5ab2 100644 --- a/src/wal/src/local_storage_impl/record_encoding.rs +++ b/src/wal/src/local_storage_impl/record_encoding.rs @@ -63,7 +63,7 @@ define_result!(Error); /// +---------+--------+--------+------------+--------------+--------------+-------+ /// ``` #[derive(Debug)] -pub struct Record<'a> { +pub struct Record { /// The version number of the record. pub version: u8, @@ -83,11 +83,11 @@ pub struct Record<'a> { pub value_length: u32, /// Common log value. - pub value: &'a [u8], + pub value: Vec, } -impl<'a> Record<'a> { - pub fn new(table_id: u64, sequence_num: u64, value: &'a [u8]) -> Result { +impl Record { + pub fn new(table_id: u64, sequence_num: u64, value: &[u8]) -> Result { let mut record = Record { version: NEWEST_RECORD_ENCODING_VERSION, crc: 0, @@ -95,7 +95,7 @@ impl<'a> Record<'a> { table_id, sequence_num, value_length: value.len() as u32, - value, + value: value.to_vec(), }; // Calculate CRC @@ -128,7 +128,7 @@ impl RecordEncoding { } } -impl Encoder> for RecordEncoding { +impl Encoder for RecordEncoding { type Error = Error; fn encode(&self, buf: &mut B, record: &Record) -> Result<()> { @@ -147,7 +147,7 @@ impl Encoder> for RecordEncoding { buf.try_put_u64(record.table_id).context(Encoding)?; buf.try_put_u64(record.sequence_num).context(Encoding)?; buf.try_put_u32(record.value_length).context(Encoding)?; - buf.try_put(record.value).context(Encoding)?; + buf.try_put(record.value.as_slice()).context(Encoding)?; Ok(()) } @@ -222,7 +222,7 @@ impl RecordEncoding { let value_length = buf.try_get_u32().context(Decoding)?; // Read value - let value = &buf[0..value_length as usize]; + let value = buf[0..value_length as usize].to_vec(); buf.advance(value_length as usize); Ok(Record { diff --git a/src/wal/src/local_storage_impl/segment.rs b/src/wal/src/local_storage_impl/segment.rs index 02bd1d13e0..b66701b209 100644 --- a/src/wal/src/local_storage_impl/segment.rs +++ b/src/wal/src/local_storage_impl/segment.rs @@ -33,7 +33,7 @@ use common_types::{table::TableId, SequenceNumber, MAX_SEQUENCE_NUMBER, MIN_SEQU use generic_error::{BoxError, GenericError}; use macros::define_result; use memmap2::{MmapMut, MmapOptions}; -use runtime::Runtime; +use runtime::{JoinHandle, Runtime}; use snafu::{ensure, Backtrace, ResultExt, Snafu}; use crate::{ @@ -832,6 +832,7 @@ impl Region { Some(req.location.table_id), start, end, + self.runtime.clone(), )?; Ok(BatchLogIteratorAdapter::new_with_sync( @@ -849,6 +850,7 @@ impl Region { None, MIN_SEQUENCE_NUMBER, MAX_SEQUENCE_NUMBER, + self.runtime.clone(), )?; Ok(BatchLogIteratorAdapter::new_with_sync( Box::new(iter), @@ -1006,19 +1008,37 @@ impl RegionManager { } } +fn decode_segment_content( + segment_content: &[u8], + record_positions: &[Position], + record_encoding: &RecordEncoding, +) -> Result> { + let mut records = Vec::with_capacity(record_positions.len()); + + for pos in record_positions { + // Extract the record data from the segment content + let record_data = &segment_content[pos.start..pos.end]; + + // Decode the record + let record = record_encoding + .decode(record_data) + .box_err() + .context(InvalidRecord)?; + records.push(record); + } + Ok(records) +} + #[derive(Debug)] struct SegmentLogIterator { /// Encoding method for common log. log_encoding: CommonLogEncoding, /// Encoding method for records. - record_encoding: RecordEncoding, - - /// Raw content of the segment. - segment_content: Vec, + _record_encoding: RecordEncoding, - /// Positions of records within the segment content. - record_positions: Vec, + /// Decoded log records in the segment. + records: Vec, /// Optional identifier for the table, which is used to filter logs. table_id: Option, @@ -1040,27 +1060,19 @@ struct SegmentLogIterator { } impl SegmentLogIterator { - pub fn new( + pub fn new_with_records( log_encoding: CommonLogEncoding, record_encoding: RecordEncoding, - segment: Arc>, - segment_manager: Arc, + records: Vec, + table_ranges: HashMap, table_id: Option, start: SequenceNumber, end: SequenceNumber, ) -> Result { - let mut guard = segment.lock().unwrap(); - // Open the segment if it is not open - segment_manager.open_segment(&mut guard, segment.clone())?; - let segment_content = guard.read(0, guard.current_size)?; - let record_positions = guard.record_position.clone(); - let table_ranges = guard.table_ranges.clone(); - Ok(Self { log_encoding, - record_encoding, - segment_content, - record_positions, + _record_encoding: record_encoding, + records, table_id, table_ranges, start, @@ -1076,24 +1088,14 @@ impl SegmentLogIterator { } loop { - // Get the next record position - let Some(pos) = self.record_positions.get(self.current_record_idx) else { + // Get the next record + let Some(record) = self.records.get(self.current_record_idx) else { self.no_more_data = true; return Ok(None); }; self.current_record_idx += 1; - // Extract the record data from the segment content - let record_data = &self.segment_content[pos.start..pos.end]; - - // Decode the record - let record = self - .record_encoding - .decode(record_data) - .box_err() - .context(InvalidRecord)?; - // Filter by sequence number if record.sequence_num < self.start { continue; @@ -1122,7 +1124,7 @@ impl SegmentLogIterator { // Decode the value let value = self .log_encoding - .decode_value(record.value) + .decode_value(&record.value) .box_err() .context(InvalidRecord)?; @@ -1150,6 +1152,9 @@ pub struct MultiSegmentLogIterator { /// Current segment iterator. current_iterator: Option, + /// Future iterator for preloading the next segment. + next_segment_iterator: Option>>, + /// Encoding method for common log. log_encoding: CommonLogEncoding, @@ -1167,6 +1172,9 @@ pub struct MultiSegmentLogIterator { /// The raw payload data of the current record. current_payload: Vec, + + /// Runtime for preloading segments + runtime: Arc, } impl MultiSegmentLogIterator { @@ -1177,6 +1185,7 @@ impl MultiSegmentLogIterator { table_id: Option, start: SequenceNumber, end: SequenceNumber, + runtime: Arc, ) -> Result { let relevant_segments = segment_manager.get_relevant_segments(table_id, start, end)?; @@ -1185,12 +1194,14 @@ impl MultiSegmentLogIterator { segments: relevant_segments, current_segment_idx: 0, current_iterator: None, + next_segment_iterator: None, log_encoding, record_encoding, table_id, start, end, current_payload: Vec::new(), + runtime, }; // Load the first segment iterator @@ -1199,25 +1210,88 @@ impl MultiSegmentLogIterator { Ok(iter) } + fn preload_next_segment(&mut self) { + assert!(self.next_segment_iterator.is_none()); + if self.current_segment_idx >= self.segments.len() { + return; + } + + let next_segment_idx = self.current_segment_idx; + let segment = self.segments[next_segment_idx].clone(); + let segment_manager = self.segment_manager.clone(); + let log_encoding = self.log_encoding.clone(); + let record_encoding = self.record_encoding.clone(); + let table_id = self.table_id; + let start = self.start; + let end = self.end; + + // Spawn an async task to preload the next SegmentLogIterator + let handle = self.runtime.spawn(async move { + let mut guard = segment.lock().unwrap(); + // Open the segment if it is not open + segment_manager.open_segment(&mut guard, segment.clone())?; + let segment_content = guard.read(0, guard.current_size)?; + let table_ranges = guard.table_ranges.clone(); + let records = + decode_segment_content(&segment_content, &guard.record_position, &record_encoding)?; + let iterator = SegmentLogIterator::new_with_records( + log_encoding, + record_encoding, + records, + table_ranges, + table_id, + start, + end, + )?; + Ok(iterator) + }); + + self.next_segment_iterator = Some(handle); + } + fn load_next_segment_iterator(&mut self) -> Result { if self.current_segment_idx >= self.segments.len() { self.current_iterator = None; return Ok(false); } - let segment = self.segments[self.current_segment_idx].clone(); - let iterator = SegmentLogIterator::new( - self.log_encoding.clone(), - self.record_encoding.clone(), - segment, - self.segment_manager.clone(), - self.table_id, - self.start, - self.end, - )?; + if let Some(handle) = self.next_segment_iterator.take() { + // Wait for the future to complete + let iterator = self + .runtime + .block_on(handle) + .map_err(anyhow::Error::new) + .context(Internal)??; + self.current_iterator = Some(iterator); + self.current_segment_idx += 1; + } else { + // Preload was not set, load synchronously + let segment = self.segments[self.current_segment_idx].clone(); + let mut guard = segment.lock().unwrap(); + self.segment_manager + .open_segment(&mut guard, segment.clone())?; + let segment_content = guard.read(0, guard.current_size)?; + let table_ranges = guard.table_ranges.clone(); + let records = decode_segment_content( + &segment_content, + &guard.record_position, + &self.record_encoding, + )?; + let iterator = SegmentLogIterator::new_with_records( + self.log_encoding.clone(), + self.record_encoding.clone(), + records, + table_ranges, + self.table_id, + self.start, + self.end, + )?; + self.current_iterator = Some(iterator); + self.current_segment_idx += 1; + } - self.current_iterator = Some(iterator); - self.current_segment_idx += 1; + // Preload the next segment + self.preload_next_segment(); Ok(true) } diff --git a/src/wal/src/manager.rs b/src/wal/src/manager.rs index 9c4a960b51..fcd017dc25 100644 --- a/src/wal/src/manager.rs +++ b/src/wal/src/manager.rs @@ -27,6 +27,7 @@ use common_types::{ }; pub use error::*; use generic_error::BoxError; +use rayon::{iter::ParallelIterator, prelude::IntoParallelIterator}; use runtime::Runtime; use snafu::ResultExt; @@ -428,13 +429,29 @@ impl BatchLogIteratorAdapter { let batch_size = self.batch_size; let (log_entries, iter_opt) = runtime .spawn_blocking(move || { - while buffer.len() < batch_size { + let mut raw_entries = Vec::new(); + + while raw_entries.len() < batch_size { if let Some(raw_log_entry) = iter.next_log_entry()? { if !filter(raw_log_entry.table_id) { continue; } - let mut raw_payload = raw_log_entry.payload; + raw_entries.push(LogEntry { + table_id: raw_log_entry.table_id, + sequence: raw_log_entry.sequence, + payload: raw_log_entry.payload.to_vec(), + }); + } else { + break; + } + } + + // Decoding is time-consuming, so we do it in parallel. + let result: Result> = raw_entries + .into_par_iter() + .map(|raw_log_entry| { + let mut raw_payload = raw_log_entry.payload.as_slice(); let ctx = PayloadDecodeContext { table_id: raw_log_entry.table_id, }; @@ -442,18 +459,20 @@ impl BatchLogIteratorAdapter { .decode(&ctx, &mut raw_payload) .box_err() .context(error::Decoding)?; - let log_entry = LogEntry { + Ok(LogEntry { table_id: raw_log_entry.table_id, sequence: raw_log_entry.sequence, payload, - }; - buffer.push_back(log_entry); - } else { - return Ok((buffer, None)); - } - } + }) + }) + .collect(); - Ok((buffer, Some(iter))) + let log_entries = result?; + if log_entries.len() < batch_size { + Ok((log_entries, None)) + } else { + Ok((log_entries, Some(iter))) + } }) .await .context(RuntimeExec)??;