diff --git a/src/batch/src/executor/hash_agg.rs b/src/batch/src/executor/hash_agg.rs index 0a3d70e29c7a7..cfea29fb63303 100644 --- a/src/batch/src/executor/hash_agg.rs +++ b/src/batch/src/executor/hash_agg.rs @@ -236,8 +236,8 @@ impl HashAggExecutor { self.mem_context.add(memory_usage_diff); } - // generate output data chunks - let mut result = groups.into_iter(); + // Don't use `into_iter` here, it may cause memory leak. + let mut result = groups.iter_mut(); let cardinality = self.chunk_size; loop { let mut group_builders: Vec<_> = self @@ -259,9 +259,9 @@ impl HashAggExecutor { array_len += 1; key.deserialize_to_builders(&mut group_builders[..], &self.group_key_types)?; states - .into_iter() + .iter_mut() .zip_eq_fast(&mut agg_builders) - .try_for_each(|(mut aggregator, builder)| aggregator.output(builder))?; + .try_for_each(|(aggregator, builder)| aggregator.output(builder))?; } if !has_next { break; // exit loop @@ -281,6 +281,11 @@ impl HashAggExecutor { #[cfg(test)] mod tests { + use std::alloc::{AllocError, Allocator, Global, Layout}; + use std::ptr::NonNull; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use prometheus::IntGauge; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::test_prelude::DataChunkTestExt; @@ -296,9 +301,11 @@ mod tests { #[tokio::test] async fn execute_int32_grouped() { - let src_exec = Box::new(MockExecutor::with_chunk( - DataChunk::from_pretty( - "i i i + let parent_mem = MemoryContext::root(IntGauge::new("root_memory_usage", " ").unwrap()); + { + let src_exec = Box::new(MockExecutor::with_chunk( + DataChunk::from_pretty( + "i i i 0 1 1 1 1 1 0 0 1 @@ -307,67 +314,74 @@ mod tests { 0 0 2 1 1 3 0 1 2", - ), - Schema::new(vec![ - Field::unnamed(DataType::Int32), - Field::unnamed(DataType::Int32), - Field::unnamed(DataType::Int64), - ]), - )); - - let agg_call = AggCall { - r#type: Type::Sum as i32, - args: vec![InputRef { - index: 2, - r#type: Some(PbDataType { - type_name: TypeName::Int32 as i32, + ), + Schema::new(vec![ + Field::unnamed(DataType::Int32), + Field::unnamed(DataType::Int32), + Field::unnamed(DataType::Int64), + ]), + )); + + let agg_call = AggCall { + r#type: Type::Sum as i32, + args: vec![InputRef { + index: 2, + r#type: Some(PbDataType { + type_name: TypeName::Int32 as i32, + ..Default::default() + }), + }], + return_type: Some(PbDataType { + type_name: TypeName::Int64 as i32, ..Default::default() }), - }], - return_type: Some(PbDataType { - type_name: TypeName::Int64 as i32, - ..Default::default() - }), - distinct: false, - order_by: vec![], - filter: None, - }; - - let agg_prost = HashAggNode { - group_key: vec![0, 1], - agg_calls: vec![agg_call], - }; - - let mem_context = MemoryContext::root(IntGauge::new("memory_usage", " ").unwrap()); - let actual_exec = HashAggExecutorBuilder::deserialize( - &agg_prost, - src_exec, - TaskId::default(), - "HashAggExecutor".to_string(), - CHUNK_SIZE, - mem_context.clone(), - ) - .unwrap(); - - // TODO: currently the order is fixed unless the hasher is changed - let expect_exec = Box::new(MockExecutor::with_chunk( - DataChunk::from_pretty( - "i i I + distinct: false, + order_by: vec![], + filter: None, + }; + + let agg_prost = HashAggNode { + group_key: vec![0, 1], + agg_calls: vec![agg_call], + }; + + let mem_context = MemoryContext::new( + Some(parent_mem.clone()), + IntGauge::new("memory_usage", " ").unwrap(), + ); + let actual_exec = HashAggExecutorBuilder::deserialize( + &agg_prost, + src_exec, + TaskId::default(), + "HashAggExecutor".to_string(), + CHUNK_SIZE, + mem_context.clone(), + ) + .unwrap(); + + // TODO: currently the order is fixed unless the hasher is changed + let expect_exec = Box::new(MockExecutor::with_chunk( + DataChunk::from_pretty( + "i i I 1 0 1 0 0 3 0 1 3 1 1 6", - ), - Schema::new(vec![ - Field::unnamed(DataType::Int32), - Field::unnamed(DataType::Int32), - Field::unnamed(DataType::Int64), - ]), - )); - diff_executor_output(actual_exec, expect_exec).await; - - // check estimated memory usage = 4 groups x state size - assert_eq!(mem_context.get_bytes_used() as usize, 4 * 72); + ), + Schema::new(vec![ + Field::unnamed(DataType::Int32), + Field::unnamed(DataType::Int32), + Field::unnamed(DataType::Int64), + ]), + )); + diff_executor_output(actual_exec, expect_exec).await; + + // check estimated memory usage = 4 groups x state size + assert_eq!(mem_context.get_bytes_used() as usize, 4 * 72); + } + + // Ensure that agg memory counter has been dropped. + assert_eq!(0, parent_mem.get_bytes_used()); } #[tokio::test] @@ -423,4 +437,61 @@ mod tests { ); diff_executor_output(actual_exec, Box::new(expect_exec)).await; } + + /// A test to verify that `HashMap` may leak memory counter when using `into_iter`. + #[test] + fn test_hashmap_into_iter_bug() { + let dropped: Arc = Arc::new(AtomicBool::new(false)); + + { + struct MyAllocInner { + drop_flag: Arc, + } + + #[derive(Clone)] + struct MyAlloc { + inner: Arc, + } + + impl Drop for MyAllocInner { + fn drop(&mut self) { + println!("MyAlloc freed."); + self.drop_flag.store(true, Ordering::SeqCst); + } + } + + unsafe impl Allocator for MyAlloc { + fn allocate( + &self, + layout: Layout, + ) -> std::result::Result, AllocError> { + let g = Global; + g.allocate(layout) + } + + unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { + let g = Global; + g.deallocate(ptr, layout) + } + } + + let mut map = hashbrown::HashMap::with_capacity_in( + 10, + MyAlloc { + inner: Arc::new(MyAllocInner { + drop_flag: dropped.clone(), + }), + }, + ); + for i in 0..10 { + map.entry(i).or_insert_with(|| "i".to_string()); + } + + for (k, v) in map { + println!("{}, {}", k, v); + } + } + + assert!(!dropped.load(Ordering::SeqCst)); + } } diff --git a/src/batch/src/executor/join/hash_join.rs b/src/batch/src/executor/join/hash_join.rs index f63d54c30bb85..ccf45fe920028 100644 --- a/src/batch/src/executor/join/hash_join.rs +++ b/src/batch/src/executor/join/hash_join.rs @@ -1844,6 +1844,7 @@ impl HashJoinExecutor { mod tests { use futures::StreamExt; use futures_async_stream::for_await; + use prometheus::IntGauge; use risingwave_common::array::{ArrayBuilderImpl, DataChunk}; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::error::Result; @@ -2050,6 +2051,7 @@ mod tests { left_child: BoxedExecutor, right_child: BoxedExecutor, shutdown_rx: Option>, + parent_mem_ctx: Option, ) -> BoxedExecutor { let join_type = self.join_type; @@ -2066,6 +2068,8 @@ mod tests { None }; + let mem_ctx = + MemoryContext::new(parent_mem_ctx, IntGauge::new("memory_usage", " ").unwrap()); Box::new(HashJoinExecutor::::new( join_type, output_indices, @@ -2078,7 +2082,7 @@ mod tests { "HashJoinExecutor".to_string(), chunk_size, shutdown_rx, - MemoryContext::none(), + mem_ctx, )) } @@ -2105,45 +2109,53 @@ mod tests { left_executor: BoxedExecutor, right_executor: BoxedExecutor, ) { - let join_executor = self.create_join_executor_with_chunk_size_and_executors( - has_non_equi_cond, - null_safe, - chunk_size, - left_executor, - right_executor, - None, - ); - - let mut data_chunk_merger = DataChunkMerger::new(self.output_data_types()).unwrap(); + let parent_mem_context = + MemoryContext::root(IntGauge::new("total_memory_usage", " ").unwrap()); + + { + let join_executor = self.create_join_executor_with_chunk_size_and_executors( + has_non_equi_cond, + null_safe, + chunk_size, + left_executor, + right_executor, + None, + Some(parent_mem_context.clone()), + ); + + let mut data_chunk_merger = DataChunkMerger::new(self.output_data_types()).unwrap(); + + let fields = &join_executor.schema().fields; + + if self.join_type.keep_all() { + assert_eq!(fields[1].data_type, DataType::Float32); + assert_eq!(fields[3].data_type, DataType::Float64); + } else if self.join_type.keep_left() { + assert_eq!(fields[1].data_type, DataType::Float32); + } else if self.join_type.keep_right() { + assert_eq!(fields[1].data_type, DataType::Float64) + } else { + unreachable!() + } - let fields = &join_executor.schema().fields; + let mut stream = join_executor.execute(); - if self.join_type.keep_all() { - assert_eq!(fields[1].data_type, DataType::Float32); - assert_eq!(fields[3].data_type, DataType::Float64); - } else if self.join_type.keep_left() { - assert_eq!(fields[1].data_type, DataType::Float32); - } else if self.join_type.keep_right() { - assert_eq!(fields[1].data_type, DataType::Float64) - } else { - unreachable!() - } + while let Some(data_chunk) = stream.next().await { + let data_chunk = data_chunk.unwrap(); + let data_chunk = data_chunk.compact(); + data_chunk_merger.append(&data_chunk).unwrap(); + } - let mut stream = join_executor.execute(); + let result_chunk = data_chunk_merger.finish().unwrap(); + println!("expected: {:?}", expected); + println!("result: {:?}", result_chunk); - while let Some(data_chunk) = stream.next().await { - let data_chunk = data_chunk.unwrap(); - let data_chunk = data_chunk.compact(); - data_chunk_merger.append(&data_chunk).unwrap(); + // TODO: Replace this with unsorted comparison + // assert_eq!(expected, result_chunk); + assert!(is_data_chunk_eq(&expected, &result_chunk)); } - let result_chunk = data_chunk_merger.finish().unwrap(); - println!("expected: {:?}", expected); - println!("result: {:?}", result_chunk); - - // TODO: Replace this with unsorted comparison - // assert_eq!(expected, result_chunk); - assert!(is_data_chunk_eq(&expected, &result_chunk)); + assert_eq!(0, parent_mem_context.get_bytes_used()); } async fn do_test_shutdown(&self, has_non_equi_cond: bool) { @@ -2158,6 +2170,7 @@ mod tests { left_executor, right_executor, Some(shutdown_rx), + None, ); shutdown_tx.send(ShutdownMsg::Cancel).unwrap(); #[for_await] @@ -2177,6 +2190,7 @@ mod tests { left_executor, right_executor, Some(shutdown_rx), + None, ); shutdown_tx .send(ShutdownMsg::Abort("Test".to_string())) diff --git a/src/common/src/memory/mem_context.rs b/src/common/src/memory/mem_context.rs index ba708f649dd99..d9b5eaa29cace 100644 --- a/src/common/src/memory/mem_context.rs +++ b/src/common/src/memory/mem_context.rs @@ -31,6 +31,7 @@ pub struct MemoryContext { inner: Option>, } +#[derive(Debug)] pub enum MemCounter { /// Used when the add/sub operation don't have much conflicts. Atomic(IntGauge), @@ -68,11 +69,9 @@ impl From for MemCounter { impl MemoryContext { pub fn new(parent: Option, counter: impl Into) -> Self { + let c = counter.into(); Self { - inner: Some(Arc::new(MemoryContextInner { - counter: counter.into(), - parent, - })), + inner: Some(Arc::new(MemoryContextInner { counter: c, parent })), } } diff --git a/src/common/src/metrics.rs b/src/common/src/metrics.rs index bddd4447a7098..c6e3f11572366 100644 --- a/src/common/src/metrics.rs +++ b/src/common/src/metrics.rs @@ -19,6 +19,7 @@ use tracing::Subscriber; use tracing_subscriber::layer::Context; use tracing_subscriber::registry::LookupSpan; use tracing_subscriber::Layer; +#[derive(Debug)] pub struct TrAdderAtomic(TrAdder); impl Atomic for TrAdderAtomic {