Skip to content

Commit

Permalink
Support Time Travel in InspectTable.entries (apache#599)
Browse files Browse the repository at this point in the history
* time travel in entries table

* undo

* Update pyiceberg/table/__init__.py

Co-authored-by: Fokko Driesprong <[email protected]>

* adopt review feedback

* docs

---------

Co-authored-by: Fokko Driesprong <[email protected]>
  • Loading branch information
sungwy and Fokko authored Apr 13, 2024
1 parent 35d4648 commit 4b91105
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 134 deletions.
12 changes: 12 additions & 0 deletions mkdocs/docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,18 @@ table.append(df)

To explore the table metadata, tables can be inspected.

<!-- prettier-ignore-start -->

!!! tip "Time Travel"
To inspect a tables's metadata with the time travel feature, call the inspect table method with the `snapshot_id` argument.
Time travel is supported on all metadata tables except `snapshots` and `refs`.

```python
table.inspect.entries(snapshot_id=805611270568163028)
```

<!-- prettier-ignore-end -->

### Snapshots

Inspect the snapshots of the table:
Expand Down
128 changes: 70 additions & 58 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3253,6 +3253,18 @@ def __init__(self, tbl: Table) -> None:
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For metadata operations PyArrow needs to be installed") from e

def _get_snapshot(self, snapshot_id: Optional[int] = None) -> Snapshot:
if snapshot_id is not None:
if snapshot := self.tbl.metadata.snapshot_by_id(snapshot_id):
return snapshot
else:
raise ValueError(f"Cannot find snapshot with ID {snapshot_id}")

if snapshot := self.tbl.metadata.current_snapshot():
return snapshot
else:
raise ValueError("Cannot get a snapshot as the table does not have any.")

def snapshots(self) -> "pa.Table":
import pyarrow as pa

Expand Down Expand Up @@ -3287,7 +3299,7 @@ def snapshots(self) -> "pa.Table":
schema=snapshots_schema,
)

def entries(self) -> "pa.Table":
def entries(self, snapshot_id: Optional[int] = None) -> "pa.Table":
import pyarrow as pa

from pyiceberg.io.pyarrow import schema_to_pyarrow
Expand Down Expand Up @@ -3346,64 +3358,64 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
])

entries = []
if snapshot := self.tbl.metadata.current_snapshot():
for manifest in snapshot.manifests(self.tbl.io):
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
column_sizes = entry.data_file.column_sizes or {}
value_counts = entry.data_file.value_counts or {}
null_value_counts = entry.data_file.null_value_counts or {}
nan_value_counts = entry.data_file.nan_value_counts or {}
lower_bounds = entry.data_file.lower_bounds or {}
upper_bounds = entry.data_file.upper_bounds or {}
readable_metrics = {
schema.find_column_name(field.field_id): {
"column_size": column_sizes.get(field.field_id),
"value_count": value_counts.get(field.field_id),
"null_value_count": null_value_counts.get(field.field_id),
"nan_value_count": nan_value_counts.get(field.field_id),
# Makes them readable
"lower_bound": from_bytes(field.field_type, lower_bound)
if (lower_bound := lower_bounds.get(field.field_id))
else None,
"upper_bound": from_bytes(field.field_type, upper_bound)
if (upper_bound := upper_bounds.get(field.field_id))
else None,
}
for field in self.tbl.metadata.schema().fields
}

