From d69407ceed545e72c24643e2abcd2d6ec335c26c Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Thu, 4 Apr 2024 03:18:05 -0400 Subject: [PATCH] Move writes to the transaction class (#571) --- pyiceberg/table/__init__.py | 160 +++++++++++++++++++------------ tests/integration/test_writes.py | 28 ++++++ 2 files changed, 127 insertions(+), 61 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 4e968eb616..0f113f3b5e 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -356,6 +356,100 @@ def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> U """ return UpdateSnapshot(self, io=self._table.io, snapshot_properties=snapshot_properties) + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + """ + Shorthand API for appending a PyArrow table to a table transaction. + + Args: + df: The Arrow dataframe that will be appended to overwrite the table + snapshot_properties: Custom properties to be added to the snapshot summary + """ + try: + import pyarrow as pa + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + if not isinstance(df, pa.Table): + raise ValueError(f"Expected PyArrow table, got: {df}") + + if len(self._table.spec().fields) > 0: + raise ValueError("Cannot write to partitioned tables") + + _check_schema_compatible(self._table.schema(), other_schema=df.schema) + # cast if the two schemas are compatible but not equal + table_arrow_schema = self._table.schema().as_arrow() + if table_arrow_schema != df.schema: + df = df.cast(table_arrow_schema) + + with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: + # skip writing data files if the dataframe is empty + if df.shape[0] > 0: + data_files = _dataframe_to_data_files( + table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io + ) + for data_file in data_files: + update_snapshot.append_data_file(data_file) + + def overwrite( + self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT + ) -> None: + """ + Shorthand for adding a table overwrite with a PyArrow table to the transaction. + + Args: + df: The Arrow dataframe that will be used to overwrite the table + overwrite_filter: ALWAYS_TRUE when you overwrite all the data, + or a boolean expression in case of a partial overwrite + snapshot_properties: Custom properties to be added to the snapshot summary + """ + try: + import pyarrow as pa + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + if not isinstance(df, pa.Table): + raise ValueError(f"Expected PyArrow table, got: {df}") + + if overwrite_filter != AlwaysTrue(): + raise NotImplementedError("Cannot overwrite a subset of a table") + + if len(self._table.spec().fields) > 0: + raise ValueError("Cannot write to partitioned tables") + + _check_schema_compatible(self._table.schema(), other_schema=df.schema) + # cast if the two schemas are compatible but not equal + table_arrow_schema = self._table.schema().as_arrow() + if table_arrow_schema != df.schema: + df = df.cast(table_arrow_schema) + + with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot: + # skip writing data files if the dataframe is empty + if df.shape[0] > 0: + data_files = _dataframe_to_data_files( + table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io + ) + for data_file in data_files: + update_snapshot.append_data_file(data_file) + + def add_files(self, file_paths: List[str]) -> None: + """ + Shorthand API for adding files as data files to the table transaction. + + Args: + file_paths: The list of full file paths to be added as data files to the table + + Raises: + FileNotFoundError: If the file does not exist. + """ + if self._table.name_mapping() is None: + self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self._table.schema().name_mapping.model_dump_json()}) + with self.update_snapshot().fast_append() as update_snapshot: + data_files = _parquet_files_to_data_files( + table_metadata=self._table.metadata, file_paths=file_paths, io=self._table.io + ) + for data_file in data_files: + update_snapshot.append_data_file(data_file) + def update_spec(self) -> UpdateSpec: """Create a new UpdateSpec to update the partitioning of the table. @@ -1219,32 +1313,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) df: The Arrow dataframe that will be appended to overwrite the table snapshot_properties: Custom properties to be added to the snapshot summary """ - try: - import pyarrow as pa - except ModuleNotFoundError as e: - raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - - if not isinstance(df, pa.Table): - raise ValueError(f"Expected PyArrow table, got: {df}") - - if len(self.spec().fields) > 0: - raise ValueError("Cannot write to partitioned tables") - - _check_schema_compatible(self.schema(), other_schema=df.schema) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) - - with self.transaction() as txn: - with txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: - # skip writing data files if the dataframe is empty - if df.shape[0] > 0: - data_files = _dataframe_to_data_files( - table_metadata=self.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self.io - ) - for data_file in data_files: - update_snapshot.append_data_file(data_file) + with self.transaction() as tx: + tx.append(df=df, snapshot_properties=snapshot_properties) def overwrite( self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT @@ -1258,35 +1328,8 @@ def overwrite( or a boolean expression in case of a partial overwrite snapshot_properties: Custom properties to be added to the snapshot summary """ - try: - import pyarrow as pa - except ModuleNotFoundError as e: - raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - - if not isinstance(df, pa.Table): - raise ValueError(f"Expected PyArrow table, got: {df}") - - if overwrite_filter != AlwaysTrue(): - raise NotImplementedError("Cannot overwrite a subset of a table") - - if len(self.spec().fields) > 0: - raise ValueError("Cannot write to partitioned tables") - - _check_schema_compatible(self.schema(), other_schema=df.schema) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) - - with self.transaction() as txn: - with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot: - # skip writing data files if the dataframe is empty - if df.shape[0] > 0: - data_files = _dataframe_to_data_files( - table_metadata=self.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self.io - ) - for data_file in data_files: - update_snapshot.append_data_file(data_file) + with self.transaction() as tx: + tx.overwrite(df=df, overwrite_filter=overwrite_filter, snapshot_properties=snapshot_properties) def add_files(self, file_paths: List[str]) -> None: """ @@ -1299,12 +1342,7 @@ def add_files(self, file_paths: List[str]) -> None: FileNotFoundError: If the file does not exist. """ with self.transaction() as tx: - if self.name_mapping() is None: - tx.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self.schema().name_mapping.model_dump_json()}) - with tx.update_snapshot().fast_append() as update_snapshot: - data_files = _parquet_files_to_data_files(table_metadata=self.metadata, file_paths=file_paths, io=self.io) - for data_file in data_files: - update_snapshot.append_data_file(data_file) + tx.add_files(file_paths=file_paths) def update_spec(self, case_sensitive: bool = True) -> UpdateSpec: return UpdateSpec(Transaction(self, autocommit=True), case_sensitive=case_sensitive) diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py index e8ad6b08fa..7756702368 100644 --- a/tests/integration/test_writes.py +++ b/tests/integration/test_writes.py @@ -832,3 +832,31 @@ def test_inspect_snapshots( continue assert left == right, f"Difference in column {column}: {left} != {right}" + + +@pytest.mark.integration +def test_write_within_transaction(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.write_in_open_transaction" + tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, []) + + def get_metadata_entries_count(identifier: str) -> int: + return spark.sql( + f""" + SELECT * + FROM {identifier}.metadata_log_entries + """ + ).count() + + # one metadata entry from table creation + assert get_metadata_entries_count(identifier) == 1 + + # one more metadata entry from transaction + with tbl.transaction() as tx: + tx.set_properties({"test": "1"}) + tx.append(arrow_table_with_null) + assert get_metadata_entries_count(identifier) == 2 + + # two more metadata entries added from two separate transactions + tbl.transaction().set_properties({"test": "2"}).commit_transaction() + tbl.append(arrow_table_with_null) + assert get_metadata_entries_count(identifier) == 4