Skip to content

Commit

Permalink
[BUG] Fix reading partition key columns in DeltaLake (#2118)
Browse files Browse the repository at this point in the history
Fixes pushdowns for column selection on partition keys in DeltaLake and
Iceberg.

When table formats such as Iceberg and Delta Lake store the data for a
partition column, they will strip the column from the actual Parquet
data files that they write out. Then when we want to select only
specific columns, our Parquet reader fails when it is not able to find
those columns in the file.

NOTE: Seems like Iceberg only does this for identity transformed
partition columns

Follow-on issue for selection of **only** the partition keys:
#2129

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Apr 15, 2024
1 parent 9c441bb commit a05dfb5
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 6 deletions.
41 changes: 35 additions & 6 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,14 @@ fn materialize_scan_task(
scan_task: Arc<ScanTask>,
io_stats: Option<IOStatsRef>,
) -> crate::Result<(Vec<Table>, SchemaRef)> {
let column_names = scan_task
let pushdown_columns = scan_task
.pushdowns
.columns
.as_ref()
.map(|v| v.iter().map(|s| s.as_ref()).collect::<Vec<&str>>());
.map(|v| v.iter().map(|s| s.as_str()).collect::<Vec<&str>>());
let file_column_names =
_get_file_column_names(pushdown_columns.as_deref(), scan_task.partition_spec());

let urls = scan_task.sources.iter().map(|s| s.get_path());

let mut table_values = match scan_task.storage_config.as_ref() {
Expand Down Expand Up @@ -130,7 +133,7 @@ fn materialize_scan_task(
let row_groups = parquet_sources_to_row_groups(scan_task.sources.as_slice());
daft_parquet::read::read_parquet_bulk(
urls.as_slice(),
column_names.as_deref(),
file_column_names.as_deref(),
None,
scan_task.pushdowns.limit,
row_groups,
Expand Down Expand Up @@ -163,7 +166,7 @@ fn materialize_scan_task(
};
let convert_options = CsvConvertOptions::new_internal(
scan_task.pushdowns.limit,
column_names
file_column_names
.as_ref()
.map(|cols| cols.iter().map(|col| col.to_string()).collect()),
col_names
Expand Down Expand Up @@ -204,7 +207,7 @@ fn materialize_scan_task(
FileFormatConfig::Json(cfg) => {
let convert_options = JsonConvertOptions::new_internal(
scan_task.pushdowns.limit,
column_names
file_column_names
.as_ref()
.map(|cols| cols.iter().map(|col| col.to_string()).collect()),
Some(scan_task.schema.clone()),
Expand Down Expand Up @@ -757,6 +760,31 @@ pub(crate) fn read_json_into_micropartition(
}
}

fn _get_file_column_names<'a>(
columns: Option<&'a [&'a str]>,
partition_spec: Option<&PartitionSpec>,
) -> Option<Vec<&'a str>> {
match (columns, partition_spec.map(|ps| ps.to_fill_map())) {
(None, _) => None,
(Some(columns), None) => Some(columns.to_vec()),

// If the ScanTask has a partition_spec, we elide reads of partition columns from the file
(Some(columns), Some(partition_fillmap)) => Some(
columns
.as_ref()
.iter()
.filter_map(|s| {
if partition_fillmap.contains_key(s) {
None
} else {
Some(*s)
}
})
.collect::<Vec<&str>>(),
),
}
}

#[allow(clippy::too_many_arguments)]
fn _read_parquet_into_loaded_micropartition(
io_client: Arc<IOClient>,
Expand All @@ -774,9 +802,10 @@ fn _read_parquet_into_loaded_micropartition(
catalog_provided_schema: Option<SchemaRef>,
field_id_mapping: Option<Arc<BTreeMap<i32, Field>>>,
) -> DaftResult<MicroPartition> {
let file_column_names = _get_file_column_names(columns, partition_spec);
let all_tables = read_parquet_bulk(
uris,
columns,
file_column_names.as_deref(),
start_offset,
num_rows,
row_groups,
Expand Down
67 changes: 67 additions & 0 deletions tests/integration/iceberg/test_table_load.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import datetime

import pytest

pyiceberg = pytest.importorskip("pyiceberg")
Expand Down Expand Up @@ -82,3 +84,68 @@ def test_daft_iceberg_table_renamed_column_pushdown_collect_correct(local_iceber
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[["idx_renamed"]]
assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])


@pytest.mark.integration()
def test_daft_iceberg_table_read_partition_column_identity(local_iceberg_catalog):
tab = local_iceberg_catalog.load_table("default.test_partitioned_by_identity")
df = daft.read_iceberg(tab)
df = df.select("ts", "number")
daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[["ts", "number"]]
assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])


@pytest.mark.integration()
def test_daft_iceberg_table_read_partition_column_identity_filter(local_iceberg_catalog):
tab = local_iceberg_catalog.load_table("default.test_partitioned_by_identity")
df = daft.read_iceberg(tab)
df = df.where(df["number"] > 0)
df = df.select("ts")
daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[iceberg_pandas["number"] > 0][["ts"]]
assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])


@pytest.mark.skip(
reason="Selecting just the identity-transformed partition key in an iceberg table is not yet supported. "
"Issue: https://github.com/Eventual-Inc/Daft/issues/2129"
)
@pytest.mark.integration()
def test_daft_iceberg_table_read_partition_column_identity_filter_on_partkey(local_iceberg_catalog):
tab = local_iceberg_catalog.load_table("default.test_partitioned_by_identity")
df = daft.read_iceberg(tab)
df = df.select("ts")
df = df.where(df["ts"] > datetime.date(2022, 3, 1))
daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[iceberg_pandas["ts"] > datetime.date(2022, 3, 1)][["ts"]]
assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])


@pytest.mark.skip(
reason="Selecting just the identity-transformed partition key in an iceberg table is not yet supported. "
"Issue: https://github.com/Eventual-Inc/Daft/issues/2129"
)
@pytest.mark.integration()
def test_daft_iceberg_table_read_partition_column_identity_only(local_iceberg_catalog):
tab = local_iceberg_catalog.load_table("default.test_partitioned_by_identity")
df = daft.read_iceberg(tab)
df = df.select("ts")
daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[["ts"]]
assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])


