Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(scantask-1): add a config flag for new scantask splitting algorithm #3615

Merged
merged 4 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def set_execution_config(
shuffle_algorithm: str | None = None,
pre_shuffle_merge_threshold: int | None = None,
enable_ray_tracing: bool | None = None,
scantask_splitting_level: int | None = None,
) -> DaftContext:
"""Globally sets various configuration parameters which control various aspects of Daft execution.

Expand Down Expand Up @@ -395,6 +396,7 @@ def set_execution_config(
shuffle_algorithm: The shuffle algorithm to use. Defaults to "map_reduce". Other options are "pre_shuffle_merge".
pre_shuffle_merge_threshold: Memory threshold in bytes for pre-shuffle merge. Defaults to 1GB
enable_ray_tracing: Enable tracing for Ray. Accessible in `/tmp/ray/session_latest/logs/daft` after the run completes. Defaults to False.
scantask_splitting_level: How aggressively to split scan tasks. Setting this to `2` will use a more aggressive ScanTask splitting algorithm which might be more expensive to run but results in more even splits of partitions. Defaults to 1.
"""
# Replace values in the DaftExecutionConfig with user-specified overrides
ctx = get_context()
Expand Down Expand Up @@ -425,6 +427,7 @@ def set_execution_config(
shuffle_algorithm=shuffle_algorithm,
pre_shuffle_merge_threshold=pre_shuffle_merge_threshold,
enable_ray_tracing=enable_ray_tracing,
scantask_splitting_level=scantask_splitting_level,
)

ctx._daft_execution_config = new_daft_execution_config
Expand Down
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,7 @@ class PyDaftExecutionConfig:
enable_ray_tracing: bool | None = None,
shuffle_algorithm: str | None = None,
pre_shuffle_merge_threshold: int | None = None,
scantask_splitting_level: int | None = None,
) -> PyDaftExecutionConfig: ...
@property
def scan_tasks_min_size_bytes(self) -> int: ...
Expand Down
6 changes: 6 additions & 0 deletions src/common/daft-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
pub shuffle_algorithm: String,
pub pre_shuffle_merge_threshold: usize,
pub enable_ray_tracing: bool,
pub scantask_splitting_level: i32,
}

impl Default for DaftExecutionConfig {
Expand Down Expand Up @@ -81,6 +82,7 @@
shuffle_algorithm: "map_reduce".to_string(),
pre_shuffle_merge_threshold: 1024 * 1024 * 1024, // 1GB
enable_ray_tracing: false,
scantask_splitting_level: 1,
}
}
}
Expand Down Expand Up @@ -118,6 +120,10 @@
if let Ok(val) = std::env::var(shuffle_algorithm_env_var_name) {
cfg.shuffle_algorithm = val;
}
let enable_aggressive_scantask_splitting_env_var_name = "DAFT_SCANTASK_SPLITTING_LEVEL";
if let Ok(val) = std::env::var(enable_aggressive_scantask_splitting_env_var_name) {
cfg.scantask_splitting_level = val.parse::<i32>().unwrap_or(0);

Check warning on line 125 in src/common/daft-config/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/daft-config/src/lib.rs#L125

Added line #L125 was not covered by tests
}
cfg
}
}
Expand Down
15 changes: 15 additions & 0 deletions src/common/daft-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
shuffle_algorithm: Option<&str>,
pre_shuffle_merge_threshold: Option<usize>,
enable_ray_tracing: Option<bool>,
scantask_splitting_level: Option<i32>,
) -> PyResult<Self> {
let mut config = self.config.as_ref().clone();

Expand Down Expand Up @@ -184,6 +185,15 @@
config.enable_ray_tracing = enable_ray_tracing;
}

if let Some(scantask_splitting_level) = scantask_splitting_level {
if !matches!(scantask_splitting_level, 1 | 2) {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"scantask_splitting_level must be 1 or 2",
));
}
config.scantask_splitting_level = scantask_splitting_level;

Check warning on line 194 in src/common/daft-config/src/python.rs

View check run for this annotation

Codecov / codecov/patch

src/common/daft-config/src/python.rs#L189-L194

Added lines #L189 - L194 were not covered by tests
}

Ok(Self {
config: Arc::new(config),
})
Expand Down Expand Up @@ -293,6 +303,11 @@
fn enable_ray_tracing(&self) -> PyResult<bool> {
Ok(self.config.enable_ray_tracing)
}

