Skip to content

Commit

Permalink
Make the snapshot creation part of the Transaction (apache#446)
Browse files Browse the repository at this point in the history
* Make the snapshot creation part of the `Transaction`

This is also how it is done in Java, and I really like it
since it allows you to easily queue up updates in a transaction.
For example, an update to the schema.

* Extend the API
  • Loading branch information
Fokko authored Feb 20, 2024
1 parent c23c24d commit 015226d
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 91 deletions.
11 changes: 6 additions & 5 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,7 +1714,7 @@ def fill_parquet_file_metadata(
data_file.split_offsets = split_offsets


def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
def write_file(table: Table, tasks: Iterator[WriteTask], file_schema: Optional[Schema] = None) -> Iterator[DataFile]:
task = next(tasks)

try:
Expand All @@ -1727,7 +1727,8 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
parquet_writer_kwargs = _get_parquet_writer_kwargs(table.properties)

file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}'
file_schema = schema_to_pyarrow(table.schema())
file_schema = file_schema or table.schema()
arrow_file_schema = schema_to_pyarrow(file_schema)

fo = table.io.new_output(file_path)
row_group_size = PropertyUtil.property_as_int(
Expand All @@ -1736,7 +1737,7 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT,
)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=file_schema, **parquet_writer_kwargs) as writer:
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
writer.write_table(task.df, row_group_size=row_group_size)

data_file = DataFile(
Expand All @@ -1758,8 +1759,8 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
fill_parquet_file_metadata(
data_file=data_file,
parquet_metadata=writer.writer.metadata,
stats_columns=compute_statistics_plan(table.schema(), table.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
stats_columns=compute_statistics_plan(file_schema, table.properties),
parquet_column_mapping=parquet_path_to_id_mapping(file_schema),
)
return iter([data_file])

Expand Down
251 changes: 165 additions & 86 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,14 @@ def update_schema(self) -> UpdateSchema:
"""
return UpdateSchema(self._table, self)

def update_snapshot(self) -> UpdateSnapshot:
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
Returns:
A new UpdateSnapshot
"""
return UpdateSnapshot(self._table, self)

def remove_properties(self, *removals: str) -> Transaction:
"""Remove properties.
Expand All @@ -351,6 +359,12 @@ def update_location(self, location: str) -> Transaction:
"""
raise NotImplementedError("Not yet implemented")

def schema(self) -> Schema:
try:
return next(update for update in self._updates if isinstance(update, AddSchemaUpdate)).schema_
except StopIteration:
return self._table.schema()

def commit_transaction(self) -> Table:
"""Commit the changes to the catalog.
Expand Down Expand Up @@ -965,8 +979,21 @@ def history(self) -> List[SnapshotLogEntry]:
return self.metadata.snapshot_log

def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema:
"""Create a new UpdateSchema to alter the columns of this table.
Returns:
A new UpdateSchema.
"""
return UpdateSchema(self, allow_incompatible_changes=allow_incompatible_changes, case_sensitive=case_sensitive)

def update_snapshot(self) -> UpdateSnapshot:
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
Returns:
A new UpdateSnapshot
"""
return UpdateSnapshot(self)

def name_mapping(self) -> NameMapping:
"""Return the table's field-id NameMapping."""
if name_mapping_json := self.properties.get(TableProperties.DEFAULT_NAME_MAPPING):
Expand All @@ -976,7 +1003,7 @@ def name_mapping(self) -> NameMapping:

def append(self, df: pa.Table) -> None:
"""
Append data to the table.
Shorthand API for appending a PyArrow table to the table.
Args:
df: The Arrow dataframe that will be appended to overwrite the table
Expand All @@ -992,19 +1019,16 @@ def append(self, df: pa.Table) -> None:
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

merge = _MergingSnapshotProducer(operation=Operation.APPEND, table=self)

# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(self, df=df)
for data_file in data_files:
merge.append_data_file(data_file)

merge.commit()
with self.update_snapshot().fast_append() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(self, df=df)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE) -> None:
"""
Overwrite all the data in the table.
Shorthand for overwriting the table with a PyArrow table.
Args:
df: The Arrow dataframe that will be used to overwrite the table
Expand All @@ -1025,18 +1049,12 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

merge = _MergingSnapshotProducer(
operation=Operation.OVERWRITE if self.current_snapshot() is not None else Operation.APPEND,
table=self,
)

# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(self, df=df)
for data_file in data_files:
merge.append_data_file(data_file)

merge.commit()
with self.update_snapshot().overwrite() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(self, df=df)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def refs(self) -> Dict[str, SnapshotRef]:
"""Return the snapshot references in the table."""
Expand Down Expand Up @@ -2331,7 +2349,12 @@ def _generate_manifest_list_path(location: str, snapshot_id: int, attempt: int,
return f'{location}/metadata/snap-{snapshot_id}-{attempt}-{commit_uuid}.avro'


def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]:
def _dataframe_to_data_files(table: Table, df: pa.Table, file_schema: Optional[Schema] = None) -> Iterable[DataFile]:
"""Convert a PyArrow table into a DataFile.
Returns:
An iterable that supplies datafiles that represent the table.
"""
from pyiceberg.io.pyarrow import write_file

if len(table.spec().fields) > 0:
Expand All @@ -2342,7 +2365,7 @@ def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]:

# This is an iter, so we don't have to materialize everything every time
# This will be more relevant when we start doing partitioned writes
yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)]))
yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)]), file_schema=file_schema)