@pytest.mark.integration()
def test_daft_iceberg_table_read_partition_column_transformed(local_iceberg_catalog):
tab = local_iceberg_catalog.load_table("default.test_partitioned_by_bucket")
df = daft.read_iceberg(tab)
df = df.select("number")
daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[["number"]]
assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])
51 changes: 51 additions & 0 deletions tests/io/delta_lake/test_table_read_pushdowns.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,54 @@ def test_read_predicate_pushdown_on_part_empty(deltalake_table, partition_genera
df.to_arrow().sort_by("part_idx"),
pa.concat_tables([table.filter(pc.field("part_idx") == part_value) for table in tables]),
)


def test_read_select_partition_key(deltalake_table):
path, catalog_table, io_config, tables = deltalake_table
df = daft.read_delta_lake(str(path) if catalog_table is None else catalog_table, io_config=io_config)

df = df.select("part_idx", "a")

assert df.schema().column_names() == ["part_idx", "a"]

assert_pyarrow_tables_equal(
df.to_arrow().sort_by([("part_idx", "ascending"), ("a", "ascending")]),
pa.concat_tables([table.select(["part_idx", "a"]) for table in tables]).sort_by(
[("part_idx", "ascending"), ("a", "ascending")]
),
)


def test_read_select_partition_key_with_filter(deltalake_table):
path, catalog_table, io_config, tables = deltalake_table
df = daft.read_delta_lake(str(path) if catalog_table is None else catalog_table, io_config=io_config)

df = df.select("part_idx", "a")
df = df.where(df["a"] < 5)

assert df.schema().column_names() == ["part_idx", "a"]

assert_pyarrow_tables_equal(
df.to_arrow().sort_by([("part_idx", "ascending"), ("a", "ascending")]),
pa.concat_tables([table.select(["part_idx", "a"]) for table in tables]).sort_by(
[("part_idx", "ascending"), ("a", "ascending")]
),
)


@pytest.mark.skip(
reason="Selecting just the partition key in a deltalake table is not yet supported. "
"Issue: https://github.com/Eventual-Inc/Daft/issues/2129"
)
def test_read_select_only_partition_key(deltalake_table):
path, catalog_table, io_config, tables = deltalake_table
df = daft.read_delta_lake(str(path) if catalog_table is None else catalog_table, io_config=io_config)

df = df.select("part_idx")

assert df.schema().column_names() == ["part_idx"]

assert_pyarrow_tables_equal(
df.to_arrow().sort_by("part_idx"),
pa.concat_tables([table.select(["part_idx"]) for table in tables]).sort_by("part_idx"),
)

0 comments on commit a05dfb5

Please sign in to comment.