diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index ae7799cfde..f28fe76bc0 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -113,7 +113,7 @@ DataFileContent, FileFormat, ) -from pyiceberg.partitioning import PartitionField, PartitionSpec, partition_record_value +from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value from pyiceberg.schema import ( PartnerAccessor, PreOrderSchemaVisitor, @@ -2125,8 +2125,6 @@ def _dataframe_to_data_files( ]), ) else: - from pyiceberg.table import _determine_partitions - partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) yield from write_file( io=io, @@ -2143,3 +2141,100 @@ def _dataframe_to_data_files( for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size) ]), ) + + +@dataclass(frozen=True) +class _TablePartition: + partition_key: PartitionKey + arrow_table_partition: pa.Table + + +def _get_table_partitions( + arrow_table: pa.Table, + partition_spec: PartitionSpec, + schema: Schema, + slice_instructions: list[dict[str, Any]], +) -> list[_TablePartition]: + sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"]) + + partition_fields = partition_spec.fields + + offsets = [inst["offset"] for inst in sorted_slice_instructions] + projected_and_filtered = { + partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name] + .take(offsets) + .to_pylist() + for partition_field in partition_fields + } + + table_partitions = [] + for idx, inst in enumerate(sorted_slice_instructions): + partition_slice = arrow_table.slice(**inst) + fieldvalues = [ + PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx]) + for partition_field in partition_fields + ] + partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema) + table_partitions.append(_TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice)) + return table_partitions + + +def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]: + """Based on the iceberg table partition spec, slice the arrow table into partitions with their keys. + + Example: + Input: + An arrow table with partition key of ['n_legs', 'year'] and with data of + {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021], + 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100], + 'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}. + The algorithm: + Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')] + and null_placement of "at_end". + This gives the same table as raw input. + Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')] + and null_placement : "at_start". + This gives: + [8, 7, 4, 5, 6, 3, 1, 2, 0] + Based on this we get partition groups of indices: + [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}] + We then retrieve the partition keys by offsets. + And slice the arrow table by offsets and lengths of each partition. + """ + partition_columns: List[Tuple[PartitionField, NestedField]] = [ + (partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields + ] + partition_values_table = pa.table({ + str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name]) + for partition, field in partition_columns + }) + + # Sort by partitions + sort_indices = pa.compute.sort_indices( + partition_values_table, + sort_keys=[(col, "ascending") for col in partition_values_table.column_names], + null_placement="at_end", + ).to_pylist() + arrow_table = arrow_table.take(sort_indices) + + # Get slice_instructions to group by partitions + partition_values_table = partition_values_table.take(sort_indices) + reversed_indices = pa.compute.sort_indices( + partition_values_table, + sort_keys=[(col, "descending") for col in partition_values_table.column_names], + null_placement="at_start", + ).to_pylist() + slice_instructions: List[Dict[str, Any]] = [] + last = len(reversed_indices) + reversed_indices_size = len(reversed_indices) + ptr = 0 + while ptr < reversed_indices_size: + group_size = last - reversed_indices[ptr] + offset = reversed_indices[ptr] + slice_instructions.append({"offset": offset, "length": group_size}) + last = reversed_indices[ptr] + ptr = ptr + group_size + + table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) + + return table_partitions diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 39bcfc2ef6..32c0af1b3c 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -92,7 +92,6 @@ PARTITION_FIELD_ID_START, UNPARTITIONED_PARTITION_SPEC, PartitionField, - PartitionFieldValue, PartitionKey, PartitionSpec, _PartitionNameGenerator, @@ -4412,105 +4411,6 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType: ) -@dataclass(frozen=True) -class TablePartition: - partition_key: PartitionKey - arrow_table_partition: pa.Table - - -def _get_table_partitions( - arrow_table: pa.Table, - partition_spec: PartitionSpec, - schema: Schema, - slice_instructions: list[dict[str, Any]], -) -> list[TablePartition]: - sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"]) - - partition_fields = partition_spec.fields - - offsets = [inst["offset"] for inst in sorted_slice_instructions] - projected_and_filtered = { - partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name] - .take(offsets) - .to_pylist() - for partition_field in partition_fields - } - - table_partitions = [] - for idx, inst in enumerate(sorted_slice_instructions): - partition_slice = arrow_table.slice(**inst) - fieldvalues = [ - PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx]) - for partition_field in partition_fields - ] - partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema) - table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice)) - return table_partitions - - -def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[TablePartition]: - """Based on the iceberg table partition spec, slice the arrow table into partitions with their keys. - - Example: - Input: - An arrow table with partition key of ['n_legs', 'year'] and with data of - {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021], - 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100], - 'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}. - The algorithm: - Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')] - and null_placement of "at_end". - This gives the same table as raw input. - Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')] - and null_placement : "at_start". - This gives: - [8, 7, 4, 5, 6, 3, 1, 2, 0] - Based on this we get partition groups of indices: - [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}] - We then retrieve the partition keys by offsets. - And slice the arrow table by offsets and lengths of each partition. - """ - import pyarrow as pa - - partition_columns: List[Tuple[PartitionField, NestedField]] = [ - (partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields - ] - partition_values_table = pa.table({ - str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name]) - for partition, field in partition_columns - }) - - # Sort by partitions - sort_indices = pa.compute.sort_indices( - partition_values_table, - sort_keys=[(col, "ascending") for col in partition_values_table.column_names], - null_placement="at_end", - ).to_pylist() - arrow_table = arrow_table.take(sort_indices) - - # Get slice_instructions to group by partitions - partition_values_table = partition_values_table.take(sort_indices) - reversed_indices = pa.compute.sort_indices( - partition_values_table, - sort_keys=[(col, "descending") for col in partition_values_table.column_names], - null_placement="at_start", - ).to_pylist() - slice_instructions: List[Dict[str, Any]] = [] - last = len(reversed_indices) - reversed_indices_size = len(reversed_indices) - ptr = 0 - while ptr < reversed_indices_size: - group_size = last - reversed_indices[ptr] - offset = reversed_indices[ptr] - slice_instructions.append({"offset": offset, "length": group_size}) - last = reversed_indices[ptr] - ptr = ptr + group_size - - table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) - - return table_partitions - - class _ManifestMergeManager(Generic[U]): _target_size_bytes: int _min_count_to_merge: int diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index ecb946a98b..1b9468993c 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -61,6 +61,7 @@ PyArrowFileIO, StatsAggregator, _ConvertToArrowSchema, + _determine_partitions, _primitive_to_physical, _read_deletes, bin_pack_arrow_table, @@ -69,11 +70,12 @@ schema_to_pyarrow, ) from pyiceberg.manifest import DataFile, DataFileContent, FileFormat -from pyiceberg.partitioning import PartitionSpec +from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema, make_compatible_name, visit from pyiceberg.table import FileScanTask, TableProperties from pyiceberg.table.metadata import TableMetadataV2 -from pyiceberg.typedef import UTF8 +from pyiceberg.transforms import IdentityTransform +from pyiceberg.typedef import UTF8, Record from pyiceberg.types import ( BinaryType, BooleanType, @@ -1718,3 +1720,81 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None: # and will produce half the number of files if we double the target size bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes * 2) assert len(list(bin_packed)) == 5 + + +def test_partition_for_demo() -> None: + test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) + test_schema = Schema( + NestedField(field_id=1, name="year", field_type=StringType(), required=False), + NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="animal", field_type=StringType(), required=False), + schema_id=1, + ) + test_data = { + "year": [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021], + "n_legs": [2, 2, 2, 4, 4, 4, 4, 5, 100], + "animal": ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"], + } + arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), + ) + result = _determine_partitions(partition_spec, test_schema, arrow_table) + assert {table_partition.partition_key.partition for table_partition in result} == { + Record(n_legs_identity=2, year_identity=2020), + Record(n_legs_identity=100, year_identity=2021), + Record(n_legs_identity=4, year_identity=2021), + Record(n_legs_identity=4, year_identity=2022), + Record(n_legs_identity=2, year_identity=2022), + Record(n_legs_identity=5, year_identity=2019), + } + assert ( + pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows + ) + + +def test_identity_partition_on_multi_columns() -> None: + test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) + test_schema = Schema( + NestedField(field_id=1, name="born_year", field_type=StringType(), required=False), + NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="animal", field_type=StringType(), required=False), + schema_id=1, + ) + # 5 partitions, 6 unique row values, 12 rows + test_rows = [ + (2021, 4, "Dog"), + (2022, 4, "Horse"), + (2022, 4, "Another Horse"), + (2021, 100, "Centipede"), + (None, 4, "Kirin"), + (2021, None, "Fish"), + ] * 2 + expected = {Record(n_legs_identity=test_rows[i][1], year_identity=test_rows[i][0]) for i in range(len(test_rows))} + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), + ) + import random + + # there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all + for _ in range(1000): + random.shuffle(test_rows) + test_data = { + "born_year": [row[0] for row in test_rows], + "n_legs": [row[1] for row in test_rows], + "animal": [row[2] for row in test_rows], + } + arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) + + result = _determine_partitions(partition_spec, test_schema, arrow_table) + + assert {table_partition.partition_key.partition for table_partition in result} == expected + concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]) + assert concatenated_arrow_table.num_rows == arrow_table.num_rows + assert concatenated_arrow_table.sort_by([ + ("born_year", "ascending"), + ("n_legs", "ascending"), + ("animal", "ascending"), + ]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")]) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index d7c4ffeeaf..31a8bbf444 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -64,7 +64,6 @@ UpdateSchema, _apply_table_update, _check_schema_compatible, - _determine_partitions, _match_deletes_to_data_file, _TableMetadataUpdateContext, update_table_metadata, @@ -88,7 +87,6 @@ BucketTransform, IdentityTransform, ) -from pyiceberg.typedef import Record from pyiceberg.types import ( BinaryType, BooleanType, @@ -1248,85 +1246,3 @@ def test_serialize_commit_table_request() -> None: deserialized_request = CommitTableRequest.model_validate_json(request.model_dump_json()) assert request == deserialized_request - - -def test_partition_for_demo() -> None: - import pyarrow as pa - - test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) - test_schema = Schema( - NestedField(field_id=1, name="year", field_type=StringType(), required=False), - NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="animal", field_type=StringType(), required=False), - schema_id=1, - ) - test_data = { - "year": [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021], - "n_legs": [2, 2, 2, 4, 4, 4, 4, 5, 100], - "animal": ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"], - } - arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) - partition_spec = PartitionSpec( - PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), - PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), - ) - result = _determine_partitions(partition_spec, test_schema, arrow_table) - assert {table_partition.partition_key.partition for table_partition in result} == { - Record(n_legs_identity=2, year_identity=2020), - Record(n_legs_identity=100, year_identity=2021), - Record(n_legs_identity=4, year_identity=2021), - Record(n_legs_identity=4, year_identity=2022), - Record(n_legs_identity=2, year_identity=2022), - Record(n_legs_identity=5, year_identity=2019), - } - assert ( - pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows - ) - - -def test_identity_partition_on_multi_columns() -> None: - import pyarrow as pa - - test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) - test_schema = Schema( - NestedField(field_id=1, name="born_year", field_type=StringType(), required=False), - NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="animal", field_type=StringType(), required=False), - schema_id=1, - ) - # 5 partitions, 6 unique row values, 12 rows - test_rows = [ - (2021, 4, "Dog"), - (2022, 4, "Horse"), - (2022, 4, "Another Horse"), - (2021, 100, "Centipede"), - (None, 4, "Kirin"), - (2021, None, "Fish"), - ] * 2 - expected = {Record(n_legs_identity=test_rows[i][1], year_identity=test_rows[i][0]) for i in range(len(test_rows))} - partition_spec = PartitionSpec( - PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), - PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), - ) - import random - - # there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all - for _ in range(1000): - random.shuffle(test_rows) - test_data = { - "born_year": [row[0] for row in test_rows], - "n_legs": [row[1] for row in test_rows], - "animal": [row[2] for row in test_rows], - } - arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) - - result = _determine_partitions(partition_spec, test_schema, arrow_table) - - assert {table_partition.partition_key.partition for table_partition in result} == expected - concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]) - assert concatenated_arrow_table.num_rows == arrow_table.num_rows - assert concatenated_arrow_table.sort_by([ - ("born_year", "ascending"), - ("n_legs", "ascending"), - ("animal", "ascending"), - ]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")])