Skip to content

Commit

Permalink
simpliy mem pool
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Oct 30, 2024
1 parent 8297f58 commit b652802
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 77 deletions.
54 changes: 2 additions & 52 deletions datafusion/execution/src/memory_pool/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,7 @@ impl MemoryPool for UnboundedMemoryPool {
#[derive(Debug)]
pub struct GreedyMemoryPool {
pool_size: usize,
// Pool size limit for each consumer, if one of the consumer exceeds the limit, error is returned
pool_size_per_consumer: HashMap<String, usize>,
used: AtomicUsize,
// Memory usage for each consumer, used to check aginst `pool_size_per_consumer`
used_per_consumer: RwLock<HashMap<String, AtomicUsize>>,
}

impl GreedyMemoryPool {
Expand All @@ -71,67 +67,21 @@ impl GreedyMemoryPool {
debug!("Created new GreedyMemoryPool(pool_size={pool_size})");
Self {
pool_size,
pool_size_per_consumer: Default::default(),
used: AtomicUsize::new(0),
used_per_consumer: RwLock::new(HashMap::new()),
}
}

pub fn with_pool_size_per_consumer(
mut self,
pool_size_per_consumer: HashMap<String, usize>,
) -> Self {
self.pool_size_per_consumer = pool_size_per_consumer;
self
}
}

impl MemoryPool for GreedyMemoryPool {
fn grow(&self, reservation: &MemoryReservation, additional: usize) {
let consumer_name = reservation.consumer().name();
fn grow(&self, _reservation: &MemoryReservation, additional: usize) {
self.used.fetch_add(additional, Ordering::Relaxed);

let mut used_per_consumer = self.used_per_consumer.write();
let consumer_usage = used_per_consumer
.entry(consumer_name.to_string())
.or_insert_with(|| AtomicUsize::new(0));
consumer_usage.fetch_add(additional, Ordering::Relaxed);
}

fn shrink(&self, reservation: &MemoryReservation, shrink: usize) {
let consumer_name = reservation.consumer().name();

fn shrink(&self, _reservation: &MemoryReservation, shrink: usize) {
self.used.fetch_sub(shrink, Ordering::Relaxed);

let mut used_per_consumer = self.used_per_consumer.write();
let consumer_usage = used_per_consumer
.entry(consumer_name.to_string())
.or_insert_with(|| AtomicUsize::new(0));
consumer_usage.fetch_sub(shrink, Ordering::Relaxed);
}

fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> {
let consumer_name = reservation.consumer().name();

if let Some(pool_size) = self.pool_size_per_consumer.get(consumer_name) {
let mut used_per_consumer = self.used_per_consumer.write();
let consumer_usage = used_per_consumer
.entry(consumer_name.to_string())
.or_insert_with(|| AtomicUsize::new(0));
consumer_usage
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |used| {
let new_used = used + additional;
(new_used <= *pool_size).then_some(new_used)
})
.map_err(|used| {
insufficient_capacity_err(
reservation,
additional,
pool_size.saturating_sub(used),
)
})?;
}

self.used
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |used| {
let new_used = used + additional;
Expand Down
17 changes: 0 additions & 17 deletions datafusion/execution/src/runtime_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,6 @@ impl RuntimeEnvBuilder {
)))
}

/// Set memory limit per consumer, if not set, by default is the same as the total pool size
/// For example, if pool size is 4000, repartition is 3000. Total pool size: 4000,
/// RepartitionExec pool size: 3000, SortPreservingMergeExec pool size: 4000
pub fn with_memory_limit_per_consumer(
self,
max_memory: usize,
memory_fraction: f64,
pool_size_per_consumer: HashMap<String, usize>,
) -> Self {
let pool_size = (max_memory as f64 * memory_fraction) as usize;
self.with_memory_pool(Arc::new(TrackConsumersPool::new(
GreedyMemoryPool::new(pool_size)
.with_pool_size_per_consumer(pool_size_per_consumer),
NonZeroUsize::new(5).unwrap(),
)))
}

/// Use the specified path to create any needed temporary files
pub fn with_temp_file_path(self, path: impl Into<PathBuf>) -> Self {
self.with_disk_manager(DiskManagerConfig::new_specified(vec![path.into()]))
Expand Down
9 changes: 1 addition & 8 deletions datafusion/physical-plan/src/sorts/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,18 +355,11 @@ mod tests {
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

use futures::{FutureExt, Stream, StreamExt};
use hashbrown::HashMap;
use tokio::time::timeout;

fn generate_task_ctx_for_round_robin_tie_breaker() -> Result<Arc<TaskContext>> {
let mut pool_per_consumer = HashMap::new();
// Bytes from 660_000 to 30_000_000 (or even more) are all valid limits
pool_per_consumer.insert("RepartitionExec[0]".to_string(), 10_000_000);
pool_per_consumer.insert("RepartitionExec[1]".to_string(), 10_000_000);

let runtime = RuntimeEnvBuilder::new()
// Random large number for total mem limit, we only care about RepartitionExec only
.with_memory_limit_per_consumer(2_000_000_000, 1.0, pool_per_consumer)
.with_memory_limit(20_000_000, 1.0)
.build_arc()?;
let config = SessionConfig::new();
let task_ctx = TaskContext::default()
Expand Down

0 comments on commit b652802

Please sign in to comment.