From cbb8cecee9a226a5b6568e36316f13d24e9acc3c Mon Sep 17 00:00:00 2001 From: HonahX Date: Mon, 3 Jun 2024 03:32:21 +0000 Subject: [PATCH] add merge_append --- pyiceberg/table/__init__.py | 49 ++++++++++++++++++++ tests/integration/test_writes/test_writes.py | 20 ++++---- 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2c5bb977dd..851e589a87 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -428,6 +428,44 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) for data_file in data_files: update_snapshot.append_data_file(data_file) + def merge_append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + """ + Shorthand API for appending a PyArrow table to a table transaction. + + Args: + df: The Arrow dataframe that will be appended to overwrite the table + snapshot_properties: Custom properties to be added to the snapshot summary + """ + try: + import pyarrow as pa + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + if not isinstance(df, pa.Table): + raise ValueError(f"Expected PyArrow table, got: {df}") + + if unsupported_partitions := [ + field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform + ]: + raise ValueError( + f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." + ) + + _check_schema_compatible(self._table.schema(), other_schema=df.schema) + # cast if the two schemas are compatible but not equal + table_arrow_schema = self._table.schema().as_arrow() + if table_arrow_schema != df.schema: + df = df.cast(table_arrow_schema) + + with self.update_snapshot(snapshot_properties=snapshot_properties).merge_append() as update_snapshot: + # skip writing data files if the dataframe is empty + if df.shape[0] > 0: + data_files = _dataframe_to_data_files( + table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df, + io=self._table.io + ) + for data_file in data_files: + update_snapshot.append_data_file(data_file) def overwrite( self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT ) -> None: @@ -1352,6 +1390,17 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) with self.transaction() as tx: tx.append(df=df, snapshot_properties=snapshot_properties) + def merge_append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + """ + Shorthand API for appending a PyArrow table to the table. + + Args: + df: The Arrow dataframe that will be appended to overwrite the table + snapshot_properties: Custom properties to be added to the snapshot summary + """ + with self.transaction() as tx: + tx.merge_append(df=df, snapshot_properties=snapshot_properties) + def overwrite( self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT ) -> None: diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 4886aa2ee6..a5094a2657 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -876,7 +876,7 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_merge_manifest_min_count_to_merge( - spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int + session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int ) -> None: tbl_a = _create_table( session_catalog, @@ -898,19 +898,19 @@ def test_merge_manifest_min_count_to_merge( ) # tbl_a should merge all manifests into 1 - tbl_a.append(arrow_table_with_null) - tbl_a.append(arrow_table_with_null) - tbl_a.append(arrow_table_with_null) + tbl_a.merge_append(arrow_table_with_null) + tbl_a.merge_append(arrow_table_with_null) + tbl_a.merge_append(arrow_table_with_null) # tbl_b should not merge any manifests because the target size is too small - tbl_b.append(arrow_table_with_null) - tbl_b.append(arrow_table_with_null) - tbl_b.append(arrow_table_with_null) + tbl_b.merge_append(arrow_table_with_null) + tbl_b.merge_append(arrow_table_with_null) + tbl_b.merge_append(arrow_table_with_null) # tbl_c should not merge any manifests because merging is disabled - tbl_c.append(arrow_table_with_null) - tbl_c.append(arrow_table_with_null) - tbl_c.append(arrow_table_with_null) + tbl_c.merge_append(arrow_table_with_null) + tbl_c.merge_append(arrow_table_with_null) + tbl_c.merge_append(arrow_table_with_null) assert len(tbl_a.current_snapshot().manifests(tbl_a.io)) == 1 # type: ignore assert len(tbl_b.current_snapshot().manifests(tbl_b.io)) == 3 # type: ignore