#[getter]
fn scantask_splitting_level(&self) -> PyResult<i32> {
Ok(self.config.scantask_splitting_level)
}

Check warning on line 310 in src/common/daft-config/src/python.rs

View check run for this annotation

Codecov / codecov/patch

src/common/daft-config/src/python.rs#L308-L310

Added lines #L308 - L310 were not covered by tests
}

impl_bincode_py_state_serialization!(PyDaftExecutionConfig);
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

use crate::{ChunkSpec, DataSource, Pushdowns, ScanTask, ScanTaskRef};

pub(crate) type BoxScanTaskIter<'a> = Box<dyn Iterator<Item = DaftResult<ScanTaskRef>> + 'a>;
type BoxScanTaskIter<'a> = Box<dyn Iterator<Item = DaftResult<ScanTaskRef>> + 'a>;

/// Coalesces ScanTasks by their [`ScanTask::estimate_in_memory_size_bytes()`]
///
Expand All @@ -25,7 +25,7 @@
/// * `min_size_bytes`: Minimum size in bytes of a ScanTask, after which no more merging will be performed
/// * `max_size_bytes`: Maximum size in bytes of a ScanTask, capping the maximum size of a merged ScanTask
#[must_use]
pub(crate) fn merge_by_sizes<'a>(
fn merge_by_sizes<'a>(
scan_tasks: BoxScanTaskIter<'a>,
pushdowns: &Pushdowns,
cfg: &'a DaftExecutionConfig,
Expand Down Expand Up @@ -176,7 +176,7 @@
}

