Skip to content

Commit

Permalink
add merge_append
Browse files Browse the repository at this point in the history
  • Loading branch information
HonahX committed Jun 3, 2024
1 parent f0fc260 commit cbb8cec
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 10 deletions.
49 changes: 49 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,44 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def merge_append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Shorthand API for appending a PyArrow table to a table transaction.
Args:
df: The Arrow dataframe that will be appended to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if unsupported_partitions := [
field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform
]:
raise ValueError(
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

with self.update_snapshot(snapshot_properties=snapshot_properties).merge_append() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df,
io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
def overwrite(
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
) -> None:
Expand Down Expand Up @@ -1352,6 +1390,17 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
with self.transaction() as tx:
tx.append(df=df, snapshot_properties=snapshot_properties)

def merge_append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Shorthand API for appending a PyArrow table to the table.
Args:
df: The Arrow dataframe that will be appended to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
with self.transaction() as tx:
tx.merge_append(df=df, snapshot_properties=snapshot_properties)

def overwrite(
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
) -> None:
Expand Down
20 changes: 10 additions & 10 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null
@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_merge_manifest_min_count_to_merge(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
tbl_a = _create_table(
session_catalog,
Expand All @@ -898,19 +898,19 @@ def test_merge_manifest_min_count_to_merge(
)

# tbl_a should merge all manifests into 1
tbl_a.append(arrow_table_with_null)
tbl_a.append(arrow_table_with_null)
tbl_a.append(arrow_table_with_null)
tbl_a.merge_append(arrow_table_with_null)
tbl_a.merge_append(arrow_table_with_null)
tbl_a.merge_append(arrow_table_with_null)

# tbl_b should not merge any manifests because the target size is too small
tbl_b.append(arrow_table_with_null)
tbl_b.append(arrow_table_with_null)
tbl_b.append(arrow_table_with_null)
tbl_b.merge_append(arrow_table_with_null)
tbl_b.merge_append(arrow_table_with_null)
tbl_b.merge_append(arrow_table_with_null)

# tbl_c should not merge any manifests because merging is disabled
tbl_c.append(arrow_table_with_null)
tbl_c.append(arrow_table_with_null)
tbl_c.append(arrow_table_with_null)
tbl_c.merge_append(arrow_table_with_null)
tbl_c.merge_append(arrow_table_with_null)
tbl_c.merge_append(arrow_table_with_null)

assert len(tbl_a.current_snapshot().manifests(tbl_a.io)) == 1 # type: ignore
assert len(tbl_b.current_snapshot().manifests(tbl_b.io)) == 3 # type: ignore
Expand Down

0 comments on commit cbb8cec

Please sign in to comment.