From 8f7927b840594adf44f74adaaea105c4cb241a42 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 18 Jan 2024 11:19:17 +0100 Subject: [PATCH] Write support (#41) --- mkdocs/docs/api.md | 98 +++++ poetry.lock | 36 +- pyiceberg/io/pyarrow.py | 82 +++- pyiceberg/manifest.py | 100 +++-- pyiceberg/table/__init__.py | 335 ++++++++++++++- pyiceberg/table/snapshots.py | 36 +- pyproject.toml | 1 + .../{test_catalogs.py => test_reads.py} | 0 tests/integration/test_writes.py | 387 ++++++++++++++++++ tests/io/test_pyarrow_stats.py | 67 +-- tests/table/test_init.py | 6 +- tests/table/test_snapshots.py | 14 +- tests/utils/test_manifest.py | 9 +- 13 files changed, 1035 insertions(+), 136 deletions(-) rename tests/integration/{test_catalogs.py => test_reads.py} (100%) create mode 100644 tests/integration/test_writes.py diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 517e52f185..9d97d4f676 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -175,6 +175,104 @@ static_table = StaticTable.from_metadata( The static-table is considered read-only. +## Write support + +With PyIceberg 0.6.0 write support is added through Arrow. Let's consider an Arrow Table: + +```python +import pyarrow as pa + +df = pa.Table.from_pylist( + [ + {"city": "Amsterdam", "lat": 52.371807, "long": 4.896029}, + {"city": "San Francisco", "lat": 37.773972, "long": -122.431297}, + {"city": "Drachten", "lat": 53.11254, "long": 6.0989}, + {"city": "Paris", "lat": 48.864716, "long": 2.349014}, + ], +) +``` + +Next, create a table based on the schema: + +```python +from pyiceberg.catalog import load_catalog + +catalog = load_catalog("default") + +from pyiceberg.schema import Schema +from pyiceberg.types import NestedField, StringType, DoubleType + +schema = Schema( + NestedField(1, "city", StringType(), required=False), + NestedField(2, "lat", DoubleType(), required=False), + NestedField(3, "long", DoubleType(), required=False), +) + +tbl = catalog.create_table("default.cities", schema=schema) +``` + +Now write the data to the table: + + + +!!! note inline end "Fast append" + PyIceberg default to the [fast append](https://iceberg.apache.org/spec/#snapshots) to minimize the amount of data written. This enables quick writes, reducing the possibility of conflicts. The downside of the fast append is that it creates more metadata than a normal commit. [Compaction is planned](https://github.com/apache/iceberg-python/issues/270) and will automatically rewrite all the metadata when a threshold is hit, to maintain performant reads. + + + +```python +tbl.append(df) + +# or + +tbl.overwrite(df) +``` + +The data is written to the table, and when the table is read using `tbl.scan().to_arrow()`: + +``` +pyarrow.Table +city: string +lat: double +long: double +---- +city: [["Amsterdam","San Francisco","Drachten","Paris"]] +lat: [[52.371807,37.773972,53.11254,48.864716]] +long: [[4.896029,-122.431297,6.0989,2.349014]] +``` + +You both can use `append(df)` or `overwrite(df)` since there is no data yet. If we want to add more data, we can use `.append()` again: + +```python +df = pa.Table.from_pylist( + [{"city": "Groningen", "lat": 53.21917, "long": 6.56667}], +) + +tbl.append(df) +``` + +When reading the table `tbl.scan().to_arrow()` you can see that `Groningen` is now also part of the table: + +``` +pyarrow.Table +city: string +lat: double +long: double +---- +city: [["Amsterdam","San Francisco","Drachten","Paris"],["Groningen"]] +lat: [[52.371807,37.773972,53.11254,48.864716],[53.21917]] +long: [[4.896029,-122.431297,6.0989,2.349014],[6.56667]] +``` + +The nested lists indicate the different Arrow buffers, where the first write results into a buffer, and the second append in a separate buffer. This is expected since it will read two parquet files. + + + +!!! example "Under development" + Writing using PyIceberg is still under development. Support for [partial overwrites](https://github.com/apache/iceberg-python/issues/268) and writing to [partitioned tables](https://github.com/apache/iceberg-python/issues/208) is planned and being worked on. + + + ## Schema evolution PyIceberg supports full schema evolution through the Python API. It takes care of setting the field-IDs and makes sure that only non-breaking changes are done (can be overriden). diff --git a/poetry.lock b/poetry.lock index 3a22111bd7..8fc927ce95 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2566,7 +2566,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, @@ -2575,8 +2574,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, @@ -2628,6 +2625,17 @@ files = [ [package.extras] dev = ["black (==22.6.0)", "flake8", "mypy", "pytest"] +[[package]] +name = "py4j" +version = "0.10.9.7" +description = "Enables Python programs to dynamically access arbitrary Java objects" +optional = false +python-versions = "*" +files = [ + {file = "py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b"}, + {file = "py4j-0.10.9.7.tar.gz", hash = "sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb"}, +] + [[package]] name = "pyarrow" version = "14.0.2" @@ -2910,6 +2918,26 @@ files = [ [package.dependencies] tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +[[package]] +name = "pyspark" +version = "3.4.2" +description = "Apache Spark Python API" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyspark-3.4.2.tar.gz", hash = "sha256:088db1b8ff33a748b802f1710ff6f6dcef0e0f2cca7d69bbbe55b187a0d55c3f"}, +] + +[package.dependencies] +py4j = "0.10.9.7" + +[package.extras] +connect = ["googleapis-common-protos (>=1.56.4)", "grpcio (>=1.48.1)", "grpcio-status (>=1.48.1)", "numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"] +ml = ["numpy (>=1.15)"] +mllib = ["numpy (>=1.15)"] +pandas-on-spark = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"] +sql = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"] + [[package]] name = "pytest" version = "7.4.4" @@ -4189,4 +4217,4 @@ zstandard = ["zstandard"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "54d7c0ea7f06959ce2fb8c4c36b595881e56c59dc1e97bbc4a859bcc25dac542" +content-hash = "744b1c2e8e96c8626732849a1349f614ba7fdd01491a3488d850979b8192787c" diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index e116cd0a38..b4988c677a 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -105,7 +105,11 @@ OutputFile, OutputStream, ) -from pyiceberg.manifest import DataFile, FileFormat +from pyiceberg.manifest import ( + DataFile, + DataFileContent, + FileFormat, +) from pyiceberg.schema import ( PartnerAccessor, PreOrderSchemaVisitor, @@ -119,8 +123,9 @@ visit, visit_with_partner, ) +from pyiceberg.table import WriteTask from pyiceberg.transforms import TruncateTransform -from pyiceberg.typedef import EMPTY_DICT, Properties +from pyiceberg.typedef import EMPTY_DICT, Properties, Record from pyiceberg.types import ( BinaryType, BooleanType, @@ -1443,9 +1448,8 @@ def parquet_path_to_id_mapping( def fill_parquet_file_metadata( - df: DataFile, + data_file: DataFile, parquet_metadata: pq.FileMetaData, - file_size: int, stats_columns: Dict[int, StatisticsCollector], parquet_column_mapping: Dict[str, int], ) -> None: @@ -1453,8 +1457,6 @@ def fill_parquet_file_metadata( Compute and fill the following fields of the DataFile object. - file_format - - record_count - - file_size_in_bytes - column_sizes - value_counts - null_value_counts @@ -1464,11 +1466,8 @@ def fill_parquet_file_metadata( - split_offsets Args: - df (DataFile): A DataFile object representing the Parquet file for which metadata is to be filled. + data_file (DataFile): A DataFile object representing the Parquet file for which metadata is to be filled. parquet_metadata (pyarrow.parquet.FileMetaData): A pyarrow metadata object. - file_size (int): The total compressed file size cannot be retrieved from the metadata and hence has to - be passed here. Depending on the kind of file system and pyarrow library call used, different - ways to obtain this value might be appropriate. stats_columns (Dict[int, StatisticsCollector]): The statistics gathering plan. It is required to set the mode for column metrics collection """ @@ -1565,13 +1564,56 @@ def fill_parquet_file_metadata( del upper_bounds[field_id] del null_value_counts[field_id] - df.file_format = FileFormat.PARQUET - df.record_count = parquet_metadata.num_rows - df.file_size_in_bytes = file_size - df.column_sizes = column_sizes - df.value_counts = value_counts - df.null_value_counts = null_value_counts - df.nan_value_counts = nan_value_counts - df.lower_bounds = lower_bounds - df.upper_bounds = upper_bounds - df.split_offsets = split_offsets + data_file.record_count = parquet_metadata.num_rows + data_file.column_sizes = column_sizes + data_file.value_counts = value_counts + data_file.null_value_counts = null_value_counts + data_file.nan_value_counts = nan_value_counts + data_file.lower_bounds = lower_bounds + data_file.upper_bounds = upper_bounds + data_file.split_offsets = split_offsets + + +def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: + task = next(tasks) + + try: + _ = next(tasks) + # If there are more tasks, raise an exception + raise NotImplementedError("Only unpartitioned writes are supported: https://github.com/apache/iceberg-python/issues/208") + except StopIteration: + pass + + file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}' + file_schema = schema_to_pyarrow(table.schema()) + + collected_metrics: List[pq.FileMetaData] = [] + fo = table.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=file_schema, version="1.0", metadata_collector=collected_metrics) as writer: + writer.write_table(task.df) + + data_file = DataFile( + content=DataFileContent.DATA, + file_path=file_path, + file_format=FileFormat.PARQUET, + partition=Record(), + file_size_in_bytes=len(fo), + sort_order_id=task.sort_order_id, + # Just copy these from the table for now + spec_id=table.spec().spec_id, + equality_ids=None, + key_metadata=None, + ) + + if len(collected_metrics) != 1: + # One file has been written + raise ValueError(f"Expected 1 entry, got: {collected_metrics}") + + fill_parquet_file_metadata( + data_file=data_file, + parquet_metadata=collected_metrics[0], + stats_columns=compute_statistics_plan(table.schema(), table.properties), + parquet_column_mapping=parquet_path_to_id_mapping(table.schema()), + ) + return iter([data_file]) diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 0ddfcd28d5..0504626d07 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -37,7 +37,7 @@ from pyiceberg.io import FileIO, InputFile, OutputFile from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.typedef import Record +from pyiceberg.typedef import EMPTY_DICT, Record from pyiceberg.types import ( BinaryType, BooleanType, @@ -60,6 +60,8 @@ DEFAULT_BLOCK_SIZE = 67108864 # 64 * 1024 * 1024 DEFAULT_READ_VERSION: Literal[2] = 2 +INITIAL_SEQUENCE_NUMBER = 0 + class DataFileContent(int, Enum): DATA = 0 @@ -491,25 +493,42 @@ def construct_partition_summaries(spec: PartitionSpec, schema: Schema, partition return [field.to_summary() for field in field_stats] -MANIFEST_FILE_SCHEMA: Schema = Schema( - NestedField(500, "manifest_path", StringType(), required=True, doc="Location URI with FS scheme"), - NestedField(501, "manifest_length", LongType(), required=True), - NestedField(502, "partition_spec_id", IntegerType(), required=True), - NestedField(517, "content", IntegerType(), required=False, initial_default=ManifestContent.DATA), - NestedField(515, "sequence_number", LongType(), required=False, initial_default=0), - NestedField(516, "min_sequence_number", LongType(), required=False, initial_default=0), - NestedField(503, "added_snapshot_id", LongType(), required=False), - NestedField(504, "added_files_count", IntegerType(), required=False), - NestedField(505, "existing_files_count", IntegerType(), required=False), - NestedField(506, "deleted_files_count", IntegerType(), required=False), - NestedField(512, "added_rows_count", LongType(), required=False), - NestedField(513, "existing_rows_count", LongType(), required=False), - NestedField(514, "deleted_rows_count", LongType(), required=False), - NestedField(507, "partitions", ListType(508, PARTITION_FIELD_SUMMARY_TYPE, element_required=True), required=False), - NestedField(519, "key_metadata", BinaryType(), required=False), -) +MANIFEST_LIST_FILE_SCHEMAS: Dict[int, Schema] = { + 1: Schema( + NestedField(500, "manifest_path", StringType(), required=True, doc="Location URI with FS scheme"), + NestedField(501, "manifest_length", LongType(), required=True), + NestedField(502, "partition_spec_id", IntegerType(), required=True), + NestedField(503, "added_snapshot_id", LongType(), required=True), + NestedField(504, "added_files_count", IntegerType(), required=False), + NestedField(505, "existing_files_count", IntegerType(), required=False), + NestedField(506, "deleted_files_count", IntegerType(), required=False), + NestedField(512, "added_rows_count", LongType(), required=False), + NestedField(513, "existing_rows_count", LongType(), required=False), + NestedField(514, "deleted_rows_count", LongType(), required=False), + NestedField(507, "partitions", ListType(508, PARTITION_FIELD_SUMMARY_TYPE, element_required=True), required=False), + NestedField(519, "key_metadata", BinaryType(), required=False), + ), + 2: Schema( + NestedField(500, "manifest_path", StringType(), required=True, doc="Location URI with FS scheme"), + NestedField(501, "manifest_length", LongType(), required=True), + NestedField(502, "partition_spec_id", IntegerType(), required=True), + NestedField(517, "content", IntegerType(), required=True, initial_default=ManifestContent.DATA), + NestedField(515, "sequence_number", LongType(), required=True, initial_default=0), + NestedField(516, "min_sequence_number", LongType(), required=True, initial_default=0), + NestedField(503, "added_snapshot_id", LongType(), required=True), + NestedField(504, "added_files_count", IntegerType(), required=True), + NestedField(505, "existing_files_count", IntegerType(), required=True), + NestedField(506, "deleted_files_count", IntegerType(), required=True), + NestedField(512, "added_rows_count", LongType(), required=True), + NestedField(513, "existing_rows_count", LongType(), required=True), + NestedField(514, "deleted_rows_count", LongType(), required=True), + NestedField(507, "partitions", ListType(508, PARTITION_FIELD_SUMMARY_TYPE, element_required=True), required=False), + NestedField(519, "key_metadata", BinaryType(), required=False), + ), +} + +MANIFEST_LIST_FILE_STRUCTS = {format_version: schema.as_struct() for format_version, schema in MANIFEST_LIST_FILE_SCHEMAS.items()} -MANIFEST_FILE_SCHEMA_STRUCT = MANIFEST_FILE_SCHEMA.as_struct() POSITIONAL_DELETE_SCHEMA = Schema( NestedField(2147483546, "file_path", StringType()), NestedField(2147483545, "pos", IntegerType()) @@ -551,7 +570,7 @@ class ManifestFile(Record): key_metadata: Optional[bytes] def __init__(self, *data: Any, **named_data: Any) -> None: - super().__init__(*data, **{"struct": MANIFEST_FILE_SCHEMA_STRUCT, **named_data}) + super().__init__(*data, **{"struct": MANIFEST_LIST_FILE_STRUCTS[DEFAULT_READ_VERSION], **named_data}) def has_added_files(self) -> bool: return self.added_files_count is None or self.added_files_count > 0 @@ -596,7 +615,7 @@ def read_manifest_list(input_file: InputFile) -> Iterator[ManifestFile]: """ with AvroFile[ManifestFile]( input_file, - MANIFEST_FILE_SCHEMA, + MANIFEST_LIST_FILE_SCHEMAS[DEFAULT_READ_VERSION], read_types={-1: ManifestFile, 508: PartitionFieldSummary}, read_enums={517: ManifestContent}, ) as reader: @@ -659,7 +678,9 @@ class ManifestWriter(ABC): _min_data_sequence_number: Optional[int] _partitions: List[Record] - def __init__(self, spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int, meta: Dict[str, str]): + def __init__( + self, spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int, meta: Dict[str, str] = EMPTY_DICT + ) -> None: self.closed = False self._spec = spec self._schema = schema @@ -737,7 +758,7 @@ def to_manifest_file(self) -> ManifestFile: existing_rows_count=self._existing_rows, deleted_rows_count=self._deleted_rows, partitions=construct_partition_summaries(self._spec, self._schema, self._partitions), - key_metadatas=None, + key_metadata=None, ) def add_entry(self, entry: ManifestEntry) -> ManifestWriter: @@ -836,13 +857,15 @@ def write_manifest( class ManifestListWriter(ABC): + _format_version: Literal[1, 2] _output_file: OutputFile _meta: Dict[str, str] _manifest_files: List[ManifestFile] _commit_snapshot_id: int _writer: AvroOutputFile[ManifestFile] - def __init__(self, output_file: OutputFile, meta: Dict[str, str]): + def __init__(self, format_version: Literal[1, 2], output_file: OutputFile, meta: Dict[str, Any]): + self._format_version = format_version self._output_file = output_file self._meta = meta self._manifest_files = [] @@ -850,7 +873,11 @@ def __init__(self, output_file: OutputFile, meta: Dict[str, str]): def __enter__(self) -> ManifestListWriter: """Open the writer for writing.""" self._writer = AvroOutputFile[ManifestFile]( - output_file=self._output_file, file_schema=MANIFEST_FILE_SCHEMA, schema_name="manifest_file", metadata=self._meta + output_file=self._output_file, + record_schema=MANIFEST_LIST_FILE_SCHEMAS[DEFAULT_READ_VERSION], + file_schema=MANIFEST_LIST_FILE_SCHEMAS[self._format_version], + schema_name="manifest_file", + metadata=self._meta, ) self._writer.__enter__() return self @@ -874,9 +901,11 @@ def add_manifests(self, manifest_files: List[ManifestFile]) -> ManifestListWrite class ManifestListWriterV1(ManifestListWriter): - def __init__(self, output_file: OutputFile, snapshot_id: int, parent_snapshot_id: int): + def __init__(self, output_file: OutputFile, snapshot_id: int, parent_snapshot_id: Optional[int]): super().__init__( - output_file, {"snapshot-id": str(snapshot_id), "parent-snapshot-id": str(parent_snapshot_id), "format-version": "1"} + format_version=1, + output_file=output_file, + meta={"snapshot-id": str(snapshot_id), "parent-snapshot-id": str(parent_snapshot_id), "format-version": "1"}, ) def prepare_manifest(self, manifest_file: ManifestFile) -> ManifestFile: @@ -889,10 +918,11 @@ class ManifestListWriterV2(ManifestListWriter): _commit_snapshot_id: int _sequence_number: int - def __init__(self, output_file: OutputFile, snapshot_id: int, parent_snapshot_id: int, sequence_number: int): + def __init__(self, output_file: OutputFile, snapshot_id: int, parent_snapshot_id: Optional[int], sequence_number: int): super().__init__( - output_file, - { + format_version=2, + output_file=output_file, + meta={ "snapshot-id": str(snapshot_id), "parent-snapshot-id": str(parent_snapshot_id), "sequence-number": str(sequence_number), @@ -910,7 +940,7 @@ def prepare_manifest(self, manifest_file: ManifestFile) -> ManifestFile: # To validate this, check that the snapshot id matches the current commit if self._commit_snapshot_id != wrapped_manifest_file.added_snapshot_id: raise ValueError( - f"Found unassigned sequence number for a manifest from snapshot: {wrapped_manifest_file.added_snapshot_id}" + f"Found unassigned sequence number for a manifest from snapshot: {self._commit_snapshot_id} != {wrapped_manifest_file.added_snapshot_id}" ) wrapped_manifest_file.sequence_number = self._sequence_number @@ -926,11 +956,17 @@ def prepare_manifest(self, manifest_file: ManifestFile) -> ManifestFile: def write_manifest_list( - format_version: Literal[1, 2], output_file: OutputFile, snapshot_id: int, parent_snapshot_id: int, sequence_number: int + format_version: Literal[1, 2], + output_file: OutputFile, + snapshot_id: int, + parent_snapshot_id: Optional[int], + sequence_number: Optional[int], ) -> ManifestListWriter: if format_version == 1: return ManifestListWriterV1(output_file, snapshot_id, parent_snapshot_id) elif format_version == 2: + if sequence_number is None: + raise ValueError(f"Sequence-number is required for V2 tables: {sequence_number}") return ManifestListWriterV2(output_file, snapshot_id, parent_snapshot_id, sequence_number) else: raise ValueError(f"Cannot write manifest list for table version: {format_version}") diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 5292c5182c..7c1b0bdecc 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -62,7 +62,10 @@ DataFileContent, ManifestContent, ManifestEntry, + ManifestEntryStatus, ManifestFile, + write_manifest, + write_manifest_list, ) from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import ( @@ -79,7 +82,14 @@ TableMetadataUtil, ) from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef -from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry +from pyiceberg.table.snapshots import ( + Operation, + Snapshot, + SnapshotLogEntry, + SnapshotSummaryCollector, + Summary, + update_snapshot_summaries, +) from pyiceberg.table.sorting import SortOrder from pyiceberg.typedef import ( EMPTY_DICT, @@ -111,6 +121,8 @@ ALWAYS_TRUE = AlwaysTrue() TABLE_ROOT_ID = -1 +_JAVA_LONG_MAX = 9223372036854775807 + class Transaction: _table: Table @@ -210,6 +222,47 @@ def set_properties(self, **updates: str) -> Transaction: """ return self._append_updates(SetPropertiesUpdate(updates=updates)) + def add_snapshot(self, snapshot: Snapshot) -> Transaction: + """Add a new snapshot to the table. + + Returns: + The transaction with the add-snapshot staged. + """ + self._append_updates(AddSnapshotUpdate(snapshot=snapshot)) + self._append_requirements(AssertTableUUID(uuid=self._table.metadata.table_uuid)) + + return self + + def set_ref_snapshot( + self, + snapshot_id: int, + parent_snapshot_id: Optional[int], + ref_name: str, + type: str, + max_age_ref_ms: Optional[int] = None, + max_snapshot_age_ms: Optional[int] = None, + min_snapshots_to_keep: Optional[int] = None, + ) -> Transaction: + """Update a ref to a snapshot. + + Returns: + The transaction with the set-snapshot-ref staged + """ + self._append_updates( + SetSnapshotRefUpdate( + snapshot_id=snapshot_id, + parent_snapshot_id=parent_snapshot_id, + ref_name=ref_name, + type=type, + max_age_ref_ms=max_age_ref_ms, + max_snapshot_age_ms=max_snapshot_age_ms, + min_snapshots_to_keep=min_snapshots_to_keep, + ) + ) + + self._append_requirements(AssertRefSnapshotId(snapshot_id=parent_snapshot_id, ref="main")) + return self + def update_schema(self) -> UpdateSchema: """Create a new UpdateSchema to alter the columns of this table. @@ -609,7 +662,7 @@ class AssertRefSnapshotId(TableRequirement): """ type: Literal["assert-ref-snapshot-id"] = Field(default="assert-ref-snapshot-id") - ref: str + ref: str = Field(...) snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id") def validate(self, base_metadata: Optional[TableMetadata]) -> None: @@ -822,8 +875,8 @@ def location(self) -> str: def last_sequence_number(self) -> int: return self.metadata.last_sequence_number - def _next_sequence_number(self) -> int: - return INITIAL_SEQUENCE_NUMBER if self.format_version == 1 else self.last_sequence_number + 1 + def next_sequence_number(self) -> int: + return self.last_sequence_number + 1 if self.metadata.format_version > 1 else INITIAL_SEQUENCE_NUMBER def new_snapshot_id(self) -> int: """Generate a new snapshot-id that's not in use.""" @@ -856,6 +909,55 @@ def history(self) -> List[SnapshotLogEntry]: def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: return UpdateSchema(self, allow_incompatible_changes=allow_incompatible_changes, case_sensitive=case_sensitive) + def append(self, df: pa.Table) -> None: + """ + Append data to the table. + + Args: + df: The Arrow dataframe that will be appended to overwrite the table + """ + if len(self.spec().fields) > 0: + raise ValueError("Cannot write to partitioned tables") + + if len(self.sort_order().fields) > 0: + raise ValueError("Cannot write to tables with a sort-order") + + data_files = _dataframe_to_data_files(self, df=df) + merge = _MergingSnapshotProducer(operation=Operation.APPEND, table=self) + for data_file in data_files: + merge.append_data_file(data_file) + + merge.commit() + + def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE) -> None: + """ + Overwrite all the data in the table. + + Args: + df: The Arrow dataframe that will be used to overwrite the table + overwrite_filter: ALWAYS_TRUE when you overwrite all the data, + or a boolean expression in case of a partial overwrite + """ + if overwrite_filter != AlwaysTrue(): + raise NotImplementedError("Cannot overwrite a subset of a table") + + if len(self.spec().fields) > 0: + raise ValueError("Cannot write to partitioned tables") + + if len(self.sort_order().fields) > 0: + raise ValueError("Cannot write to tables with a sort-order") + + data_files = _dataframe_to_data_files(self, df=df) + merge = _MergingSnapshotProducer( + operation=Operation.OVERWRITE if self.current_snapshot() is not None else Operation.APPEND, + table=self, + ) + + for data_file in data_files: + merge.append_data_file(data_file) + + merge.commit() + def refs(self) -> Dict[str, SnapshotRef]: """Return the snapshot references in the table.""" return self.metadata.refs @@ -1068,7 +1170,7 @@ def _min_data_file_sequence_number(manifests: List[ManifestFile]) -> int: return INITIAL_SEQUENCE_NUMBER -def _match_deletes_to_datafile(data_entry: ManifestEntry, positional_delete_entries: SortedList[ManifestEntry]) -> Set[DataFile]: +def _match_deletes_to_data_file(data_entry: ManifestEntry, positional_delete_entries: SortedList[ManifestEntry]) -> Set[DataFile]: """Check if the delete file is relevant for the data file. Using the column metrics to see if the filename is in the lower and upper bound. @@ -1212,7 +1314,7 @@ def plan_files(self) -> Iterable[FileScanTask]: return [ FileScanTask( data_entry.data_file, - delete_files=_match_deletes_to_datafile( + delete_files=_match_deletes_to_data_file( data_entry, positional_delete_entries, ), @@ -1935,3 +2037,224 @@ def _generate_snapshot_id() -> int: snapshot_id = snapshot_id if snapshot_id >= 0 else snapshot_id * -1 return snapshot_id + + +@dataclass(frozen=True) +class WriteTask: + write_uuid: uuid.UUID + task_id: int + df: pa.Table + sort_order_id: Optional[int] = None + + # Later to be extended with partition information + + def generate_data_file_filename(self, extension: str) -> str: + # Mimics the behavior in the Java API: + # https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101 + return f"00000-{self.task_id}-{self.write_uuid}.{extension}" + + +def _new_manifest_path(location: str, num: int, commit_uuid: uuid.UUID) -> str: + return f'{location}/metadata/{commit_uuid}-m{num}.avro' + + +def _generate_manifest_list_path(location: str, snapshot_id: int, attempt: int, commit_uuid: uuid.UUID) -> str: + # Mimics the behavior in Java: + # https://github.com/apache/iceberg/blob/c862b9177af8e2d83122220764a056f3b96fd00c/core/src/main/java/org/apache/iceberg/SnapshotProducer.java#L491 + return f'{location}/metadata/snap-{snapshot_id}-{attempt}-{commit_uuid}.avro' + + +def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]: + from pyiceberg.io.pyarrow import write_file + + if len(table.spec().fields) > 0: + raise ValueError("Cannot write to partitioned tables") + + if len(table.sort_order().fields) > 0: + raise ValueError("Cannot write to tables with a sort-order") + + write_uuid = uuid.uuid4() + counter = itertools.count(0) + + # This is an iter, so we don't have to materialize everything every time + # This will be more relevant when we start doing partitioned writes + yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)])) + + +class _MergingSnapshotProducer: + _operation: Operation + _table: Table + _snapshot_id: int + _parent_snapshot_id: Optional[int] + _added_data_files: List[DataFile] + _commit_uuid: uuid.UUID + + def __init__(self, operation: Operation, table: Table) -> None: + self._operation = operation + self._table = table + self._snapshot_id = table.new_snapshot_id() + # Since we only support the main branch for now + self._parent_snapshot_id = snapshot.snapshot_id if (snapshot := self._table.current_snapshot()) else None + self._added_data_files = [] + self._commit_uuid = uuid.uuid4() + + def append_data_file(self, data_file: DataFile) -> _MergingSnapshotProducer: + self._added_data_files.append(data_file) + return self + + def _deleted_entries(self) -> List[ManifestEntry]: + """To determine if we need to record any deleted entries. + + With partial overwrites we have to use the predicate to evaluate + which entries are affected. + """ + if self._operation == Operation.OVERWRITE: + if self._parent_snapshot_id is not None: + previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id) + if previous_snapshot is None: + # This should never happen since you cannot overwrite an empty table + raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}") + + executor = ExecutorFactory.get_or_create() + + def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: + return [ + ManifestEntry( + status=ManifestEntryStatus.DELETED, + snapshot_id=entry.snapshot_id, + data_sequence_number=entry.data_sequence_number, + file_sequence_number=entry.file_sequence_number, + data_file=entry.data_file, + ) + for entry in manifest.fetch_manifest_entry(self._table.io, discard_deleted=True) + if entry.data_file.content == DataFileContent.DATA + ] + + list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._table.io)) + return list(chain(*list_of_entries)) + return [] + elif self._operation == Operation.APPEND: + return [] + else: + raise ValueError(f"Not implemented for: {self._operation}") + + def _manifests(self) -> List[ManifestFile]: + def _write_added_manifest() -> List[ManifestFile]: + if self._added_data_files: + output_file_location = _new_manifest_path(location=self._table.location(), num=0, commit_uuid=self._commit_uuid) + with write_manifest( + format_version=self._table.format_version, + spec=self._table.spec(), + schema=self._table.schema(), + output_file=self._table.io.new_output(output_file_location), + snapshot_id=self._snapshot_id, + ) as writer: + for data_file in self._added_data_files: + writer.add_entry( + ManifestEntry( + status=ManifestEntryStatus.ADDED, + snapshot_id=self._snapshot_id, + data_sequence_number=None, + file_sequence_number=None, + data_file=data_file, + ) + ) + return [writer.to_manifest_file()] + else: + return [] + + def _write_delete_manifest() -> List[ManifestFile]: + # Check if we need to mark the files as deleted + deleted_entries = self._deleted_entries() + if deleted_entries: + output_file_location = _new_manifest_path(location=self._table.location(), num=1, commit_uuid=self._commit_uuid) + with write_manifest( + format_version=self._table.format_version, + spec=self._table.spec(), + schema=self._table.schema(), + output_file=self._table.io.new_output(output_file_location), + snapshot_id=self._snapshot_id, + ) as writer: + for delete_entry in deleted_entries: + writer.add_entry(delete_entry) + return [writer.to_manifest_file()] + else: + return [] + + def _fetch_existing_manifests() -> List[ManifestFile]: + existing_manifests = [] + + # Add existing manifests + if self._operation == Operation.APPEND and self._parent_snapshot_id is not None: + # In case we want to append, just add the existing manifests + previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id) + + if previous_snapshot is None: + raise ValueError(f"Snapshot could not be found: {self._parent_snapshot_id}") + + for manifest in previous_snapshot.manifests(io=self._table.io): + if ( + manifest.has_added_files() + or manifest.has_existing_files() + or manifest.added_snapshot_id == self._snapshot_id + ): + existing_manifests.append(manifest) + + return existing_manifests + + executor = ExecutorFactory.get_or_create() + + added_manifests = executor.submit(_write_added_manifest) + delete_manifests = executor.submit(_write_delete_manifest) + existing_manifests = executor.submit(_fetch_existing_manifests) + + return added_manifests.result() + delete_manifests.result() + existing_manifests.result() + + def _summary(self) -> Summary: + ssc = SnapshotSummaryCollector() + + for data_file in self._added_data_files: + ssc.add_file(data_file=data_file) + + previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id) if self._parent_snapshot_id is not None else None + + return update_snapshot_summaries( + summary=Summary(operation=self._operation, **ssc.build()), + previous_summary=previous_snapshot.summary if previous_snapshot is not None else None, + truncate_full_table=self._operation == Operation.OVERWRITE, + ) + + def commit(self) -> Snapshot: + new_manifests = self._manifests() + next_sequence_number = self._table.next_sequence_number() + + summary = self._summary() + + manifest_list_file_path = _generate_manifest_list_path( + location=self._table.location(), snapshot_id=self._snapshot_id, attempt=0, commit_uuid=self._commit_uuid + ) + with write_manifest_list( + format_version=self._table.metadata.format_version, + output_file=self._table.io.new_output(manifest_list_file_path), + snapshot_id=self._snapshot_id, + parent_snapshot_id=self._parent_snapshot_id, + sequence_number=next_sequence_number, + ) as writer: + writer.add_manifests(new_manifests) + + snapshot = Snapshot( + snapshot_id=self._snapshot_id, + parent_snapshot_id=self._parent_snapshot_id, + manifest_list=manifest_list_file_path, + sequence_number=next_sequence_number, + summary=summary, + schema_id=self._table.schema().schema_id, + ) + + with self._table.transaction() as tx: + tx.add_snapshot(snapshot=snapshot) + tx.set_ref_snapshot( + snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch" + ) + + return snapshot diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index e22c95f8ee..a2f15d4405 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import time from enum import Enum from typing import ( Any, @@ -138,7 +139,7 @@ class Snapshot(IcebergBaseModel): snapshot_id: int = Field(alias="snapshot-id") parent_snapshot_id: Optional[int] = Field(alias="parent-snapshot-id", default=None) sequence_number: Optional[int] = Field(alias="sequence-number", default=None) - timestamp_ms: int = Field(alias="timestamp-ms") + timestamp_ms: int = Field(alias="timestamp-ms", default_factory=lambda: int(time.time() * 1000)) manifest_list: Optional[str] = Field( alias="manifest-list", description="Location of the snapshot's manifest list file", default=None ) @@ -277,23 +278,30 @@ def _truncate_table_summary(summary: Summary, previous_summary: Mapping[str, str }: summary[prop] = '0' - if value := previous_summary.get(TOTAL_DATA_FILES): - summary[DELETED_DATA_FILES] = value - if value := previous_summary.get(TOTAL_DELETE_FILES): - summary[REMOVED_DELETE_FILES] = value - if value := previous_summary.get(TOTAL_RECORDS): - summary[DELETED_RECORDS] = value - if value := previous_summary.get(TOTAL_FILE_SIZE): - summary[REMOVED_FILE_SIZE] = value - if value := previous_summary.get(TOTAL_POSITION_DELETES): - summary[REMOVED_POSITION_DELETES] = value - if value := previous_summary.get(TOTAL_EQUALITY_DELETES): - summary[REMOVED_EQUALITY_DELETES] = value + def get_prop(prop: str) -> int: + value = previous_summary.get(prop) or '0' + try: + return int(value) + except ValueError as e: + raise ValueError(f"Could not parse summary property {prop} to an int: {value}") from e + + if value := get_prop(TOTAL_DATA_FILES): + summary[DELETED_DATA_FILES] = str(value) + if value := get_prop(TOTAL_DELETE_FILES): + summary[REMOVED_DELETE_FILES] = str(value) + if value := get_prop(TOTAL_RECORDS): + summary[DELETED_RECORDS] = str(value) + if value := get_prop(TOTAL_FILE_SIZE): + summary[REMOVED_FILE_SIZE] = str(value) + if value := get_prop(TOTAL_POSITION_DELETES): + summary[REMOVED_POSITION_DELETES] = str(value) + if value := get_prop(TOTAL_EQUALITY_DELETES): + summary[REMOVED_EQUALITY_DELETES] = str(value) return summary -def _update_snapshot_summaries( +def update_snapshot_summaries( summary: Summary, previous_summary: Optional[Mapping[str, str]] = None, truncate_full_table: bool = False ) -> Summary: if summary.operation not in {Operation.APPEND, Operation.OVERWRITE}: diff --git a/pyproject.toml b/pyproject.toml index 63299413d5..8766dfe281 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ requests-mock = "1.11.0" moto = { version = "^4.2.13", extras = ["server"] } typing-extensions = "4.9.0" pytest-mock = "3.12.0" +pyspark = "3.4.2" cython = "3.0.8" [[tool.mypy.overrides]] diff --git a/tests/integration/test_catalogs.py b/tests/integration/test_reads.py similarity index 100% rename from tests/integration/test_catalogs.py rename to tests/integration/test_reads.py diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py new file mode 100644 index 0000000000..f8317e481d --- /dev/null +++ b/tests/integration/test_writes.py @@ -0,0 +1,387 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name +import uuid +from datetime import date, datetime + +import pyarrow as pa +import pytest +from pyspark.sql import SparkSession + +from pyiceberg.catalog import Catalog, load_catalog +from pyiceberg.exceptions import NamespaceAlreadyExistsError, NoSuchTableError +from pyiceberg.schema import Schema +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DoubleType, + FixedType, + FloatType, + IntegerType, + LongType, + NestedField, + StringType, + TimestampType, + TimestamptzType, +) + + +@pytest.fixture() +def catalog() -> Catalog: + catalog = load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) + + try: + catalog.create_namespace("default") + except NamespaceAlreadyExistsError: + pass + + return catalog + + +TEST_DATA_WITH_NULL = { + 'bool': [False, None, True], + 'string': ['a', None, 'z'], + # Go over the 16 bytes to kick in truncation + 'string_long': ['a' * 22, None, 'z' * 22], + 'int': [1, None, 9], + 'long': [1, None, 9], + 'float': [0.0, None, 0.9], + 'double': [0.0, None, 0.9], + 'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + 'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + 'date': [date(2023, 1, 1), None, date(2023, 3, 1)], + # Not supported by Spark + # 'time': [time(1, 22, 0), None, time(19, 25, 0)], + # Not natively supported by Arrow + # 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, uuid.UUID('11111111-1111-1111-1111-111111111111').bytes], + 'binary': [b'\01', None, b'\22'], + 'fixed': [ + uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, + None, + uuid.UUID('11111111-1111-1111-1111-111111111111').bytes, + ], +} + +TABLE_SCHEMA = Schema( + NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False), + NestedField(field_id=2, name="string", field_type=StringType(), required=False), + NestedField(field_id=3, name="string_long", field_type=StringType(), required=False), + NestedField(field_id=4, name="int", field_type=IntegerType(), required=False), + NestedField(field_id=5, name="long", field_type=LongType(), required=False), + NestedField(field_id=6, name="float", field_type=FloatType(), required=False), + NestedField(field_id=7, name="double", field_type=DoubleType(), required=False), + NestedField(field_id=8, name="timestamp", field_type=TimestampType(), required=False), + NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(), required=False), + NestedField(field_id=10, name="date", field_type=DateType(), required=False), + # NestedField(field_id=11, name="time", field_type=TimeType(), required=False), + # NestedField(field_id=12, name="uuid", field_type=UuidType(), required=False), + NestedField(field_id=12, name="binary", field_type=BinaryType(), required=False), + NestedField(field_id=13, name="fixed", field_type=FixedType(16), required=False), +) + + +@pytest.fixture(scope="session") +def session_catalog() -> Catalog: + return load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) + + +@pytest.fixture(scope="session") +def arrow_table_with_null() -> pa.Table: + """PyArrow table with all kinds of columns""" + pa_schema = pa.schema([ + ("bool", pa.bool_()), + ("string", pa.string()), + ("string_long", pa.string()), + ("int", pa.int32()), + ("long", pa.int64()), + ("float", pa.float32()), + ("double", pa.float64()), + ("timestamp", pa.timestamp(unit="us")), + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ("date", pa.date32()), + # Not supported by Spark + # ("time", pa.time64("us")), + # Not natively supported by Arrow + # ("uuid", pa.fixed(16)), + ("binary", pa.binary()), + ("fixed", pa.binary(16)), + ]) + return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) + + +@pytest.fixture(scope="session", autouse=True) +def table_v1_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.arrow_table_v1_with_null" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '1'}) + tbl.append(arrow_table_with_null) + + assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" + + +@pytest.fixture(scope="session", autouse=True) +def table_v1_appended_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.arrow_table_v1_appended_with_null" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '1'}) + + for _ in range(2): + tbl.append(arrow_table_with_null) + + assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" + + +@pytest.fixture(scope="session", autouse=True) +def table_v2_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.arrow_table_v2_with_null" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '2'}) + tbl.append(arrow_table_with_null) + + assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" + + +@pytest.fixture(scope="session", autouse=True) +def table_v2_appended_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.arrow_table_v2_appended_with_null" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '2'}) + + for _ in range(2): + tbl.append(arrow_table_with_null) + + assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" + + +@pytest.fixture(scope="session", autouse=True) +def table_v1_v2_appended_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.arrow_table_v1_v2_appended_with_null" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '1'}) + tbl.append(arrow_table_with_null) + + assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" + + with tbl.transaction() as tx: + tx.upgrade_table_version(format_version=2) + + tbl.append(arrow_table_with_null) + + assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" + + +@pytest.fixture(scope="session") +def spark() -> SparkSession: + import os + + os.environ["PYSPARK_SUBMIT_ARGS"] = ( + "--packages org.apache.iceberg:iceberg-spark-runtime-3.4_2.12:1.4.0,org.apache.iceberg:iceberg-aws-bundle:1.4.0 pyspark-shell" + ) + os.environ["AWS_REGION"] = "us-east-1" + os.environ["AWS_ACCESS_KEY_ID"] = "admin" + os.environ["AWS_SECRET_ACCESS_KEY"] = "password" + + spark = ( + SparkSession.builder.appName("PyIceberg integration test") + .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") + .config("spark.sql.catalog.integration", "org.apache.iceberg.spark.SparkCatalog") + .config("spark.sql.catalog.integration.catalog-impl", "org.apache.iceberg.rest.RESTCatalog") + .config("spark.sql.catalog.integration.uri", "http://localhost:8181") + .config("spark.sql.catalog.integration.io-impl", "org.apache.iceberg.aws.s3.S3FileIO") + .config("spark.sql.catalog.integration.warehouse", "s3://warehouse/wh/") + .config("spark.sql.catalog.integration.s3.endpoint", "http://localhost:9000") + .config("spark.sql.catalog.integration.s3.path-style-access", "true") + .config("spark.sql.defaultCatalog", "integration") + .getOrCreate() + ) + + return spark + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_count(spark: SparkSession, format_version: int) -> None: + df = spark.table(f"default.arrow_table_v{format_version}_with_null") + assert df.count() == 3, "Expected 3 rows" + + +@pytest.mark.integration +@pytest.mark.parametrize("col", TEST_DATA_WITH_NULL.keys()) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_filter_null(spark: SparkSession, col: str, format_version: int) -> None: + identifier = f"default.arrow_table_v{format_version}_with_null" + df = spark.table(identifier) + assert df.where(f"{col} is null").count() == 1, f"Expected 1 row for {col}" + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 rows for {col}" + + +@pytest.mark.integration +@pytest.mark.parametrize("col", TEST_DATA_WITH_NULL.keys()) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_filter_appended_null(spark: SparkSession, col: str, format_version: int) -> None: + identifier = f"default.arrow_table_v{format_version}_appended_with_null" + df = spark.table(identifier) + assert df.where(f"{col} is null").count() == 2, f"Expected 1 row for {col}" + assert df.where(f"{col} is not null").count() == 4, f"Expected 2 rows for {col}" + + +@pytest.mark.integration +@pytest.mark.parametrize("col", TEST_DATA_WITH_NULL.keys()) +def test_query_filter_v1_v2_append_null(spark: SparkSession, col: str) -> None: + identifier = "default.arrow_table_v1_v2_appended_with_null" + df = spark.table(identifier) + assert df.where(f"{col} is null").count() == 2, f"Expected 1 row for {col}" + assert df.where(f"{col} is not null").count() == 4, f"Expected 2 rows for {col}" + + +@pytest.mark.integration +def test_summaries(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.arrow_table_summaries" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '1'}) + + tbl.append(arrow_table_with_null) + tbl.append(arrow_table_with_null) + tbl.overwrite(arrow_table_with_null) + + rows = spark.sql( + f""" + SELECT operation, summary + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + + operations = [row.operation for row in rows] + assert operations == ['append', 'append', 'overwrite'] + + summaries = [row.summary for row in rows] + + assert summaries[0] == { + 'added-data-files': '1', + 'added-files-size': '5283', + 'added-records': '3', + 'total-data-files': '1', + 'total-delete-files': '0', + 'total-equality-deletes': '0', + 'total-files-size': '5283', + 'total-position-deletes': '0', + 'total-records': '3', + } + + assert summaries[1] == { + 'added-data-files': '1', + 'added-files-size': '5283', + 'added-records': '3', + 'total-data-files': '2', + 'total-delete-files': '0', + 'total-equality-deletes': '0', + 'total-files-size': '10566', + 'total-position-deletes': '0', + 'total-records': '6', + } + + assert summaries[2] == { + 'added-data-files': '1', + 'added-files-size': '5283', + 'added-records': '3', + 'deleted-data-files': '2', + 'deleted-records': '6', + 'removed-files-size': '10566', + 'total-data-files': '1', + 'total-delete-files': '0', + 'total-equality-deletes': '0', + 'total-files-size': '5283', + 'total-position-deletes': '0', + 'total-records': '3', + } + + +@pytest.mark.integration +def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.arrow_data_files" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '1'}) + + tbl.overwrite(arrow_table_with_null) + # should produce a DELETE entry + tbl.overwrite(arrow_table_with_null) + # Since we don't rewrite, this should produce a new manifest with an ADDED entry + tbl.append(arrow_table_with_null) + + rows = spark.sql( + f""" + SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count + FROM {identifier}.all_manifests + """ + ).collect() + + assert [row.added_data_files_count for row in rows] == [1, 1, 0, 1, 1] + assert [row.existing_data_files_count for row in rows] == [0, 0, 0, 0, 0] + assert [row.deleted_data_files_count for row in rows] == [0, 0, 1, 0, 0] diff --git a/tests/io/test_pyarrow_stats.py b/tests/io/test_pyarrow_stats.py index 6f00061174..01b844a43e 100644 --- a/tests/io/test_pyarrow_stats.py +++ b/tests/io/test_pyarrow_stats.py @@ -81,7 +81,7 @@ class TestStruct: y: Optional[float] -def construct_test_table() -> Tuple[Any, Any, Union[TableMetadataV1, TableMetadataV2]]: +def construct_test_table() -> Tuple[pq.FileMetaData, Union[TableMetadataV1, TableMetadataV2]]: table_metadata = { "format-version": 2, "location": "s3://bucket/test/location", @@ -172,7 +172,7 @@ def construct_test_table() -> Tuple[Any, Any, Union[TableMetadataV1, TableMetada with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector) as writer: writer.write_table(table) - return f.getvalue(), metadata_collector[0], table_metadata + return metadata_collector[0], table_metadata def get_current_schema( @@ -182,45 +182,27 @@ def get_current_schema( def test_record_count() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() + metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) datafile = DataFile() fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) assert datafile.record_count == 4 -def test_file_size() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() - - schema = get_current_schema(table_metadata) - datafile = DataFile() - fill_parquet_file_metadata( - datafile, - metadata, - len(file_bytes), - compute_statistics_plan(schema, table_metadata.properties), - parquet_path_to_id_mapping(schema), - ) - - assert datafile.file_size_in_bytes == len(file_bytes) - - def test_value_counts() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() + metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) datafile = DataFile() fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) @@ -236,14 +218,13 @@ def test_value_counts() -> None: def test_column_sizes() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() + metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) datafile = DataFile() fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) @@ -258,14 +239,13 @@ def test_column_sizes() -> None: def test_null_and_nan_counts() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() + metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) datafile = DataFile() fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) @@ -287,14 +267,13 @@ def test_null_and_nan_counts() -> None: def test_bounds() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() + metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) datafile = DataFile() fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) @@ -332,7 +311,7 @@ def test_metrics_mode_parsing() -> None: def test_metrics_mode_none() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() + metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) datafile = DataFile() @@ -340,7 +319,6 @@ def test_metrics_mode_none() -> None: fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) @@ -353,7 +331,7 @@ def test_metrics_mode_none() -> None: def test_metrics_mode_counts() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() + metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) datafile = DataFile() @@ -361,7 +339,6 @@ def test_metrics_mode_counts() -> None: fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) @@ -374,7 +351,7 @@ def test_metrics_mode_counts() -> None: def test_metrics_mode_full() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() + metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) datafile = DataFile() @@ -382,7 +359,6 @@ def test_metrics_mode_full() -> None: fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) @@ -401,7 +377,7 @@ def test_metrics_mode_full() -> None: def test_metrics_mode_non_default_trunc() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() + metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) datafile = DataFile() @@ -409,7 +385,6 @@ def test_metrics_mode_non_default_trunc() -> None: fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) @@ -428,7 +403,7 @@ def test_metrics_mode_non_default_trunc() -> None: def test_column_metrics_mode() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() + metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) datafile = DataFile() @@ -437,7 +412,6 @@ def test_column_metrics_mode() -> None: fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) @@ -455,7 +429,7 @@ def test_column_metrics_mode() -> None: assert 1 not in datafile.upper_bounds -def construct_test_table_primitive_types() -> Tuple[Any, Any, Union[TableMetadataV1, TableMetadataV2]]: +def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, Union[TableMetadataV1, TableMetadataV2]]: table_metadata = { "format-version": 2, "location": "s3://bucket/test/location", @@ -527,11 +501,11 @@ def construct_test_table_primitive_types() -> Tuple[Any, Any, Union[TableMetadat with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector) as writer: writer.write_table(table) - return f.getvalue(), metadata_collector[0], table_metadata + return metadata_collector[0], table_metadata def test_metrics_primitive_types() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table_primitive_types() + metadata, table_metadata = construct_test_table_primitive_types() schema = get_current_schema(table_metadata) datafile = DataFile() @@ -539,7 +513,6 @@ def test_metrics_primitive_types() -> None: fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) @@ -579,7 +552,7 @@ def test_metrics_primitive_types() -> None: assert datafile.upper_bounds[12] == b"wp" -def construct_test_table_invalid_upper_bound() -> Tuple[Any, Any, Union[TableMetadataV1, TableMetadataV2]]: +def construct_test_table_invalid_upper_bound() -> Tuple[pq.FileMetaData, Union[TableMetadataV1, TableMetadataV2]]: table_metadata = { "format-version": 2, "location": "s3://bucket/test/location", @@ -627,11 +600,11 @@ def construct_test_table_invalid_upper_bound() -> Tuple[Any, Any, Union[TableMet with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector) as writer: writer.write_table(table) - return f.getvalue(), metadata_collector[0], table_metadata + return metadata_collector[0], table_metadata def test_metrics_invalid_upper_bound() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table_invalid_upper_bound() + metadata, table_metadata = construct_test_table_invalid_upper_bound() schema = get_current_schema(table_metadata) datafile = DataFile() @@ -639,7 +612,6 @@ def test_metrics_invalid_upper_bound() -> None: fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) @@ -660,14 +632,13 @@ def test_metrics_invalid_upper_bound() -> None: def test_offsets() -> None: - (file_bytes, metadata, table_metadata) = construct_test_table() + metadata, table_metadata = construct_test_table() schema = get_current_schema(table_metadata) datafile = DataFile() fill_parquet_file_metadata( datafile, metadata, - len(file_bytes), compute_statistics_plan(schema, table_metadata.properties), parquet_path_to_id_mapping(schema), ) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index d3bbe418c4..efee43b192 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -59,7 +59,7 @@ UpdateSchema, _apply_table_update, _generate_snapshot_id, - _match_deletes_to_datafile, + _match_deletes_to_data_file, _TableMetadataUpdateContext, update_table_metadata, ) @@ -358,7 +358,7 @@ def test_match_deletes_to_datafile() -> None: upper_bounds={}, ), ) - assert _match_deletes_to_datafile( + assert _match_deletes_to_data_file( data_entry, SortedList(iterable=[delete_entry_1, delete_entry_2], key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER), ) == { @@ -415,7 +415,7 @@ def test_match_deletes_to_datafile_duplicate_number() -> None: upper_bounds={}, ), ) - assert _match_deletes_to_datafile( + assert _match_deletes_to_data_file( data_entry, SortedList(iterable=[delete_entry_1, delete_entry_2], key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER), ) == { diff --git a/tests/table/test_snapshots.py b/tests/table/test_snapshots.py index 124c513022..3591847ad6 100644 --- a/tests/table/test_snapshots.py +++ b/tests/table/test_snapshots.py @@ -18,7 +18,7 @@ import pytest from pyiceberg.manifest import DataFile, DataFileContent, ManifestContent, ManifestFile -from pyiceberg.table.snapshots import Operation, Snapshot, SnapshotSummaryCollector, Summary, _update_snapshot_summaries +from pyiceberg.table.snapshots import Operation, Snapshot, SnapshotSummaryCollector, Summary, update_snapshot_summaries @pytest.fixture @@ -161,7 +161,7 @@ def test_snapshot_summary_collector(data_file: DataFile) -> None: def test_merge_snapshot_summaries_empty() -> None: - assert _update_snapshot_summaries(Summary(Operation.APPEND)) == Summary( + assert update_snapshot_summaries(Summary(Operation.APPEND)) == Summary( operation=Operation.APPEND, **{ 'total-data-files': '0', @@ -175,7 +175,7 @@ def test_merge_snapshot_summaries_empty() -> None: def test_merge_snapshot_summaries_new_summary() -> None: - actual = _update_snapshot_summaries( + actual = update_snapshot_summaries( summary=Summary( operation=Operation.APPEND, **{ @@ -211,7 +211,7 @@ def test_merge_snapshot_summaries_new_summary() -> None: def test_merge_snapshot_summaries_overwrite_summary() -> None: - actual = _update_snapshot_summaries( + actual = update_snapshot_summaries( summary=Summary( operation=Operation.OVERWRITE, **{ @@ -260,17 +260,17 @@ def test_merge_snapshot_summaries_overwrite_summary() -> None: def test_invalid_operation() -> None: with pytest.raises(ValueError) as e: - _update_snapshot_summaries(summary=Summary(Operation.REPLACE)) + update_snapshot_summaries(summary=Summary(Operation.REPLACE)) assert "Operation not implemented: Operation.REPLACE" in str(e.value) with pytest.raises(ValueError) as e: - _update_snapshot_summaries(summary=Summary(Operation.DELETE)) + update_snapshot_summaries(summary=Summary(Operation.DELETE)) assert "Operation not implemented: Operation.DELETE" in str(e.value) def test_invalid_type() -> None: with pytest.raises(ValueError) as e: - _update_snapshot_summaries( + update_snapshot_summaries( summary=Summary( operation=Operation.OVERWRITE, **{ diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index 08906b68ad..6ef11a47ea 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -55,14 +55,19 @@ def _verify_metadata_with_fastavro(avro_file: str, expected_metadata: Dict[str, def test_read_manifest_entry(generated_manifest_entry_file: str) -> None: manifest = ManifestFile( - manifest_path=generated_manifest_entry_file, manifest_length=0, partition_spec_id=0, sequence_number=None, partitions=[] + manifest_path=generated_manifest_entry_file, + manifest_length=0, + partition_spec_id=0, + added_snapshot_id=0, + sequence_number=0, + partitions=[], ) manifest_entries = manifest.fetch_manifest_entry(PyArrowFileIO()) manifest_entry = manifest_entries[0] assert manifest_entry.status == ManifestEntryStatus.ADDED assert manifest_entry.snapshot_id == 8744736658442914487 - assert manifest_entry.data_sequence_number is None + assert manifest_entry.data_sequence_number == 0 assert isinstance(manifest_entry.data_file, DataFile) data_file = manifest_entry.data_file