class _MergingSnapshotProducer:
Expand All @@ -2352,55 +2375,35 @@ class _MergingSnapshotProducer:
_parent_snapshot_id: Optional[int]
_added_data_files: List[DataFile]
_commit_uuid: uuid.UUID
_transaction: Optional[Transaction]

def __init__(self, operation: Operation, table: Table) -> None:
def __init__(self, operation: Operation, table: Table, transaction: Optional[Transaction] = None) -> None:
self._operation = operation
self._table = table
self._snapshot_id = table.new_snapshot_id()
# Since we only support the main branch for now
self._parent_snapshot_id = snapshot.snapshot_id if (snapshot := self._table.current_snapshot()) else None
self._added_data_files = []
self._commit_uuid = uuid.uuid4()
self._transaction = transaction

def __enter__(self) -> _MergingSnapshotProducer:
"""Start a transaction to update the table."""
return self

def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
"""Close and commit the transaction."""
self.commit()

def append_data_file(self, data_file: DataFile) -> _MergingSnapshotProducer:
self._added_data_files.append(data_file)
return self

def _deleted_entries(self) -> List[ManifestEntry]:
"""To determine if we need to record any deleted entries.
With partial overwrites we have to use the predicate to evaluate
which entries are affected.
"""
if self._operation == Operation.OVERWRITE:
if self._parent_snapshot_id is not None:
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)
if previous_snapshot is None:
# This should never happen since you cannot overwrite an empty table
raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}")

executor = ExecutorFactory.get_or_create()

def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
return [
ManifestEntry(
status=ManifestEntryStatus.DELETED,
snapshot_id=entry.snapshot_id,
data_sequence_number=entry.data_sequence_number,
file_sequence_number=entry.file_sequence_number,
data_file=entry.data_file,
)
for entry in manifest.fetch_manifest_entry(self._table.io, discard_deleted=True)
if entry.data_file.content == DataFileContent.DATA
]
@abstractmethod
def _deleted_entries(self) -> List[ManifestEntry]: ...

list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._table.io))
return list(chain(*list_of_entries))
return []
elif self._operation == Operation.APPEND:
return []
else:
raise ValueError(f"Not implemented for: {self._operation}")
@abstractmethod
def _existing_manifests(self) -> List[ManifestFile]: ...

def _manifests(self) -> List[ManifestFile]:
def _write_added_manifest() -> List[ManifestFile]:
Expand Down Expand Up @@ -2430,7 +2433,7 @@ def _write_added_manifest() -> List[ManifestFile]:
def _write_delete_manifest() -> List[ManifestFile]:
# Check if we need to mark the files as deleted
deleted_entries = self._deleted_entries()
if deleted_entries:
if len(deleted_entries) > 0:
output_file_location = _new_manifest_path(location=self._table.location(), num=1, commit_uuid=self._commit_uuid)
with write_manifest(
format_version=self._table.format_version,
Expand All @@ -2445,32 +2448,11 @@ def _write_delete_manifest() -> List[ManifestFile]:
else:
return []

def _fetch_existing_manifests() -> List[ManifestFile]:
existing_manifests = []

# Add existing manifests
if self._operation == Operation.APPEND and self._parent_snapshot_id is not None:
# In case we want to append, just add the existing manifests
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)

