Skip to content

Commit

Permalink
[BUG] Fix bug with merge tasks that allows for tasks larger than max …
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 authored Feb 14, 2024
1 parent cf77fd2 commit 74591c1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 66 deletions.
119 changes: 57 additions & 62 deletions src/daft-scan/src/scan_task_iters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,74 +45,69 @@ struct MergeByFileSize {
accumulator: Option<ScanTaskRef>,
}

impl MergeByFileSize {
fn accumulator_ready(&self) -> bool {
if let Some(acc) = &self.accumulator && let Some(acc_bytes) = acc.size_bytes() && acc_bytes >= self.min_size_bytes {
true
} else {
false
}
}

fn can_merge(&self, other: &ScanTask) -> bool {
let accumulator = self
.accumulator
.as_ref()
.expect("accumulator should be populated");
let child_matches_accumulator = other.partition_spec() == accumulator.partition_spec()
&& other.file_format_config == accumulator.file_format_config
&& other.schema == accumulator.schema
&& other.storage_config == accumulator.storage_config
&& other.pushdowns == accumulator.pushdowns;

let sum_smaller_than_max_size_bytes = if let Some(child_bytes) = other.size_bytes()
&& let Some(accumulator_bytes) = accumulator.size_bytes() {child_bytes + accumulator_bytes <= self.max_size_bytes} else {false};

child_matches_accumulator && sum_smaller_than_max_size_bytes
}
}

impl Iterator for MergeByFileSize {
type Item = DaftResult<ScanTaskRef>;

fn next(&mut self) -> Option<Self::Item> {
loop {
// Grabs the accumulator, leaving a `None` in its place
let accumulator = self.accumulator.take();

match (self.iter.next(), accumulator) {
// When no accumulator exists, trivially place the ScanTask into the accumulator
(Some(Ok(child_item)), None) => {
self.accumulator = Some(child_item);
continue;
}
// When an accumulator exists, attempt a merge and yield the result
(Some(Ok(child_item)), Some(accumulator)) => {
// Whether or not the accumulator and the current item should be merged
let should_merge = {
let child_matches_accumulator = child_item.partition_spec()
== accumulator.partition_spec()
&& child_item.file_format_config == accumulator.file_format_config
&& child_item.schema == accumulator.schema
&& child_item.storage_config == accumulator.storage_config
&& child_item.pushdowns == accumulator.pushdowns;
let smaller_than_max_size_bytes = matches!(
(child_item.size_bytes(), accumulator.size_bytes()),
(Some(child_item_size), Some(buffered_item_size)) if child_item_size + buffered_item_size <= self.max_size_bytes
);
child_matches_accumulator && smaller_than_max_size_bytes
};

if should_merge {
let merged_result = Some(Arc::new(
ScanTask::merge(accumulator.as_ref(), child_item.as_ref())
.expect("ScanTasks should be mergeable in MergeByFileSize"),
));

// Whether or not we should immediately yield the merged result, or keep accumulating
let should_yield = matches!(
(child_item.size_bytes(), accumulator.size_bytes()),
(Some(child_item_size), Some(buffered_item_size)) if child_item_size + buffered_item_size >= self.min_size_bytes
);

// Either yield eagerly, or keep looping with a merged accumulator
if should_yield {
return Ok(merged_result).transpose();
} else {
self.accumulator = merged_result;
continue;
}
} else {
self.accumulator = Some(child_item);
return Some(Ok(accumulator));
}
}
// Bubble up errors from child iterator, making sure to replace the accumulator which we moved
(Some(Err(e)), acc) => {
self.accumulator = acc;
return Some(Err(e));
}
// Iterator ran out of elements: ensure that we flush the last buffered ScanTask
(None, Some(last_scan_task)) => {
return Some(Ok(last_scan_task));
}
(None, None) => {
return None;
}
if self.accumulator.is_none() {
self.accumulator = match self.iter.next() {
Some(Ok(item)) => Some(item),
e @ Some(Err(_)) => return e,
None => return None,
};
}

if self.accumulator_ready() {
return self.accumulator.take().map(Ok);
}

let next_item = match self.iter.next() {
Some(Ok(item)) => item,
e @ Some(Err(_)) => return e,
None => return self.accumulator.take().map(Ok),
};

if next_item.size_bytes().is_none() || !self.can_merge(&next_item) {
return self.accumulator.replace(next_item).map(Ok);
}

self.accumulator = Some(Arc::new(
ScanTask::merge(
self.accumulator
.as_ref()
.expect("accumulator should be populated"),
next_item.as_ref(),
)
.expect("ScanTasks should be mergeable in MergeByFileSize"),
));
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions tests/io/test_merge_scan_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,20 @@ def test_merge_scan_task_exceed_max(csv_files):

@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions")
def test_merge_scan_task_below_max(csv_files):
with override_merge_scan_tasks_configs(1, 20):
with override_merge_scan_tasks_configs(21, 22):
df = daft.read_csv(str(csv_files))
assert (
df.num_partitions() == 2
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the second merge is too large (>20 bytes)"
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the second merge is too large (>22 bytes)"


@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions")
def test_merge_scan_task_above_min(csv_files):
with override_merge_scan_tasks_configs(0, 40):
with override_merge_scan_tasks_configs(19, 40):
df = daft.read_csv(str(csv_files))
assert (
df.num_partitions() == 2
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the first merge is above the minimum (>0 bytes)"
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the first merge is above the minimum (>19 bytes)"


@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions")
Expand Down

0 comments on commit 74591c1

Please sign in to comment.