Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move determine_partitions and helper methods to io.pyarrow #906

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 98 additions & 3 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
DataFileContent,
FileFormat,
)
from pyiceberg.partitioning import PartitionField, PartitionSpec, partition_record_value
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
from pyiceberg.schema import (
PartnerAccessor,
PreOrderSchemaVisitor,
Expand Down Expand Up @@ -2125,8 +2125,6 @@ def _dataframe_to_data_files(
]),
)
else:
from pyiceberg.table import _determine_partitions

partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)
yield from write_file(
io=io,
Expand All @@ -2143,3 +2141,100 @@ def _dataframe_to_data_files(
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
]),
)


@dataclass(frozen=True)
class _TablePartition:
partition_key: PartitionKey
arrow_table_partition: pa.Table


def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
schema: Schema,
slice_instructions: list[dict[str, Any]],
) -> list[_TablePartition]:
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"])

partition_fields = partition_spec.fields

offsets = [inst["offset"] for inst in sorted_slice_instructions]
projected_and_filtered = {
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
.take(offsets)
.to_pylist()
for partition_field in partition_fields
}

table_partitions = []
for idx, inst in enumerate(sorted_slice_instructions):
partition_slice = arrow_table.slice(**inst)
fieldvalues = [
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
for partition_field in partition_fields
]
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
table_partitions.append(_TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
return table_partitions


def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]:
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.

Example:
Input:
An arrow table with partition key of ['n_legs', 'year'] and with data of
{'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
The algorithm:
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
and null_placement of "at_end".
This gives the same table as raw input.
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
and null_placement : "at_start".
This gives:
[8, 7, 4, 5, 6, 3, 1, 2, 0]
Based on this we get partition groups of indices:
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
We then retrieve the partition keys by offsets.
And slice the arrow table by offsets and lengths of each partition.
"""
partition_columns: List[Tuple[PartitionField, NestedField]] = [
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
]
partition_values_table = pa.table({
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
for partition, field in partition_columns
})

# Sort by partitions
sort_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
null_placement="at_end",
).to_pylist()
arrow_table = arrow_table.take(sort_indices)

# Get slice_instructions to group by partitions
partition_values_table = partition_values_table.take(sort_indices)
reversed_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
null_placement="at_start",
).to_pylist()
slice_instructions: List[Dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
while ptr < reversed_indices_size:
group_size = last - reversed_indices[ptr]
offset = reversed_indices[ptr]
slice_instructions.append({"offset": offset, "length": group_size})
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)

return table_partitions
100 changes: 0 additions & 100 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
PARTITION_FIELD_ID_START,
UNPARTITIONED_PARTITION_SPEC,
PartitionField,
PartitionFieldValue,
PartitionKey,
PartitionSpec,
_PartitionNameGenerator,
Expand Down Expand Up @@ -4412,105 +4411,6 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
)


@dataclass(frozen=True)
class TablePartition:
partition_key: PartitionKey
arrow_table_partition: pa.Table


def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
schema: Schema,
slice_instructions: list[dict[str, Any]],
) -> list[TablePartition]:
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"])

partition_fields = partition_spec.fields

offsets = [inst["offset"] for inst in sorted_slice_instructions]
projected_and_filtered = {
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
.take(offsets)
.to_pylist()
for partition_field in partition_fields
}

table_partitions = []
for idx, inst in enumerate(sorted_slice_instructions):
partition_slice = arrow_table.slice(**inst)
fieldvalues = [
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
for partition_field in partition_fields
]
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
return table_partitions


def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[TablePartition]:
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.

Example:
Input:
An arrow table with partition key of ['n_legs', 'year'] and with data of
{'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
The algorithm:
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
and null_placement of "at_end".
This gives the same table as raw input.
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
and null_placement : "at_start".
This gives:
[8, 7, 4, 5, 6, 3, 1, 2, 0]
Based on this we get partition groups of indices:
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
We then retrieve the partition keys by offsets.
And slice the arrow table by offsets and lengths of each partition.
"""
import pyarrow as pa

partition_columns: List[Tuple[PartitionField, NestedField]] = [
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
]
partition_values_table = pa.table({
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
for partition, field in partition_columns
})

# Sort by partitions
sort_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
null_placement="at_end",
).to_pylist()
arrow_table = arrow_table.take(sort_indices)

# Get slice_instructions to group by partitions
partition_values_table = partition_values_table.take(sort_indices)
reversed_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
null_placement="at_start",
).to_pylist()
slice_instructions: List[Dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
while ptr < reversed_indices_size:
group_size = last - reversed_indices[ptr]
offset = reversed_indices[ptr]
slice_instructions.append({"offset": offset, "length": group_size})
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)

return table_partitions


class _ManifestMergeManager(Generic[U]):
_target_size_bytes: int
_min_count_to_merge: int
Expand Down
84 changes: 82 additions & 2 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
PyArrowFileIO,
StatsAggregator,
_ConvertToArrowSchema,
_determine_partitions,
_primitive_to_physical,
_read_deletes,
bin_pack_arrow_table,
Expand All @@ -69,11 +70,12 @@
schema_to_pyarrow,
)
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema, make_compatible_name, visit
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.typedef import UTF8
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import UTF8, Record
from pyiceberg.types import (
BinaryType,
BooleanType,
Expand Down Expand Up @@ -1718,3 +1720,81 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None:
# and will produce half the number of files if we double the target size
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes * 2)
assert len(list(bin_packed)) == 5


def test_partition_for_demo() -> None:
test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
test_schema = Schema(
NestedField(field_id=1, name="year", field_type=StringType(), required=False),
NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="animal", field_type=StringType(), required=False),
schema_id=1,
)
test_data = {
"year": [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021],
"n_legs": [2, 2, 2, 4, 4, 4, 4, 5, 100],
"animal": ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"],
}
arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"),
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"),
)
result = _determine_partitions(partition_spec, test_schema, arrow_table)
assert {table_partition.partition_key.partition for table_partition in result} == {
Record(n_legs_identity=2, year_identity=2020),
Record(n_legs_identity=100, year_identity=2021),
Record(n_legs_identity=4, year_identity=2021),
Record(n_legs_identity=4, year_identity=2022),
Record(n_legs_identity=2, year_identity=2022),
Record(n_legs_identity=5, year_identity=2019),
}
assert (
pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows
)


def test_identity_partition_on_multi_columns() -> None:
test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
test_schema = Schema(
NestedField(field_id=1, name="born_year", field_type=StringType(), required=False),
NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="animal", field_type=StringType(), required=False),
schema_id=1,
)
# 5 partitions, 6 unique row values, 12 rows
test_rows = [
(2021, 4, "Dog"),
(2022, 4, "Horse"),
(2022, 4, "Another Horse"),
(2021, 100, "Centipede"),
(None, 4, "Kirin"),
(2021, None, "Fish"),
] * 2
expected = {Record(n_legs_identity=test_rows[i][1], year_identity=test_rows[i][0]) for i in range(len(test_rows))}
partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"),
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"),
)
import random

# there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all
for _ in range(1000):
random.shuffle(test_rows)
test_data = {
"born_year": [row[0] for row in test_rows],
"n_legs": [row[1] for row in test_rows],
"animal": [row[2] for row in test_rows],
}
arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)

result = _determine_partitions(partition_spec, test_schema, arrow_table)

assert {table_partition.partition_key.partition for table_partition in result} == expected
concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result])
assert concatenated_arrow_table.num_rows == arrow_table.num_rows
assert concatenated_arrow_table.sort_by([
("born_year", "ascending"),
("n_legs", "ascending"),
("animal", "ascending"),
]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")])
Loading