Skip to content

Commit

Permalink
Move determine_partitions and helper methods to io.pyarrow (#906)
Browse files Browse the repository at this point in the history
  • Loading branch information
soumya-ghosh authored Jul 11, 2024
1 parent 5aa451d commit 8f47dfd
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 189 deletions.
101 changes: 98 additions & 3 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
DataFileContent,
FileFormat,
)
from pyiceberg.partitioning import PartitionField, PartitionSpec, partition_record_value
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
from pyiceberg.schema import (
PartnerAccessor,
PreOrderSchemaVisitor,
Expand Down Expand Up @@ -2125,8 +2125,6 @@ def _dataframe_to_data_files(
]),
)
else:
from pyiceberg.table import _determine_partitions

partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)
yield from write_file(
io=io,
Expand All @@ -2143,3 +2141,100 @@ def _dataframe_to_data_files(
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
]),
)


@dataclass(frozen=True)
class _TablePartition:
partition_key: PartitionKey
arrow_table_partition: pa.Table


def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
schema: Schema,
slice_instructions: list[dict[str, Any]],
) -> list[_TablePartition]:
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"])

partition_fields = partition_spec.fields

offsets = [inst["offset"] for inst in sorted_slice_instructions]
projected_and_filtered = {
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
.take(offsets)
.to_pylist()
for partition_field in partition_fields
}

table_partitions = []
for idx, inst in enumerate(sorted_slice_instructions):
partition_slice = arrow_table.slice(**inst)
fieldvalues = [
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
for partition_field in partition_fields
]
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
table_partitions.append(_TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
return table_partitions


def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]:
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
Example:
Input:
An arrow table with partition key of ['n_legs', 'year'] and with data of
{'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
The algorithm:
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
and null_placement of "at_end".
This gives the same table as raw input.
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
and null_placement : "at_start".
This gives:
[8, 7, 4, 5, 6, 3, 1, 2, 0]
Based on this we get partition groups of indices:
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
We then retrieve the partition keys by offsets.
And slice the arrow table by offsets and lengths of each partition.
"""
partition_columns: List[Tuple[PartitionField, NestedField]] = [
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
]
partition_values_table = pa.table({
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
for partition, field in partition_columns
})

# Sort by partitions
sort_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
null_placement="at_end",
).to_pylist()
arrow_table = arrow_table.take(sort_indices)

# Get slice_instructions to group by partitions
partition_values_table = partition_values_table.take(sort_indices)
reversed_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
null_placement="at_start",
).to_pylist()
slice_instructions: List[Dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
while ptr < reversed_indices_size:
group_size = last - reversed_indices[ptr]
offset = reversed_indices[ptr]
slice_instructions.append({"offset": offset, "length": group_size})
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)

return table_partitions
100 changes: 0 additions & 100 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
PARTITION_FIELD_ID_START,
UNPARTITIONED_PARTITION_SPEC,
PartitionField,
PartitionFieldValue,
PartitionKey,
PartitionSpec,
_PartitionNameGenerator,
Expand Down Expand Up @@ -4412,105 +4411,6 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
)


@dataclass(frozen=True)
class TablePartition:
partition_key: PartitionKey
arrow_table_partition: pa.Table


def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
schema: Schema,
slice_instructions: list[dict[str, Any]],
) -> list[TablePartition]:
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"])

partition_fields = partition_spec.fields

offsets = [inst["offset"] for inst in sorted_slice_instructions]
projected_and_filtered = {
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
.take(offsets)
.to_pylist()
for partition_field in partition_fields
}

table_partitions = []
for idx, inst in enumerate(sorted_slice_instructions):
partition_slice = arrow_table.slice(**inst)
fieldvalues = [
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
for partition_field in partition_fields
]
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
return table_partitions


def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[TablePartition]:
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
Example:
Input:
An arrow table with partition key of ['n_legs', 'year'] and with data of
{'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
The algorithm:
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
and null_placement of "at_end".
This gives the same table as raw input.
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
and null_placement : "at_start".
This gives:
[8, 7, 4, 5, 6, 3, 1, 2, 0]
Based on this we get partition groups of indices:
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
We then retrieve the partition keys by offsets.
And slice the arrow table by offsets and lengths of each partition.
"""
import pyarrow as pa

partition_columns: List[Tuple[PartitionField, NestedField]] = [
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
]
partition_values_table = pa.table({
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
for partition, field in partition_columns
})

# Sort by partitions
sort_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
null_placement="at_end",
).to_pylist()
arrow_table = arrow_table.take(sort_indices)

