Skip to content

Commit

Permalink
Prevent memory overflows (and spills) on sorts with a fixed limit (#3593
Browse files Browse the repository at this point in the history
)
  • Loading branch information
isidentical authored Sep 24, 2022
1 parent 8bcc965 commit 696a0b5
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
7 changes: 7 additions & 0 deletions datafusion/core/src/physical_plan/metrics/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ impl Gauge {
self.value.fetch_add(n, Ordering::Relaxed);
}

/// Sub `n` from the metric's value
pub fn sub(&self, n: usize) {
// relaxed ordering for operations on `value` poses no issues
// we're purely using atomic ops with no associated memory ops
self.value.fetch_sub(n, Ordering::Relaxed);
}

/// Set the metric's value to `n` and return the previous value
pub fn set(&self, n: usize) -> usize {
// relaxed ordering for operations on `value` poses no issues
Expand Down
74 changes: 74 additions & 0 deletions datafusion/core/src/physical_plan/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ impl ExternalSorter {
// calls to `timer.done()` below.
let _timer = tracking_metrics.elapsed_compute().timer();
let partial = sort_batch(input, self.schema.clone(), &self.expr, self.fetch)?;
// The resulting batch might be smaller than the input batch if there
// is an propagated limit.

if self.fetch.is_some() {
let new_size = batch_byte_size(&partial.sorted_batch);
let size_delta = size.checked_sub(new_size).ok_or_else(|| {
DataFusionError::Internal(format!(
"The size of the sorted batch is larger than the size of the input batch: {} > {}",
size,
new_size
))
})?;
self.shrink(size_delta);
self.metrics.mem_used().sub(size_delta);
}
in_mem_batches.push(partial);
}
Ok(())
Expand Down Expand Up @@ -1062,6 +1077,65 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_sort_fetch_memory_calculation() -> Result<()> {
// This test mirrors down the size from the example above.
let avg_batch_size = 5336;
let partitions = 4;

// A tuple of (fetch, expect_spillage)
let test_options = vec![
// Since we don't have a limit (and the memory is less than the total size of
// all the batches we are processing, we expect it to spill.
(None, true),
// When we have a limit however, the buffered size of batches should fit in memory
// since it is much lover than the total size of the input batch.
(Some(1), false),
];

for (fetch, expect_spillage) in test_options {
let config = RuntimeConfig::new()
.with_memory_limit(avg_batch_size * (partitions - 1), 1.0);
let runtime = Arc::new(RuntimeEnv::new(config)?);
let session_ctx =
SessionContext::with_config_rt(SessionConfig::new(), runtime);

let csv = test::scan_partitioned_csv(partitions)?;
let schema = csv.schema();

let sort_exec = Arc::new(SortExec::try_new(
vec![
// c1 string column
PhysicalSortExpr {
expr: col("c1", &schema)?,
options: SortOptions::default(),
},
// c2 uin32 column
PhysicalSortExpr {
expr: col("c2", &schema)?,
options: SortOptions::default(),
},
// c7 uin8 column
PhysicalSortExpr {
expr: col("c7", &schema)?,
options: SortOptions::default(),
},
],
Arc::new(CoalescePartitionsExec::new(csv)),
fetch,
)?);

let task_ctx = session_ctx.task_ctx();
let result = collect(sort_exec.clone(), task_ctx).await?;
assert_eq!(result.len(), 1);

let metrics = sort_exec.metrics().unwrap();
let did_it_spill = metrics.spill_count().unwrap() > 0;
assert_eq!(did_it_spill, expect_spillage, "with fetch: {:?}", fetch);
}
Ok(())
}

#[tokio::test]
async fn test_sort_metadata() -> Result<()> {
let session_ctx = SessionContext::new();
Expand Down

0 comments on commit 696a0b5

Please sign in to comment.