partition = entry.data_file.partition
partition_record_dict = {
field.name: partition[pos]
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
snapshot = self._get_snapshot(snapshot_id)
for manifest in snapshot.manifests(self.tbl.io):
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
column_sizes = entry.data_file.column_sizes or {}
value_counts = entry.data_file.value_counts or {}
null_value_counts = entry.data_file.null_value_counts or {}
nan_value_counts = entry.data_file.nan_value_counts or {}
lower_bounds = entry.data_file.lower_bounds or {}
upper_bounds = entry.data_file.upper_bounds or {}
readable_metrics = {
schema.find_column_name(field.field_id): {
"column_size": column_sizes.get(field.field_id),
"value_count": value_counts.get(field.field_id),
"null_value_count": null_value_counts.get(field.field_id),
"nan_value_count": nan_value_counts.get(field.field_id),
# Makes them readable
"lower_bound": from_bytes(field.field_type, lower_bound)
if (lower_bound := lower_bounds.get(field.field_id))
else None,
"upper_bound": from_bytes(field.field_type, upper_bound)
if (upper_bound := upper_bounds.get(field.field_id))
else None,
}

entries.append({
'status': entry.status.value,
'snapshot_id': entry.snapshot_id,
'sequence_number': entry.data_sequence_number,
'file_sequence_number': entry.file_sequence_number,
'data_file': {
"content": entry.data_file.content,
"file_path": entry.data_file.file_path,
"file_format": entry.data_file.file_format,
"partition": partition_record_dict,
"record_count": entry.data_file.record_count,
"file_size_in_bytes": entry.data_file.file_size_in_bytes,
"column_sizes": dict(entry.data_file.column_sizes),
"value_counts": dict(entry.data_file.value_counts),
"null_value_counts": dict(entry.data_file.null_value_counts),
"nan_value_counts": entry.data_file.nan_value_counts,
"lower_bounds": entry.data_file.lower_bounds,
"upper_bounds": entry.data_file.upper_bounds,
"key_metadata": entry.data_file.key_metadata,
"split_offsets": entry.data_file.split_offsets,
"equality_ids": entry.data_file.equality_ids,
"sort_order_id": entry.data_file.sort_order_id,
"spec_id": entry.data_file.spec_id,
},
'readable_metrics': readable_metrics,
})
for field in self.tbl.metadata.schema().fields
}

partition = entry.data_file.partition
partition_record_dict = {
field.name: partition[pos]
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
}

entries.append({
'status': entry.status.value,
'snapshot_id': entry.snapshot_id,
'sequence_number': entry.data_sequence_number,
'file_sequence_number': entry.file_sequence_number,
'data_file': {
"content": entry.data_file.content,
"file_path": entry.data_file.file_path,
"file_format": entry.data_file.file_format,
"partition": partition_record_dict,
"record_count": entry.data_file.record_count,
"file_size_in_bytes": entry.data_file.file_size_in_bytes,
"column_sizes": dict(entry.data_file.column_sizes),
"value_counts": dict(entry.data_file.value_counts),
"null_value_counts": dict(entry.data_file.null_value_counts),
"nan_value_counts": entry.data_file.nan_value_counts,
"lower_bounds": entry.data_file.lower_bounds,
"upper_bounds": entry.data_file.upper_bounds,
"key_metadata": entry.data_file.key_metadata,
"split_offsets": entry.data_file.split_offsets,
"equality_ids": entry.data_file.equality_ids,
"sort_order_id": entry.data_file.sort_order_id,
"spec_id": entry.data_file.spec_id,
},
'readable_metrics': readable_metrics,
})

return pa.Table.from_pylist(
entries,
Expand Down
156 changes: 80 additions & 76 deletions tests/integration/test_inspect_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pyarrow as pa
import pytest
import pytz
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame, SparkSession

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
Expand Down Expand Up @@ -148,81 +148,85 @@ def test_inspect_entries(
# Write some data
tbl.append(arrow_table_with_null)

df = tbl.inspect.entries()

assert df.column_names == [
'status',
'snapshot_id',
'sequence_number',
'file_sequence_number',
'data_file',
'readable_metrics',
]

# Make sure that they are filled properly
for int_column in ['status', 'snapshot_id', 'sequence_number', 'file_sequence_number']:
for value in df[int_column]:
assert isinstance(value.as_py(), int)

for snapshot_id in df['snapshot_id']:
assert isinstance(snapshot_id.as_py(), int)

lhs = df.to_pandas()
rhs = spark.table(f"{identifier}.entries").toPandas()
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if column == 'data_file':
right = right.asDict(recursive=True)
for df_column in left.keys():
if df_column == 'partition':
# Spark leaves out the partition if the table is unpartitioned
continue

df_lhs = left[df_column]
df_rhs = right[df_column]
if isinstance(df_rhs, dict):
# Arrow turns dicts into lists of tuple
df_lhs = dict(df_lhs)

assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}"
elif column == 'readable_metrics':
right = right.asDict(recursive=True)

assert list(left.keys()) == [
'bool',
'string',
'string_long',
'int',
'long',
'float',
'double',
'timestamp',
'timestamptz',
'date',
'binary',
'fixed',
]

assert left.keys() == right.keys()

for rm_column in left.keys():
rm_lhs = left[rm_column]
rm_rhs = right[rm_column]

assert rm_lhs['column_size'] == rm_rhs['column_size']
assert rm_lhs['value_count'] == rm_rhs['value_count']
assert rm_lhs['null_value_count'] == rm_rhs['null_value_count']
assert rm_lhs['nan_value_count'] == rm_rhs['nan_value_count']

if rm_column == 'timestamptz':
# PySpark does not correctly set the timstamptz
rm_rhs['lower_bound'] = rm_rhs['lower_bound'].replace(tzinfo=pytz.utc)
rm_rhs['upper_bound'] = rm_rhs['upper_bound'].replace(tzinfo=pytz.utc)

assert rm_lhs['lower_bound'] == rm_rhs['lower_bound']
assert rm_lhs['upper_bound'] == rm_rhs['upper_bound']
else:
assert left == right, f"Difference in column {column}: {left} != {right}"
def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
assert df.column_names == [
'status',
'snapshot_id',
'sequence_number',
'file_sequence_number',
'data_file',
'readable_metrics',
]

# Make sure that they are filled properly
for int_column in ['status', 'snapshot_id', 'sequence_number', 'file_sequence_number']:
for value in df[int_column]:
assert isinstance(value.as_py(), int)

for snapshot_id in df['snapshot_id']:
assert isinstance(snapshot_id.as_py(), int)

lhs = df.to_pandas()
rhs = spark_df.toPandas()
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if column == 'data_file':
right = right.asDict(recursive=True)
for df_column in left.keys():
if df_column == 'partition':
# Spark leaves out the partition if the table is unpartitioned
continue

df_lhs = left[df_column]
df_rhs = right[df_column]
if isinstance(df_rhs, dict):
# Arrow turns dicts into lists of tuple
df_lhs = dict(df_lhs)

assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}"
elif column == 'readable_metrics':
right = right.asDict(recursive=True)

assert list(left.keys()) == [
'bool',
'string',
'string_long',
'int',
'long',
'float',
'double',
'timestamp',
'timestamptz',
'date',
'binary',
'fixed',
]

assert left.keys() == right.keys()

for rm_column in left.keys():
rm_lhs = left[rm_column]
rm_rhs = right[rm_column]

assert rm_lhs['column_size'] == rm_rhs['column_size']
assert rm_lhs['value_count'] == rm_rhs['value_count']
assert rm_lhs['null_value_count'] == rm_rhs['null_value_count']
assert rm_lhs['nan_value_count'] == rm_rhs['nan_value_count']

if rm_column == 'timestamptz':
# PySpark does not correctly set the timstamptz
rm_rhs['lower_bound'] = rm_rhs['lower_bound'].replace(tzinfo=pytz.utc)
rm_rhs['upper_bound'] = rm_rhs['upper_bound'].replace(tzinfo=pytz.utc)

assert rm_lhs['lower_bound'] == rm_rhs['lower_bound']
assert rm_lhs['upper_bound'] == rm_rhs['upper_bound']
else:
assert left == right, f"Difference in column {column}: {left} != {right}"

for snapshot in tbl.metadata.snapshots:
df = tbl.inspect.entries(snapshot_id=snapshot.snapshot_id)
spark_df = spark.sql(f"SELECT * FROM {identifier}.entries VERSION AS OF {snapshot.snapshot_id}")
check_pyiceberg_df_equals_spark_df(df, spark_df)


@pytest.mark.integration
Expand Down

0 comments on commit 4b91105

Please sign in to comment.