From 74591c1a5eaa88b798f2f8be85de117104f49cdd Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Wed, 14 Feb 2024 11:31:50 -0800 Subject: [PATCH] [BUG] Fix bug with merge tasks that allows for tasks larger than max size allowed (#1882) image --- src/daft-scan/src/scan_task_iters.rs | 119 +++++++++++++-------------- tests/io/test_merge_scan_tasks.py | 8 +- 2 files changed, 61 insertions(+), 66 deletions(-) diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index f8410eb595..fe3d470a2f 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -45,74 +45,69 @@ struct MergeByFileSize { accumulator: Option, } +impl MergeByFileSize { + fn accumulator_ready(&self) -> bool { + if let Some(acc) = &self.accumulator && let Some(acc_bytes) = acc.size_bytes() && acc_bytes >= self.min_size_bytes { + true + } else { + false + } + } + + fn can_merge(&self, other: &ScanTask) -> bool { + let accumulator = self + .accumulator + .as_ref() + .expect("accumulator should be populated"); + let child_matches_accumulator = other.partition_spec() == accumulator.partition_spec() + && other.file_format_config == accumulator.file_format_config + && other.schema == accumulator.schema + && other.storage_config == accumulator.storage_config + && other.pushdowns == accumulator.pushdowns; + + let sum_smaller_than_max_size_bytes = if let Some(child_bytes) = other.size_bytes() + && let Some(accumulator_bytes) = accumulator.size_bytes() {child_bytes + accumulator_bytes <= self.max_size_bytes} else {false}; + + child_matches_accumulator && sum_smaller_than_max_size_bytes + } +} + impl Iterator for MergeByFileSize { type Item = DaftResult; fn next(&mut self) -> Option { loop { - // Grabs the accumulator, leaving a `None` in its place - let accumulator = self.accumulator.take(); - - match (self.iter.next(), accumulator) { - // When no accumulator exists, trivially place the ScanTask into the accumulator - (Some(Ok(child_item)), None) => { - self.accumulator = Some(child_item); - continue; - } - // When an accumulator exists, attempt a merge and yield the result - (Some(Ok(child_item)), Some(accumulator)) => { - // Whether or not the accumulator and the current item should be merged - let should_merge = { - let child_matches_accumulator = child_item.partition_spec() - == accumulator.partition_spec() - && child_item.file_format_config == accumulator.file_format_config - && child_item.schema == accumulator.schema - && child_item.storage_config == accumulator.storage_config - && child_item.pushdowns == accumulator.pushdowns; - let smaller_than_max_size_bytes = matches!( - (child_item.size_bytes(), accumulator.size_bytes()), - (Some(child_item_size), Some(buffered_item_size)) if child_item_size + buffered_item_size <= self.max_size_bytes - ); - child_matches_accumulator && smaller_than_max_size_bytes - }; - - if should_merge { - let merged_result = Some(Arc::new( - ScanTask::merge(accumulator.as_ref(), child_item.as_ref()) - .expect("ScanTasks should be mergeable in MergeByFileSize"), - )); - - // Whether or not we should immediately yield the merged result, or keep accumulating - let should_yield = matches!( - (child_item.size_bytes(), accumulator.size_bytes()), - (Some(child_item_size), Some(buffered_item_size)) if child_item_size + buffered_item_size >= self.min_size_bytes - ); - - // Either yield eagerly, or keep looping with a merged accumulator - if should_yield { - return Ok(merged_result).transpose(); - } else { - self.accumulator = merged_result; - continue; - } - } else { - self.accumulator = Some(child_item); - return Some(Ok(accumulator)); - } - } - // Bubble up errors from child iterator, making sure to replace the accumulator which we moved - (Some(Err(e)), acc) => { - self.accumulator = acc; - return Some(Err(e)); - } - // Iterator ran out of elements: ensure that we flush the last buffered ScanTask - (None, Some(last_scan_task)) => { - return Some(Ok(last_scan_task)); - } - (None, None) => { - return None; - } + if self.accumulator.is_none() { + self.accumulator = match self.iter.next() { + Some(Ok(item)) => Some(item), + e @ Some(Err(_)) => return e, + None => return None, + }; + } + + if self.accumulator_ready() { + return self.accumulator.take().map(Ok); } + + let next_item = match self.iter.next() { + Some(Ok(item)) => item, + e @ Some(Err(_)) => return e, + None => return self.accumulator.take().map(Ok), + }; + + if next_item.size_bytes().is_none() || !self.can_merge(&next_item) { + return self.accumulator.replace(next_item).map(Ok); + } + + self.accumulator = Some(Arc::new( + ScanTask::merge( + self.accumulator + .as_ref() + .expect("accumulator should be populated"), + next_item.as_ref(), + ) + .expect("ScanTasks should be mergeable in MergeByFileSize"), + )); } } } diff --git a/tests/io/test_merge_scan_tasks.py b/tests/io/test_merge_scan_tasks.py index c3fbd96dc3..75c426874b 100644 --- a/tests/io/test_merge_scan_tasks.py +++ b/tests/io/test_merge_scan_tasks.py @@ -44,20 +44,20 @@ def test_merge_scan_task_exceed_max(csv_files): @pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions") def test_merge_scan_task_below_max(csv_files): - with override_merge_scan_tasks_configs(1, 20): + with override_merge_scan_tasks_configs(21, 22): df = daft.read_csv(str(csv_files)) assert ( df.num_partitions() == 2 - ), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the second merge is too large (>20 bytes)" + ), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the second merge is too large (>22 bytes)" @pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions") def test_merge_scan_task_above_min(csv_files): - with override_merge_scan_tasks_configs(0, 40): + with override_merge_scan_tasks_configs(19, 40): df = daft.read_csv(str(csv_files)) assert ( df.num_partitions() == 2 - ), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the first merge is above the minimum (>0 bytes)" + ), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the first merge is above the minimum (>19 bytes)" @pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions")