From 4b911057f13491f30f89f133544c063133133fa5 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sat, 13 Apr 2024 15:07:51 -0400 Subject: [PATCH] Support Time Travel in InspectTable.entries (#599) * time travel in entries table * undo * Update pyiceberg/table/__init__.py Co-authored-by: Fokko Driesprong * adopt review feedback * docs --------- Co-authored-by: Fokko Driesprong --- mkdocs/docs/api.md | 12 ++ pyiceberg/table/__init__.py | 128 ++++++++++--------- tests/integration/test_inspect_table.py | 156 ++++++++++++------------ 3 files changed, 162 insertions(+), 134 deletions(-) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 15931d02fb..9bdb6dcdaa 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -342,6 +342,18 @@ table.append(df) To explore the table metadata, tables can be inspected. + + +!!! tip "Time Travel" + To inspect a tables's metadata with the time travel feature, call the inspect table method with the `snapshot_id` argument. + Time travel is supported on all metadata tables except `snapshots` and `refs`. + + ```python + table.inspect.entries(snapshot_id=805611270568163028) + ``` + + + ### Snapshots Inspect the snapshots of the table: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index ea813176fc..da4b1465be 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -3253,6 +3253,18 @@ def __init__(self, tbl: Table) -> None: except ModuleNotFoundError as e: raise ModuleNotFoundError("For metadata operations PyArrow needs to be installed") from e + def _get_snapshot(self, snapshot_id: Optional[int] = None) -> Snapshot: + if snapshot_id is not None: + if snapshot := self.tbl.metadata.snapshot_by_id(snapshot_id): + return snapshot + else: + raise ValueError(f"Cannot find snapshot with ID {snapshot_id}") + + if snapshot := self.tbl.metadata.current_snapshot(): + return snapshot + else: + raise ValueError("Cannot get a snapshot as the table does not have any.") + def snapshots(self) -> "pa.Table": import pyarrow as pa @@ -3287,7 +3299,7 @@ def snapshots(self) -> "pa.Table": schema=snapshots_schema, ) - def entries(self) -> "pa.Table": + def entries(self, snapshot_id: Optional[int] = None) -> "pa.Table": import pyarrow as pa from pyiceberg.io.pyarrow import schema_to_pyarrow @@ -3346,64 +3358,64 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType: ]) entries = [] - if snapshot := self.tbl.metadata.current_snapshot(): - for manifest in snapshot.manifests(self.tbl.io): - for entry in manifest.fetch_manifest_entry(io=self.tbl.io): - column_sizes = entry.data_file.column_sizes or {} - value_counts = entry.data_file.value_counts or {} - null_value_counts = entry.data_file.null_value_counts or {} - nan_value_counts = entry.data_file.nan_value_counts or {} - lower_bounds = entry.data_file.lower_bounds or {} - upper_bounds = entry.data_file.upper_bounds or {} - readable_metrics = { - schema.find_column_name(field.field_id): { - "column_size": column_sizes.get(field.field_id), - "value_count": value_counts.get(field.field_id), - "null_value_count": null_value_counts.get(field.field_id), - "nan_value_count": nan_value_counts.get(field.field_id), - # Makes them readable - "lower_bound": from_bytes(field.field_type, lower_bound) - if (lower_bound := lower_bounds.get(field.field_id)) - else None, - "upper_bound": from_bytes(field.field_type, upper_bound) - if (upper_bound := upper_bounds.get(field.field_id)) - else None, - } - for field in self.tbl.metadata.schema().fields - } - - partition = entry.data_file.partition - partition_record_dict = { - field.name: partition[pos] - for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields) + snapshot = self._get_snapshot(snapshot_id) + for manifest in snapshot.manifests(self.tbl.io): + for entry in manifest.fetch_manifest_entry(io=self.tbl.io): + column_sizes = entry.data_file.column_sizes or {} + value_counts = entry.data_file.value_counts or {} + null_value_counts = entry.data_file.null_value_counts or {} + nan_value_counts = entry.data_file.nan_value_counts or {} + lower_bounds = entry.data_file.lower_bounds or {} + upper_bounds = entry.data_file.upper_bounds or {} + readable_metrics = { + schema.find_column_name(field.field_id): { + "column_size": column_sizes.get(field.field_id), + "value_count": value_counts.get(field.field_id), + "null_value_count": null_value_counts.get(field.field_id), + "nan_value_count": nan_value_counts.get(field.field_id), + # Makes them readable + "lower_bound": from_bytes(field.field_type, lower_bound) + if (lower_bound := lower_bounds.get(field.field_id)) + else None, + "upper_bound": from_bytes(field.field_type, upper_bound) + if (upper_bound := upper_bounds.get(field.field_id)) + else None, } - - entries.append({ - 'status': entry.status.value, - 'snapshot_id': entry.snapshot_id, - 'sequence_number': entry.data_sequence_number, - 'file_sequence_number': entry.file_sequence_number, - 'data_file': { - "content": entry.data_file.content, - "file_path": entry.data_file.file_path, - "file_format": entry.data_file.file_format, - "partition": partition_record_dict, - "record_count": entry.data_file.record_count, - "file_size_in_bytes": entry.data_file.file_size_in_bytes, - "column_sizes": dict(entry.data_file.column_sizes), - "value_counts": dict(entry.data_file.value_counts), - "null_value_counts": dict(entry.data_file.null_value_counts), - "nan_value_counts": entry.data_file.nan_value_counts, - "lower_bounds": entry.data_file.lower_bounds, - "upper_bounds": entry.data_file.upper_bounds, - "key_metadata": entry.data_file.key_metadata, - "split_offsets": entry.data_file.split_offsets, - "equality_ids": entry.data_file.equality_ids, - "sort_order_id": entry.data_file.sort_order_id, - "spec_id": entry.data_file.spec_id, - }, - 'readable_metrics': readable_metrics, - }) + for field in self.tbl.metadata.schema().fields + } + + partition = entry.data_file.partition + partition_record_dict = { + field.name: partition[pos] + for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields) + } + + entries.append({ + 'status': entry.status.value, + 'snapshot_id': entry.snapshot_id, + 'sequence_number': entry.data_sequence_number, + 'file_sequence_number': entry.file_sequence_number, + 'data_file': { + "content": entry.data_file.content, + "file_path": entry.data_file.file_path, + "file_format": entry.data_file.file_format, + "partition": partition_record_dict, + "record_count": entry.data_file.record_count, + "file_size_in_bytes": entry.data_file.file_size_in_bytes, + "column_sizes": dict(entry.data_file.column_sizes), + "value_counts": dict(entry.data_file.value_counts), + "null_value_counts": dict(entry.data_file.null_value_counts), + "nan_value_counts": entry.data_file.nan_value_counts, + "lower_bounds": entry.data_file.lower_bounds, + "upper_bounds": entry.data_file.upper_bounds, + "key_metadata": entry.data_file.key_metadata, + "split_offsets": entry.data_file.split_offsets, + "equality_ids": entry.data_file.equality_ids, + "sort_order_id": entry.data_file.sort_order_id, + "spec_id": entry.data_file.spec_id, + }, + 'readable_metrics': readable_metrics, + }) return pa.Table.from_pylist( entries, diff --git a/tests/integration/test_inspect_table.py b/tests/integration/test_inspect_table.py index f2515caee8..7cbfc6da08 100644 --- a/tests/integration/test_inspect_table.py +++ b/tests/integration/test_inspect_table.py @@ -22,7 +22,7 @@ import pyarrow as pa import pytest import pytz -from pyspark.sql import SparkSession +from pyspark.sql import DataFrame, SparkSession from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError @@ -148,81 +148,85 @@ def test_inspect_entries( # Write some data tbl.append(arrow_table_with_null) - df = tbl.inspect.entries() - - assert df.column_names == [ - 'status', - 'snapshot_id', - 'sequence_number', - 'file_sequence_number', - 'data_file', - 'readable_metrics', - ] - - # Make sure that they are filled properly - for int_column in ['status', 'snapshot_id', 'sequence_number', 'file_sequence_number']: - for value in df[int_column]: - assert isinstance(value.as_py(), int) - - for snapshot_id in df['snapshot_id']: - assert isinstance(snapshot_id.as_py(), int) - - lhs = df.to_pandas() - rhs = spark.table(f"{identifier}.entries").toPandas() - for column in df.column_names: - for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): - if column == 'data_file': - right = right.asDict(recursive=True) - for df_column in left.keys(): - if df_column == 'partition': - # Spark leaves out the partition if the table is unpartitioned - continue - - df_lhs = left[df_column] - df_rhs = right[df_column] - if isinstance(df_rhs, dict): - # Arrow turns dicts into lists of tuple - df_lhs = dict(df_lhs) - - assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}" - elif column == 'readable_metrics': - right = right.asDict(recursive=True) - - assert list(left.keys()) == [ - 'bool', - 'string', - 'string_long', - 'int', - 'long', - 'float', - 'double', - 'timestamp', - 'timestamptz', - 'date', - 'binary', - 'fixed', - ] - - assert left.keys() == right.keys() - - for rm_column in left.keys(): - rm_lhs = left[rm_column] - rm_rhs = right[rm_column] - - assert rm_lhs['column_size'] == rm_rhs['column_size'] - assert rm_lhs['value_count'] == rm_rhs['value_count'] - assert rm_lhs['null_value_count'] == rm_rhs['null_value_count'] - assert rm_lhs['nan_value_count'] == rm_rhs['nan_value_count'] - - if rm_column == 'timestamptz': - # PySpark does not correctly set the timstamptz - rm_rhs['lower_bound'] = rm_rhs['lower_bound'].replace(tzinfo=pytz.utc) - rm_rhs['upper_bound'] = rm_rhs['upper_bound'].replace(tzinfo=pytz.utc) - - assert rm_lhs['lower_bound'] == rm_rhs['lower_bound'] - assert rm_lhs['upper_bound'] == rm_rhs['upper_bound'] - else: - assert left == right, f"Difference in column {column}: {left} != {right}" + def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None: + assert df.column_names == [ + 'status', + 'snapshot_id', + 'sequence_number', + 'file_sequence_number', + 'data_file', + 'readable_metrics', + ] + + # Make sure that they are filled properly + for int_column in ['status', 'snapshot_id', 'sequence_number', 'file_sequence_number']: + for value in df[int_column]: + assert isinstance(value.as_py(), int) + + for snapshot_id in df['snapshot_id']: + assert isinstance(snapshot_id.as_py(), int) + + lhs = df.to_pandas() + rhs = spark_df.toPandas() + for column in df.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + if column == 'data_file': + right = right.asDict(recursive=True) + for df_column in left.keys(): + if df_column == 'partition': + # Spark leaves out the partition if the table is unpartitioned + continue + + df_lhs = left[df_column] + df_rhs = right[df_column] + if isinstance(df_rhs, dict): + # Arrow turns dicts into lists of tuple + df_lhs = dict(df_lhs) + + assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}" + elif column == 'readable_metrics': + right = right.asDict(recursive=True) + + assert list(left.keys()) == [ + 'bool', + 'string', + 'string_long', + 'int', + 'long', + 'float', + 'double', + 'timestamp', + 'timestamptz', + 'date', + 'binary', + 'fixed', + ] + + assert left.keys() == right.keys() + + for rm_column in left.keys(): + rm_lhs = left[rm_column] + rm_rhs = right[rm_column] + + assert rm_lhs['column_size'] == rm_rhs['column_size'] + assert rm_lhs['value_count'] == rm_rhs['value_count'] + assert rm_lhs['null_value_count'] == rm_rhs['null_value_count'] + assert rm_lhs['nan_value_count'] == rm_rhs['nan_value_count'] + + if rm_column == 'timestamptz': + # PySpark does not correctly set the timstamptz + rm_rhs['lower_bound'] = rm_rhs['lower_bound'].replace(tzinfo=pytz.utc) + rm_rhs['upper_bound'] = rm_rhs['upper_bound'].replace(tzinfo=pytz.utc) + + assert rm_lhs['lower_bound'] == rm_rhs['lower_bound'] + assert rm_lhs['upper_bound'] == rm_rhs['upper_bound'] + else: + assert left == right, f"Difference in column {column}: {left} != {right}" + + for snapshot in tbl.metadata.snapshots: + df = tbl.inspect.entries(snapshot_id=snapshot.snapshot_id) + spark_df = spark.sql(f"SELECT * FROM {identifier}.entries VERSION AS OF {snapshot.snapshot_id}") + check_pyiceberg_df_equals_spark_df(df, spark_df) @pytest.mark.integration