if previous_snapshot is None:
raise ValueError(f"Snapshot could not be found: {self._parent_snapshot_id}")

for manifest in previous_snapshot.manifests(io=self._table.io):
if (
manifest.has_added_files()
or manifest.has_existing_files()
or manifest.added_snapshot_id == self._snapshot_id
):
existing_manifests.append(manifest)

return existing_manifests

executor = ExecutorFactory.get_or_create()

added_manifests = executor.submit(_write_added_manifest)
delete_manifests = executor.submit(_write_delete_manifest)
existing_manifests = executor.submit(_fetch_existing_manifests)
existing_manifests = executor.submit(self._existing_manifests)

return added_manifests.result() + delete_manifests.result() + existing_manifests.result()

Expand Down Expand Up @@ -2515,10 +2497,107 @@ def commit(self) -> Snapshot:
schema_id=self._table.schema().schema_id,
)

with self._table.transaction() as tx:
tx.add_snapshot(snapshot=snapshot)
tx.set_ref_snapshot(
if self._transaction is not None:
self._transaction.add_snapshot(snapshot=snapshot)
self._transaction.set_ref_snapshot(
snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch"
)
else:
with self._table.transaction() as tx:
tx.add_snapshot(snapshot=snapshot)
tx.set_ref_snapshot(
snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch"
)

return snapshot


class FastAppendFiles(_MergingSnapshotProducer):
def _existing_manifests(self) -> List[ManifestFile]:
"""To determine if there are any existing manifest files.
A fast append will add another ManifestFile to the ManifestList.
All the existing manifest files are considered existing.
"""
existing_manifests = []

if self._parent_snapshot_id is not None:
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)

if previous_snapshot is None:
raise ValueError(f"Snapshot could not be found: {self._parent_snapshot_id}")

for manifest in previous_snapshot.manifests(io=self._table.io):
if manifest.has_added_files() or manifest.has_existing_files() or manifest.added_snapshot_id == self._snapshot_id:
existing_manifests.append(manifest)

return existing_manifests

def _deleted_entries(self) -> List[ManifestEntry]:
"""To determine if we need to record any deleted manifest entries.
In case of an append, nothing is deleted.
"""
return []


class OverwriteFiles(_MergingSnapshotProducer):
def _existing_manifests(self) -> List[ManifestFile]:
"""To determine if there are any existing manifest files.
In the of a full overwrite, all the previous manifests are
considered deleted.
"""
return []

def _deleted_entries(self) -> List[ManifestEntry]:
"""To determine if we need to record any deleted entries.
With a full overwrite all the entries are considered deleted.
With partial overwrites we have to use the predicate to evaluate
which entries are affected.
"""
if self._parent_snapshot_id is not None:
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)
if previous_snapshot is None:
# This should never happen since you cannot overwrite an empty table
raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}")

executor = ExecutorFactory.get_or_create()

def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
return [
ManifestEntry(
status=ManifestEntryStatus.DELETED,
snapshot_id=entry.snapshot_id,
data_sequence_number=entry.data_sequence_number,
file_sequence_number=entry.file_sequence_number,
data_file=entry.data_file,
)
for entry in manifest.fetch_manifest_entry(self._table.io, discard_deleted=True)
if entry.data_file.content == DataFileContent.DATA
]

list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._table.io))
return list(chain(*list_of_entries))
else:
return []


class UpdateSnapshot:
_table: Table
_transaction: Optional[Transaction]

def __init__(self, table: Table, transaction: Optional[Transaction] = None) -> None:
self._table = table
self._transaction = transaction

def fast_append(self) -> FastAppendFiles:
return FastAppendFiles(table=self._table, operation=Operation.APPEND, transaction=self._transaction)

def overwrite(self) -> OverwriteFiles:
return OverwriteFiles(
table=self._table,
operation=Operation.OVERWRITE if self._table.current_snapshot() is not None else Operation.APPEND,
transaction=self._transaction,
)
Loading

0 comments on commit 015226d

Please sign in to comment.