diff --git a/datafusion/core/src/physical_plan/metrics/value.rs b/datafusion/core/src/physical_plan/metrics/value.rs index 5360f272c314..5c3aeb4dcdca 100644 --- a/datafusion/core/src/physical_plan/metrics/value.rs +++ b/datafusion/core/src/physical_plan/metrics/value.rs @@ -122,6 +122,13 @@ impl Gauge { self.value.fetch_add(n, Ordering::Relaxed); } + /// Sub `n` from the metric's value + pub fn sub(&self, n: usize) { + // relaxed ordering for operations on `value` poses no issues + // we're purely using atomic ops with no associated memory ops + self.value.fetch_sub(n, Ordering::Relaxed); + } + /// Set the metric's value to `n` and return the previous value pub fn set(&self, n: usize) -> usize { // relaxed ordering for operations on `value` poses no issues diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index fb2ad091900d..fb9bb10a38ec 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -124,6 +124,21 @@ impl ExternalSorter { // calls to `timer.done()` below. let _timer = tracking_metrics.elapsed_compute().timer(); let partial = sort_batch(input, self.schema.clone(), &self.expr, self.fetch)?; + // The resulting batch might be smaller than the input batch if there + // is an propagated limit. + + if self.fetch.is_some() { + let new_size = batch_byte_size(&partial.sorted_batch); + let size_delta = size.checked_sub(new_size).ok_or_else(|| { + DataFusionError::Internal(format!( + "The size of the sorted batch is larger than the size of the input batch: {} > {}", + size, + new_size + )) + })?; + self.shrink(size_delta); + self.metrics.mem_used().sub(size_delta); + } in_mem_batches.push(partial); } Ok(()) @@ -1062,6 +1077,65 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_sort_fetch_memory_calculation() -> Result<()> { + // This test mirrors down the size from the example above. + let avg_batch_size = 5336; + let partitions = 4; + + // A tuple of (fetch, expect_spillage) + let test_options = vec![ + // Since we don't have a limit (and the memory is less than the total size of + // all the batches we are processing, we expect it to spill. + (None, true), + // When we have a limit however, the buffered size of batches should fit in memory + // since it is much lover than the total size of the input batch. + (Some(1), false), + ]; + + for (fetch, expect_spillage) in test_options { + let config = RuntimeConfig::new() + .with_memory_limit(avg_batch_size * (partitions - 1), 1.0); + let runtime = Arc::new(RuntimeEnv::new(config)?); + let session_ctx = + SessionContext::with_config_rt(SessionConfig::new(), runtime); + + let csv = test::scan_partitioned_csv(partitions)?; + let schema = csv.schema(); + + let sort_exec = Arc::new(SortExec::try_new( + vec![ + // c1 string column + PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }, + // c2 uin32 column + PhysicalSortExpr { + expr: col("c2", &schema)?, + options: SortOptions::default(), + }, + // c7 uin8 column + PhysicalSortExpr { + expr: col("c7", &schema)?, + options: SortOptions::default(), + }, + ], + Arc::new(CoalescePartitionsExec::new(csv)), + fetch, + )?); + + let task_ctx = session_ctx.task_ctx(); + let result = collect(sort_exec.clone(), task_ctx).await?; + assert_eq!(result.len(), 1); + + let metrics = sort_exec.metrics().unwrap(); + let did_it_spill = metrics.spill_count().unwrap() > 0; + assert_eq!(did_it_spill, expect_spillage, "with fetch: {:?}", fetch); + } + Ok(()) + } + #[tokio::test] async fn test_sort_metadata() -> Result<()> { let session_ctx = SessionContext::new();