From 1b9b884e56f74c7b7d1802774317ee95d799c5f2 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 11 Jul 2024 12:45:20 +0200 Subject: [PATCH] PyArrow: Don't enforce the schema when reading/writing (#902) * PyArrow: Don't enforce the schema PyIceberg struggled with the different type of arrow, such as the `string` and `large_string`. They represent the same, but are different under the hood. My take is that we should hide these kind of details from the user as much as possible. Now we went down the road of passing in the Iceberg schema into Arrow, but when doing this, Iceberg has to decide if it is a large or non-large type. This PR removes passing down the schema in order to let Arrow decide unless: - The type should be evolved - In case of re-ordering, we reorder the original types * WIP * Reuse Table schema * Make linter happy * Squash some bugs * Thanks Sung! Co-authored-by: Sung Yun <107272191+syun64@users.noreply.github.com> * Moar code moar bugs * Remove the variables wrt file sizes * Linting * Go with large ones for now * Missed one there! --------- Co-authored-by: Sung Yun <107272191+syun64@users.noreply.github.com> --- pyiceberg/io/pyarrow.py | 73 ++++++++++------- pyiceberg/table/__init__.py | 3 +- tests/integration/test_add_files.py | 80 +++++++++++++++++-- tests/integration/test_deletes.py | 2 +- tests/integration/test_inspect_table.py | 9 ++- .../test_writes/test_partitioned_writes.py | 12 ++- tests/integration/test_writes/test_writes.py | 35 +++++--- 7 files changed, 156 insertions(+), 58 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index f28fe76bc0..142e9e5f08 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1047,8 +1047,10 @@ def _task_to_record_batches( fragment_scanner = ds.Scanner.from_fragment( fragment=fragment, - # We always use large types in memory as it uses larger offsets - # That can chunk more row values into the buffers + # With PyArrow 16.0.0 there is an issue with casting record-batches: + # https://github.com/apache/arrow/issues/41884 + # https://github.com/apache/arrow/issues/43183 + # Would be good to remove this later on schema=_pyarrow_schema_ensure_large_types(physical_schema), # This will push down the query to Arrow. # But in case there are positional deletes, we have to apply them first @@ -1084,11 +1086,17 @@ def _task_to_table( positional_deletes: Optional[List[ChunkedArray]], case_sensitive: bool, name_mapping: Optional[NameMapping] = None, -) -> pa.Table: - batches = _task_to_record_batches( - fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping +) -> Optional[pa.Table]: + batches = list( + _task_to_record_batches( + fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping + ) ) - return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False)) + + if len(batches) > 0: + return pa.Table.from_batches(batches) + else: + return None def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: @@ -1192,7 +1200,7 @@ def project_table( if len(tables) < 1: return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema, include_field_ids=False)) - result = pa.concat_tables(tables) + result = pa.concat_tables(tables, promote_options="permissive") if limit is not None: return result.slice(0, limit) @@ -1271,54 +1279,62 @@ def project_batches( def to_requested_schema( - requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, downcast_ns_timestamp_to_us: bool = False + requested_schema: Schema, + file_schema: Schema, + batch: pa.RecordBatch, + downcast_ns_timestamp_to_us: bool = False, + include_field_ids: bool = False, ) -> pa.RecordBatch: + # We could re-use some of these visitors struct_array = visit_with_partner( - requested_schema, batch, ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us), ArrowAccessor(file_schema) + requested_schema, + batch, + ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids), + ArrowAccessor(file_schema), ) - - arrays = [] - fields = [] - for pos, field in enumerate(requested_schema.fields): - array = struct_array.field(pos) - arrays.append(array) - fields.append(pa.field(field.name, array.type, field.optional)) - return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields)) + return pa.RecordBatch.from_struct_array(struct_array) class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]): file_schema: Schema + _include_field_ids: bool - def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False): + def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None: self.file_schema = file_schema + self._include_field_ids = include_field_ids self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: file_field = self.file_schema.find_field(field.field_id) + if field.field_type.is_primitive: if field.field_type != file_field.field_type: - return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False)) - elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=False)) != values.type: - # if file_field and field_type (e.g. String) are the same - # but the pyarrow type of the array is different from the expected type - # (e.g. string vs larger_string), we want to cast the array to the larger type - safe = True + return values.cast( + schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids) + ) + elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type: + # Downcasting of nanoseconds to microseconds if ( pa.types.is_timestamp(target_type) and target_type.unit == "us" and pa.types.is_timestamp(values.type) and values.type.unit == "ns" ): - safe = False - return values.cast(target_type, safe=safe) + return values.cast(target_type, safe=False) return values def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field: + metadata = {} + if field.doc: + metadata[PYARROW_FIELD_DOC_KEY] = field.doc + if self._include_field_ids: + metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id) + return pa.field( name=field.name, type=arrow_type, nullable=field.optional, - metadata={DOC: field.doc} if field.doc is not None else None, + metadata=metadata, ) def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]: @@ -1960,6 +1976,7 @@ def write_parquet(task: WriteTask) -> DataFile: file_schema=table_schema, batch=batch, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + include_field_ids=True, ) for batch in task.record_batches ] @@ -1967,7 +1984,7 @@ def write_parquet(task: WriteTask) -> DataFile: file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}' fo = io.new_output(file_path) with fo.create(overwrite=True) as fos: - with pq.ParquetWriter(fos, schema=file_schema.as_arrow(), **parquet_writer_kwargs) as writer: + with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer: writer.write(arrow_table, row_group_size=row_group_size) statistics = data_file_statistics_from_parquet_metadata( parquet_metadata=writer.writer.metadata, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 7638200881..5342d37053 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2053,8 +2053,9 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow + target_schema = schema_to_pyarrow(self.projection()) return pa.RecordBatchReader.from_batches( - schema_to_pyarrow(self.projection()), + target_schema, project_batches( self.plan_files(), self.table_metadata, diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index fffdfc8ef9..825d17e924 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -18,7 +18,7 @@ import os from datetime import date -from typing import Iterator, Optional +from typing import Iterator import pyarrow as pa import pyarrow.parquet as pq @@ -28,7 +28,8 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.io import FileIO +from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import Table from pyiceberg.transforms import BucketTransform, IdentityTransform, MonthTransform @@ -107,23 +108,32 @@ ) +def _write_parquet(io: FileIO, file_path: str, arrow_schema: pa.Schema, arrow_table: pa.Table) -> None: + fo = io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=arrow_schema) as writer: + writer.write_table(arrow_table) + + def _create_table( - session_catalog: Catalog, identifier: str, format_version: int, partition_spec: Optional[PartitionSpec] = None + session_catalog: Catalog, + identifier: str, + format_version: int, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + schema: Schema = TABLE_SCHEMA, ) -> Table: try: session_catalog.drop_table(identifier=identifier) except NoSuchTableError: pass - tbl = session_catalog.create_table( + return session_catalog.create_table( identifier=identifier, - schema=TABLE_SCHEMA, + schema=schema, properties={"format-version": str(format_version)}, - partition_spec=partition_spec if partition_spec else PartitionSpec(), + partition_spec=partition_spec, ) - return tbl - @pytest.fixture(name="format_version", params=[pytest.param(1, id="format_version=1"), pytest.param(2, id="format_version=2")]) def format_version_fixure(request: pytest.FixtureRequest) -> Iterator[int]: @@ -454,6 +464,60 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat @pytest.mark.integration +def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.unpartitioned_with_large_types{format_version}" + + iceberg_schema = Schema(NestedField(1, "foo", StringType(), required=True)) + arrow_schema = pa.schema([ + pa.field("foo", pa.string(), nullable=False), + ]) + arrow_schema_large = pa.schema([ + pa.field("foo", pa.large_string(), nullable=False), + ]) + + tbl = _create_table(session_catalog, identifier, format_version, schema=iceberg_schema) + + file_path = f"s3://warehouse/default/unpartitioned_with_large_types/v{format_version}/test-0.parquet" + _write_parquet( + tbl.io, + file_path, + arrow_schema, + pa.Table.from_pylist( + [ + { + "foo": "normal", + } + ], + schema=arrow_schema, + ), + ) + + tbl.add_files([file_path]) + + table_schema = tbl.scan().to_arrow().schema + assert table_schema == arrow_schema_large + + file_path_large = f"s3://warehouse/default/unpartitioned_with_large_types/v{format_version}/test-1.parquet" + _write_parquet( + tbl.io, + file_path_large, + arrow_schema_large, + pa.Table.from_pylist( + [ + { + "foo": "normal", + } + ], + schema=arrow_schema_large, + ), + ) + + tbl.add_files([file_path_large]) + + table_schema = tbl.scan().to_arrow().schema + assert table_schema == arrow_schema_large + + def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None: nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType())) diff --git a/tests/integration/test_deletes.py b/tests/integration/test_deletes.py index ad3adedeca..d8fb01c447 100644 --- a/tests/integration/test_deletes.py +++ b/tests/integration/test_deletes.py @@ -291,7 +291,7 @@ def test_partitioned_table_positional_deletes_sequence_number(spark: SparkSessio assert snapshots[2].summary == Summary( Operation.OVERWRITE, **{ - "added-files-size": "1145", + "added-files-size": snapshots[2].summary["total-files-size"], "added-data-files": "1", "added-records": "2", "changed-partition-count": "1", diff --git a/tests/integration/test_inspect_table.py b/tests/integration/test_inspect_table.py index d8a83e0df7..9415d7146d 100644 --- a/tests/integration/test_inspect_table.py +++ b/tests/integration/test_inspect_table.py @@ -110,22 +110,25 @@ def test_inspect_snapshots( for manifest_list in df["manifest_list"]: assert manifest_list.as_py().startswith("s3://") + file_size = int(next(value for key, value in df["summary"][0].as_py() if key == "added-files-size")) + assert file_size > 0 + # Append assert df["summary"][0].as_py() == [ - ("added-files-size", "5459"), + ("added-files-size", str(file_size)), ("added-data-files", "1"), ("added-records", "3"), ("total-data-files", "1"), ("total-delete-files", "0"), ("total-records", "3"), - ("total-files-size", "5459"), + ("total-files-size", str(file_size)), ("total-position-deletes", "0"), ("total-equality-deletes", "0"), ] # Delete assert df["summary"][1].as_py() == [ - ("removed-files-size", "5459"), + ("removed-files-size", str(file_size)), ("deleted-data-files", "1"), ("deleted-records", "3"), ("total-data-files", "0"), diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 59bb76933e..12da9c928b 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -252,15 +252,19 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro assert operations == ["append", "append"] summaries = [row.summary for row in rows] + + file_size = int(summaries[0]["added-files-size"]) + assert file_size > 0 + assert summaries[0] == { "changed-partition-count": "3", "added-data-files": "3", - "added-files-size": "15029", + "added-files-size": str(file_size), "added-records": "3", "total-data-files": "3", "total-delete-files": "0", "total-equality-deletes": "0", - "total-files-size": "15029", + "total-files-size": str(file_size), "total-position-deletes": "0", "total-records": "3", } @@ -268,12 +272,12 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro assert summaries[1] == { "changed-partition-count": "3", "added-data-files": "3", - "added-files-size": "15029", + "added-files-size": str(file_size), "added-records": "3", "total-data-files": "6", "total-delete-files": "0", "total-equality-deletes": "0", - "total-files-size": "30058", + "total-files-size": str(file_size * 2), "total-position-deletes": "0", "total-records": "6", } diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 2542fbdb38..af626718f7 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -194,15 +194,18 @@ def test_summaries(spark: SparkSession, session_catalog: Catalog, arrow_table_wi summaries = [row.summary for row in rows] + file_size = int(summaries[0]["added-files-size"]) + assert file_size > 0 + # Append assert summaries[0] == { "added-data-files": "1", - "added-files-size": "5459", + "added-files-size": str(file_size), "added-records": "3", "total-data-files": "1", "total-delete-files": "0", "total-equality-deletes": "0", - "total-files-size": "5459", + "total-files-size": str(file_size), "total-position-deletes": "0", "total-records": "3", } @@ -210,12 +213,12 @@ def test_summaries(spark: SparkSession, session_catalog: Catalog, arrow_table_wi # Append assert summaries[1] == { "added-data-files": "1", - "added-files-size": "5459", + "added-files-size": str(file_size), "added-records": "3", "total-data-files": "2", "total-delete-files": "0", "total-equality-deletes": "0", - "total-files-size": "10918", + "total-files-size": str(file_size * 2), "total-position-deletes": "0", "total-records": "6", } @@ -224,7 +227,7 @@ def test_summaries(spark: SparkSession, session_catalog: Catalog, arrow_table_wi assert summaries[2] == { "deleted-data-files": "2", "deleted-records": "6", - "removed-files-size": "10918", + "removed-files-size": str(file_size * 2), "total-data-files": "0", "total-delete-files": "0", "total-equality-deletes": "0", @@ -236,12 +239,12 @@ def test_summaries(spark: SparkSession, session_catalog: Catalog, arrow_table_wi # Overwrite assert summaries[3] == { "added-data-files": "1", - "added-files-size": "5459", + "added-files-size": str(file_size), "added-records": "3", "total-data-files": "1", "total-delete-files": "0", "total-equality-deletes": "0", - "total-files-size": "5459", + "total-files-size": str(file_size), "total-position-deletes": "0", "total-records": "3", } @@ -576,6 +579,9 @@ def test_summaries_with_only_nulls( summaries = [row.summary for row in rows] + file_size = int(summaries[1]["added-files-size"]) + assert file_size > 0 + assert summaries[0] == { "total-data-files": "0", "total-delete-files": "0", @@ -587,12 +593,12 @@ def test_summaries_with_only_nulls( assert summaries[1] == { "added-data-files": "1", - "added-files-size": "4239", + "added-files-size": str(file_size), "added-records": "2", "total-data-files": "1", "total-delete-files": "0", "total-equality-deletes": "0", - "total-files-size": "4239", + "total-files-size": str(file_size), "total-position-deletes": "0", "total-records": "2", } @@ -600,7 +606,7 @@ def test_summaries_with_only_nulls( assert summaries[2] == { "deleted-data-files": "1", "deleted-records": "2", - "removed-files-size": "4239", + "removed-files-size": str(file_size), "total-data-files": "0", "total-delete-files": "0", "total-equality-deletes": "0", @@ -844,22 +850,25 @@ def test_inspect_snapshots( for manifest_list in df["manifest_list"]: assert manifest_list.as_py().startswith("s3://") + file_size = int(next(value for key, value in df["summary"][0].as_py() if key == "added-files-size")) + assert file_size > 0 + # Append assert df["summary"][0].as_py() == [ - ("added-files-size", "5459"), + ("added-files-size", str(file_size)), ("added-data-files", "1"), ("added-records", "3"), ("total-data-files", "1"), ("total-delete-files", "0"), ("total-records", "3"), - ("total-files-size", "5459"), + ("total-files-size", str(file_size)), ("total-position-deletes", "0"), ("total-equality-deletes", "0"), ] # Delete assert df["summary"][1].as_py() == [ - ("removed-files-size", "5459"), + ("removed-files-size", str(file_size)), ("deleted-data-files", "1"), ("deleted-records", "3"), ("total-data-files", "0"),