# Get slice_instructions to group by partitions
partition_values_table = partition_values_table.take(sort_indices)
reversed_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
null_placement="at_start",
).to_pylist()
slice_instructions: List[Dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
while ptr < reversed_indices_size:
group_size = last - reversed_indices[ptr]
offset = reversed_indices[ptr]
slice_instructions.append({"offset": offset, "length": group_size})
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)

return table_partitions


class _ManifestMergeManager(Generic[U]):
_target_size_bytes: int
_min_count_to_merge: int
Expand Down
84 changes: 82 additions & 2 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
PyArrowFileIO,
StatsAggregator,
_ConvertToArrowSchema,
_determine_partitions,
_primitive_to_physical,
_read_deletes,
bin_pack_arrow_table,
Expand All @@ -69,11 +70,12 @@
schema_to_pyarrow,
)
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema, make_compatible_name, visit
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.typedef import UTF8
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import UTF8, Record
from pyiceberg.types import (
BinaryType,
BooleanType,
Expand Down Expand Up @@ -1718,3 +1720,81 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None:
# and will produce half the number of files if we double the target size
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes * 2)
assert len(list(bin_packed)) == 5


def test_partition_for_demo() -> None:
test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
test_schema = Schema(
NestedField(field_id=1, name="year", field_type=StringType(), required=False),
NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="animal", field_type=StringType(), required=False),
schema_id=1,
)
test_data = {
"year": [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021],
"n_legs": [2, 2, 2, 4, 4, 4, 4, 5, 100],
"animal": ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"],
}
arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"),
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"),
)
result = _determine_partitions(partition_spec, test_schema, arrow_table)
assert {table_partition.partition_key.partition for table_partition in result} == {
Record(n_legs_identity=2, year_identity=2020),
Record(n_legs_identity=100, year_identity=2021),
Record(n_legs_identity=4, year_identity=2021),
Record(n_legs_identity=4, year_identity=2022),
Record(n_legs_identity=2, year_identity=2022),
Record(n_legs_identity=5, year_identity=2019),
}
assert (
pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows
)


def test_identity_partition_on_multi_columns() -> None:
test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
test_schema = Schema(
NestedField(field_id=1, name="born_year", field_type=StringType(), required=False),
NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="animal", field_type=StringType(), required=False),
schema_id=1,
)
# 5 partitions, 6 unique row values, 12 rows
test_rows = [
(2021, 4, "Dog"),
(2022, 4, "Horse"),
(2022, 4, "Another Horse"),
(2021, 100, "Centipede"),
(None, 4, "Kirin"),
(2021, None, "Fish"),
] * 2
expected = {Record(n_legs_identity=test_rows[i][1], year_identity=test_rows[i][0]) for i in range(len(test_rows))}
partition_spec = PartitionSpec(
PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"),
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"),
)
import random

# there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all
for _ in range(1000):
random.shuffle(test_rows)
test_data = {
"born_year": [row[0] for row in test_rows],
"n_legs": [row[1] for row in test_rows],
"animal": [row[2] for row in test_rows],
}
arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)

result = _determine_partitions(partition_spec, test_schema, arrow_table)

assert {table_partition.partition_key.partition for table_partition in result} == expected
concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result])
assert concatenated_arrow_table.num_rows == arrow_table.num_rows
assert concatenated_arrow_table.sort_by([
("born_year", "ascending"),
("n_legs", "ascending"),
("animal", "ascending"),
]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")])
Loading

0 comments on commit 8f47dfd

Please sign in to comment.