From f2949b7218d4e2cd601e1f7f7db47483cda88d7c Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Tue, 18 Jun 2024 10:01:11 -0700 Subject: [PATCH 1/8] remove arrow schema cast --- pyiceberg/table/__init__.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 62440c4773..b43dc3206b 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -484,10 +484,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) _check_schema_compatible( self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self._table.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) manifest_merge_enabled = PropertyUtil.property_as_bool( self.table_metadata.properties, @@ -545,10 +541,6 @@ def overwrite( _check_schema_compatible( self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self._table.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties) From 397e31e099f98e6cbad754e238d4afda13907823 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Tue, 18 Jun 2024 11:33:15 -0700 Subject: [PATCH 2/8] add test for writing out of order schema --- tests/integration/test_writes/test_writes.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index af626718f7..f3bc7bc569 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -975,6 +975,20 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns) * 2 +@pytest.mark.parametrize("format_version", [1, 2]) +def table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: + import random + identifier = "default.table_write_out_of_order_schema" + shuffled_schema = pa.schema(random.shuffle(arrow_table_with_null.schema)) + + tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=shuffled_schema) + + tbl.overwrite(arrow_table_with_null) + tbl.append(arrow_table_with_null) + # overwrite and then append should produce twice the data + assert len(tbl.scan().to_arrow()) == len(arrow_table_with_null) * 2 + + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: Catalog, format_version: int) -> None: From fd24c5686e72a3bc25ca98ebaf16de9c553967df Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Tue, 18 Jun 2024 11:37:27 -0700 Subject: [PATCH 3/8] lint --- tests/integration/test_writes/test_writes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index f3bc7bc569..f06f781ab1 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -978,6 +978,7 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null @pytest.mark.parametrize("format_version", [1, 2]) def table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: import random + identifier = "default.table_write_out_of_order_schema" shuffled_schema = pa.schema(random.shuffle(arrow_table_with_null.schema)) From 2c52276cf55f5781633be6df20976af6671c8f8e Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Tue, 18 Jun 2024 11:50:31 -0700 Subject: [PATCH 4/8] properly name tests --- tests/integration/test_writes/test_writes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index f06f781ab1..92b0c06213 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -964,8 +964,8 @@ def test_sanitize_character_partitioned(catalog: Catalog) -> None: @pytest.mark.parametrize("format_version", [1, 2]) -def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: - identifier = "default.table_append_subset_of_schema" +def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: + identifier = "default.test_table_write_subset_of_schema" tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null]) arrow_table_without_some_columns = arrow_table_with_null.combine_chunks().drop(arrow_table_with_null.column_names[0]) assert len(arrow_table_without_some_columns.columns) < len(arrow_table_with_null.columns) @@ -976,10 +976,10 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null @pytest.mark.parametrize("format_version", [1, 2]) -def table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: +def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: import random - identifier = "default.table_write_out_of_order_schema" + identifier = "default.test_table_write_out_of_order_schema" shuffled_schema = pa.schema(random.shuffle(arrow_table_with_null.schema)) tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=shuffled_schema) From 6a9b612e3cf1cc1d58fd6fe790d1233cb306a7b0 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Tue, 18 Jun 2024 12:23:23 -0700 Subject: [PATCH 5/8] rewrite test_table_write_out_of_order_schema --- tests/integration/test_writes/test_writes.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 92b0c06213..8692d714be 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -977,12 +977,13 @@ def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with @pytest.mark.parametrize("format_version", [1, 2]) def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: - import random - identifier = "default.test_table_write_out_of_order_schema" - shuffled_schema = pa.schema(random.shuffle(arrow_table_with_null.schema)) - - tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=shuffled_schema) + # rotate the schema fields by 1 + fields = list(arrow_table_with_null.schema) + rotated_fields = fields[1:] + fields[:1] + rotated_schema = pa.schema(rotated_fields) + assert arrow_table_with_null.schema != rotated_schema + tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=rotated_schema) tbl.overwrite(arrow_table_with_null) tbl.append(arrow_table_with_null) From 9ee11c2f8c7b342a3a6afb5d7c0bca6b43a34b73 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Tue, 18 Jun 2024 16:38:56 -0700 Subject: [PATCH 6/8] reimplement _check_schema_compatible --- pyiceberg/io/pyarrow.py | 19 ++++++++++++------- tests/io/test_pyarrow.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 56f2242514..86e0a39713 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2036,24 +2036,29 @@ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, down """ Check if the `table_schema` is compatible with `other_schema`. - Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type. + The schemas are compatible if: + - All fields in `other_schema` are present in `table_schema`. (other_schema <= table_schema) + - All required fields in `table_schema` are present in `other_schema`. Raises: ValueError: If the schemas are not compatible. """ + from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema + name_mapping = table_schema.name_mapping try: - task_schema = pyarrow_to_schema( - other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us - ) + other_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping) except ValueError as e: - other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) + other_schema = _pyarrow_to_schema_without_ids(other_schema) additional_names = set(other_schema.column_names) - set(table_schema.column_names) raise ValueError( f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e - if table_schema.as_struct() != task_schema.as_struct(): + missing_table_schema_fields = {field for field in other_schema.fields if field not in table_schema.fields} + required_table_schema_fields = {field for field in table_schema.fields if field.required} + missing_required_fields = {field for field in required_table_schema_fields if field not in other_schema.fields} + if missing_table_schema_fields or missing_required_fields: from rich.console import Console from rich.table import Table as RichTable @@ -2066,7 +2071,7 @@ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, down for lhs in table_schema.fields: try: - rhs = task_schema.find_field(lhs.field_id) + rhs = other_schema.find_field(lhs.field_id) rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) except ValueError: rich_table.add_row("❌", str(lhs), "Missing") diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 326eeff195..3ce38555aa 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -1799,6 +1799,39 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: _check_schema_compatible(table_schema_simple, other_schema) +def test_schema_compatible(table_schema_simple: Schema) -> None: + try: + _check_schema_compatible(table_schema_simple, table_schema_simple.as_arrow()) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + + +def test_schema_projection(table_schema_simple: Schema) -> None: + # remove optional `baz` field from `table_schema_simple` + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + )) + try: + _check_schema_compatible(table_schema_simple, other_schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + + +def test_schema_downcast(table_schema_simple: Schema) -> None: + # large_string type is compatible with string type + other_schema = pa.schema(( + pa.field("foo", pa.large_string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + )) + + try: + _check_schema_compatible(table_schema_simple, other_schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") + + def test_schema_downcast(table_schema_simple: Schema) -> None: # large_string type is compatible with string type other_schema = pa.schema(( From 84ad3da45d2872015241f57fa33617f9e4a9da83 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Mon, 24 Jun 2024 10:16:53 -0700 Subject: [PATCH 7/8] mark test as integration --- tests/integration/test_writes/test_writes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 8692d714be..2fd5a8d4c7 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -963,6 +963,7 @@ def test_sanitize_character_partitioned(catalog: Catalog) -> None: assert len(tbl.scan().to_arrow()) == 22 +@pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: identifier = "default.test_table_write_subset_of_schema" @@ -975,6 +976,7 @@ def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns) * 2 +@pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: identifier = "default.test_table_write_out_of_order_schema" From 2ce6db31f4560e2f53d735da9141ab9629952d7a Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Thu, 11 Jul 2024 22:03:16 -0700 Subject: [PATCH 8/8] merge main --- pyiceberg/io/pyarrow.py | 25 ++++++++++++++----------- tests/io/test_pyarrow.py | 14 -------------- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 86e0a39713..4ad9309666 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2034,7 +2034,7 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None: """ - Check if the `table_schema` is compatible with `other_schema`. + Check if the `table_schema` is compatible with `other_schema` in terms of the Iceberg Schema representation. The schemas are compatible if: - All fields in `other_schema` are present in `table_schema`. (other_schema <= table_schema) @@ -2043,22 +2043,22 @@ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, down Raises: ValueError: If the schemas are not compatible. """ - from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema - name_mapping = table_schema.name_mapping try: - other_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping) + other_schema = pyarrow_to_schema( + other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + ) except ValueError as e: - other_schema = _pyarrow_to_schema_without_ids(other_schema) + other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) additional_names = set(other_schema.column_names) - set(table_schema.column_names) raise ValueError( f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e - missing_table_schema_fields = {field for field in other_schema.fields if field not in table_schema.fields} - required_table_schema_fields = {field for field in table_schema.fields if field.required} - missing_required_fields = {field for field in required_table_schema_fields if field not in other_schema.fields} - if missing_table_schema_fields or missing_required_fields: + fields_missing_from_table = {field for field in other_schema.fields if field not in table_schema.fields} + required_fields_in_table = {field for field in table_schema.fields if field.required} + missing_required_fields_in_other = {field for field in required_fields_in_table if field not in other_schema.fields} + if fields_missing_from_table or missing_required_fields_in_other: from rich.console import Console from rich.table import Table as RichTable @@ -2182,17 +2182,20 @@ def _dataframe_to_data_files( default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT, ) + # projects schema to match the pyarrow table + write_schema = pyarrow_to_schema(df.schema, name_mapping=table_metadata.schema().name_mapping) + if table_metadata.spec().is_unpartitioned(): yield from write_file( io=io, table_metadata=table_metadata, tasks=iter([ - WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema()) + WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=write_schema) for batches in bin_pack_arrow_table(df, target_file_size) ]), ) else: - partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) + partitions = _determine_partitions(spec=table_metadata.spec(), schema=write_schema, arrow_table=df) yield from write_file( io=io, table_metadata=table_metadata, diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 3ce38555aa..d0045854b3 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -1832,20 +1832,6 @@ def test_schema_downcast(table_schema_simple: Schema) -> None: pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`") -def test_schema_downcast(table_schema_simple: Schema) -> None: - # large_string type is compatible with string type - other_schema = pa.schema(( - pa.field("foo", pa.large_string(), nullable=True), - pa.field("bar", pa.int32(), nullable=False), - pa.field("baz", pa.bool_(), nullable=True), - )) - - try: - _check_schema_compatible(table_schema_simple, other_schema) - except Exception: - pytest.fail("Unexpected Exception raised when calling `_check_schema`") - - def test_partition_for_demo() -> None: test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) test_schema = Schema(