Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix bug with merge tasks that allows for tasks larger than max size allowed #1882

Merged
merged 4 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 57 additions & 62 deletions src/daft-scan/src/scan_task_iters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,74 +45,69 @@ struct MergeByFileSize {
accumulator: Option<ScanTaskRef>,
}

impl MergeByFileSize {
fn accumulator_ready(&self) -> bool {
if let Some(acc) = &self.accumulator && let Some(acc_bytes) = acc.size_bytes() && acc_bytes >= self.min_size_bytes {
true
} else {
false
}
}

fn can_merge(&self, other: &ScanTask) -> bool {
let accumulator = self
.accumulator
.as_ref()
.expect("accumulator should be populated");
let child_matches_accumulator = other.partition_spec() == accumulator.partition_spec()
&& other.file_format_config == accumulator.file_format_config
&& other.schema == accumulator.schema
&& other.storage_config == accumulator.storage_config
&& other.pushdowns == accumulator.pushdowns;

let sum_smaller_than_max_size_bytes = if let Some(child_bytes) = other.size_bytes()
&& let Some(accumulator_bytes) = accumulator.size_bytes() {child_bytes + accumulator_bytes <= self.max_size_bytes} else {false};

child_matches_accumulator && sum_smaller_than_max_size_bytes
}
}

impl Iterator for MergeByFileSize {
type Item = DaftResult<ScanTaskRef>;

fn next(&mut self) -> Option<Self::Item> {
loop {
// Grabs the accumulator, leaving a `None` in its place
let accumulator = self.accumulator.take();

match (self.iter.next(), accumulator) {
// When no accumulator exists, trivially place the ScanTask into the accumulator
(Some(Ok(child_item)), None) => {
self.accumulator = Some(child_item);
continue;
}
// When an accumulator exists, attempt a merge and yield the result
(Some(Ok(child_item)), Some(accumulator)) => {
// Whether or not the accumulator and the current item should be merged
let should_merge = {
let child_matches_accumulator = child_item.partition_spec()
== accumulator.partition_spec()
&& child_item.file_format_config == accumulator.file_format_config
&& child_item.schema == accumulator.schema
&& child_item.storage_config == accumulator.storage_config
&& child_item.pushdowns == accumulator.pushdowns;
let smaller_than_max_size_bytes = matches!(
(child_item.size_bytes(), accumulator.size_bytes()),
(Some(child_item_size), Some(buffered_item_size)) if child_item_size + buffered_item_size <= self.max_size_bytes
);
child_matches_accumulator && smaller_than_max_size_bytes
};

if should_merge {
let merged_result = Some(Arc::new(
ScanTask::merge(accumulator.as_ref(), child_item.as_ref())
.expect("ScanTasks should be mergeable in MergeByFileSize"),
));

// Whether or not we should immediately yield the merged result, or keep accumulating
let should_yield = matches!(
(child_item.size_bytes(), accumulator.size_bytes()),
(Some(child_item_size), Some(buffered_item_size)) if child_item_size + buffered_item_size >= self.min_size_bytes
);

// Either yield eagerly, or keep looping with a merged accumulator
if should_yield {
return Ok(merged_result).transpose();
} else {
self.accumulator = merged_result;
continue;
}
} else {
self.accumulator = Some(child_item);
return Some(Ok(accumulator));
}
}
// Bubble up errors from child iterator, making sure to replace the accumulator which we moved
(Some(Err(e)), acc) => {
self.accumulator = acc;
return Some(Err(e));
}
// Iterator ran out of elements: ensure that we flush the last buffered ScanTask
(None, Some(last_scan_task)) => {
return Some(Ok(last_scan_task));
}
(None, None) => {
return None;
}
if self.accumulator.is_none() {
self.accumulator = match self.iter.next() {
Some(Ok(item)) => Some(item),
e @ Some(Err(_)) => return e,
None => return None,
};
}

if self.accumulator_ready() {
return self.accumulator.take().map(Ok);
}

let next_item = match self.iter.next() {
Some(Ok(item)) => item,
e @ Some(Err(_)) => return e,
None => return self.accumulator.take().map(Ok),
};

if next_item.size_bytes().is_none() || !self.can_merge(&next_item) {
return self.accumulator.replace(next_item).map(Ok);
}

self.accumulator = Some(Arc::new(
ScanTask::merge(
self.accumulator
.as_ref()
.expect("accumulator should be populated"),
next_item.as_ref(),
)
.expect("ScanTasks should be mergeable in MergeByFileSize"),
));
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions tests/io/test_merge_scan_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,20 @@ def test_merge_scan_task_exceed_max(csv_files):

@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions")
def test_merge_scan_task_below_max(csv_files):
with override_merge_scan_tasks_configs(1, 20):
with override_merge_scan_tasks_configs(21, 22):
df = daft.read_csv(str(csv_files))
assert (
df.num_partitions() == 2
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the second merge is too large (>20 bytes)"
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the second merge is too large (>22 bytes)"


@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions")
def test_merge_scan_task_above_min(csv_files):
with override_merge_scan_tasks_configs(0, 40):
with override_merge_scan_tasks_configs(19, 40):
df = daft.read_csv(str(csv_files))
assert (
df.num_partitions() == 2
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the first merge is above the minimum (>0 bytes)"
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the first merge is above the minimum (>19 bytes)"


@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions")
Expand Down
Loading