#[must_use]
pub(crate) fn split_by_row_groups(
fn split_by_row_groups(
scan_tasks: BoxScanTaskIter,
max_tasks: usize,
min_size_bytes: usize,
Expand Down Expand Up @@ -316,31 +316,40 @@
.iter()
.all(|st| st.as_any().downcast_ref::<ScanTask>().is_some())
{
// TODO(desmond): Here we downcast Arc<dyn ScanTaskLike> to Arc<ScanTask>. ScanTask and DummyScanTask (test only) are
// the only non-test implementer of ScanTaskLike. It might be possible to avoid the downcast by implementing merging
// at the trait level, but today that requires shifting around a non-trivial amount of code to avoid circular dependencies.
let iter: BoxScanTaskIter = Box::new(scan_tasks.as_ref().iter().map(|st| {
st.clone()
.as_any_arc()
.downcast::<ScanTask>()
.map_err(|e| DaftError::TypeError(format!("Expected Arc<ScanTask>, found {:?}", e)))
}));
let split_tasks = split_by_row_groups(
iter,
cfg.parquet_split_row_groups_max_files,
cfg.scan_tasks_min_size_bytes,
cfg.scan_tasks_max_size_bytes,
);
let merged_tasks = merge_by_sizes(split_tasks, pushdowns, cfg);
let scan_tasks: Vec<Arc<dyn ScanTaskLike>> = merged_tasks
.map(|st| st.map(|task| task as Arc<dyn ScanTaskLike>))
.collect::<DaftResult<Vec<_>>>()?;
Ok(Arc::new(scan_tasks))
if cfg.scantask_splitting_level == 1 {
// TODO(desmond): Here we downcast Arc<dyn ScanTaskLike> to Arc<ScanTask>. ScanTask and DummyScanTask (test only) are
// the only non-test implementer of ScanTaskLike. It might be possible to avoid the downcast by implementing merging
// at the trait level, but today that requires shifting around a non-trivial amount of code to avoid circular dependencies.
let iter: BoxScanTaskIter = Box::new(scan_tasks.as_ref().iter().map(|st| {
st.clone().as_any_arc().downcast::<ScanTask>().map_err(|e| {
DaftError::TypeError(format!("Expected Arc<ScanTask>, found {:?}", e))

Check warning on line 325 in src/daft-scan/src/scan_task_iters/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-scan/src/scan_task_iters/mod.rs#L325

Added line #L325 was not covered by tests
})
}));
let split_tasks = split_by_row_groups(
iter,
cfg.parquet_split_row_groups_max_files,
cfg.scan_tasks_min_size_bytes,
cfg.scan_tasks_max_size_bytes,
);
let merged_tasks = merge_by_sizes(split_tasks, pushdowns, cfg);
let scan_tasks: Vec<Arc<dyn ScanTaskLike>> = merged_tasks
.map(|st| st.map(|task| task as Arc<dyn ScanTaskLike>))
.collect::<DaftResult<Vec<_>>>()?;
Ok(Arc::new(scan_tasks))
} else if cfg.scantask_splitting_level == 2 {
todo!("Implement aggressive scantask splitting");

Check warning on line 340 in src/daft-scan/src/scan_task_iters/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-scan/src/scan_task_iters/mod.rs#L339-L340

Added lines #L339 - L340 were not covered by tests
} else {
panic!(
"DAFT_SCANTASK_SPLITTING_LEVEL must be either 1 or 2, received: {}",
cfg.scantask_splitting_level
);

Check warning on line 345 in src/daft-scan/src/scan_task_iters/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-scan/src/scan_task_iters/mod.rs#L342-L345

Added lines #L342 - L345 were not covered by tests
}
} else {
Ok(scan_tasks)
}
}

/// Sets ``SPLIT_AND_MERGE_PASS``, which is the publicly-available pass that the query optimizer can use
#[ctor::ctor]
fn set_pass() {
let _ = SPLIT_AND_MERGE_PASS.set(&split_and_merge_pass);
Expand Down
58 changes: 58 additions & 0 deletions tests/io/test_split_scan_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,61 @@ def test_split_parquet_read(parquet_files):
df = daft.read_parquet(str(parquet_files))
assert df.num_partitions() == 10, "Should have 10 partitions since we will split the file"
assert df.to_pydict() == {"data": ["aaa"] * 100}


@pytest.mark.skip(reason="Not implemented yet")
def test_split_parquet_read_some_splits(tmpdir):
with daft.execution_config_ctx(scantask_splitting_level=2):
# Write a mix of 20 large and 20 small files
# Small ones should not be split, large ones should be split into 10 rowgroups each
# This gives us a total of 200 + 20 scantasks

# Write 20 large files into tmpdir
large_file_paths = []
for i in range(20):
tbl = pa.table({"data": [str(f"large{i}") for i in range(100)]})
path = tmpdir / f"file.{i}.large.pq"
papq.write_table(tbl, str(path), row_group_size=10, use_dictionary=False)
large_file_paths.append(str(path))

# Write 20 small files into tmpdir
small_file_paths = []
for i in range(20):
tbl = pa.table({"data": ["small"]})
path = tmpdir / f"file.{i}.small.pq"
papq.write_table(tbl, str(path), row_group_size=1, use_dictionary=False)
small_file_paths.append(str(path))

# Test [large_paths, ..., small_paths, ...]
with daft.execution_config_ctx(
scan_tasks_min_size_bytes=20,
scan_tasks_max_size_bytes=100,
):
df = daft.read_parquet(large_file_paths + small_file_paths)
assert (
df.num_partitions() == 220
), "Should have 220 partitions since we will split all large files (20 * 10 rowgroups) but keep small files unsplit"
assert df.to_pydict() == {"data": [str(f"large{i}") for i in range(100)] * 20 + ["small"] * 20}

# Test interleaved [large_path, small_path, large_path, small_path, ...]
with daft.execution_config_ctx(
scan_tasks_min_size_bytes=20,
scan_tasks_max_size_bytes=100,
):
interleaved_paths = [path for pair in zip(large_file_paths, small_file_paths) for path in pair]
df = daft.read_parquet(interleaved_paths)
assert (
df.num_partitions() == 220
), "Should have 220 partitions since we will split all large files (20 * 10 rowgroups) but keep small files unsplit"
assert df.to_pydict() == {"data": ([str(f"large{i}") for i in range(100)] + ["small"]) * 20}

# Test [small_paths, ..., large_paths]
with daft.execution_config_ctx(
scan_tasks_min_size_bytes=20,
scan_tasks_max_size_bytes=100,
):
df = daft.read_parquet(small_file_paths + large_file_paths)
assert (
df.num_partitions() == 220
), "Should have 220 partitions since we will split all large files (20 * 10 rowgroups) but keep small files unsplit"
assert df.to_pydict() == {"data": ["small"] * 20 + [str(f"large{i}") for i in range(100)] * 20}
Loading