From 1629d28c0596948c30d1fc517959cfc9be7d912a Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sat, 1 Jun 2024 02:53:17 +0000 Subject: [PATCH 1/9] _task_to_table to _task_to_record_batches --- pyiceberg/io/pyarrow.py | 56 +++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 71925c27cd..f2751ecc24 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -655,12 +655,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedAr } -def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: int) -> pa.Array: +def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start_index: int, end_index: int) -> pa.Array: if len(positional_deletes) == 1: all_chunks = positional_deletes[0] else: all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes])) - return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False) + return np.subtract(np.setdiff1d(np.arange(start_index, end_index), all_chunks, assume_unique=False), start_index) def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema: @@ -967,7 +967,7 @@ def _field_id(self, field: pa.Field) -> int: return -1 -def _task_to_table( +def _task_to_record_batches( fs: FileSystem, task: FileScanTask, bound_row_filter: BooleanExpression, @@ -975,9 +975,8 @@ def _task_to_table( projected_field_ids: Set[int], positional_deletes: Optional[List[ChunkedArray]], case_sensitive: bool, - limit: Optional[int] = None, name_mapping: Optional[NameMapping] = None, -) -> Optional[pa.Table]: +) -> Iterator[pa.RecordBatch]: _, _, path = PyArrowFileIO.parse_location(task.file.file_path) arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with fs.open_input_file(path) as fin: @@ -1005,36 +1004,27 @@ def _task_to_table( columns=[col.name for col in file_project_schema.columns], ) - if positional_deletes: - # Create the mask of indices that we're interested in - indices = _combine_positional_deletes(positional_deletes, fragment.count_rows()) - - if limit: - if pyarrow_filter is not None: - # In case of the filter, we don't exactly know how many rows - # we need to fetch upfront, can be optimized in the future: - # https://github.com/apache/arrow/issues/35301 - arrow_table = fragment_scanner.take(indices) - arrow_table = arrow_table.filter(pyarrow_filter) - arrow_table = arrow_table.slice(0, limit) - else: - arrow_table = fragment_scanner.take(indices[0:limit]) - else: - arrow_table = fragment_scanner.take(indices) + current_index = 0 + batches = fragment_scanner.to_batches() + for batch in batches: + if positional_deletes: + # Create the mask of indices that we're interested in + indices = _combine_positional_deletes(positional_deletes, current_index, len(batch)) + + batch = batch.take(indices) # Apply the user filter if pyarrow_filter is not None: + # we need to switch back and forth between RecordBatch and Table + # as Expression filter isn't yet supported in RecordBatch + # https://github.com/apache/arrow/issues/39220 + arrow_table = pa.Table.from_batches([batch]) arrow_table = arrow_table.filter(pyarrow_filter) - else: - # If there are no deletes, we can just take the head - # and the user-filter is already applied - if limit: - arrow_table = fragment_scanner.head(limit) + arrow_batches = arrow_table.to_batches() + for arrow_batch in arrow_batches: + yield to_requested_schema(projected_schema, file_project_schema, arrow_table) else: - arrow_table = fragment_scanner.to_table() - - if len(arrow_table) < 1: - return None - return to_requested_schema(projected_schema, file_project_schema, arrow_table) + yield to_requested_schema(projected_schema, file_project_schema, arrow_table) + current_index += len(batch) def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: @@ -1147,7 +1137,7 @@ def project_table( return result -def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table: +def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.RecordBatch) -> pa.RecordBatch: struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) arrays = [] @@ -1156,7 +1146,7 @@ def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa array = struct_array.field(pos) arrays.append(array) fields.append(pa.field(field.name, array.type, field.optional)) - return pa.Table.from_arrays(arrays, schema=pa.schema(fields)) + return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields)) class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]): From f604b15a1bdbc953be94e8f2b822bf281d3e583f Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sun, 2 Jun 2024 02:28:11 +0000 Subject: [PATCH 2/9] to_arrow_batches --- pyiceberg/io/pyarrow.py | 118 +++++++++++++++++++++++++++++++----- pyiceberg/table/__init__.py | 13 ++++ 2 files changed, 116 insertions(+), 15 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index f2751ecc24..49a0987e5b 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1009,9 +1009,11 @@ def _task_to_record_batches( for batch in batches: if positional_deletes: # Create the mask of indices that we're interested in - indices = _combine_positional_deletes(positional_deletes, current_index, len(batch)) - + indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch)) + print(f"DEBUG: {indices=} {current_index=} {len(batch)=}") + print(f"{batch=}") batch = batch.take(indices) + print(f"{batch=}") # Apply the user filter if pyarrow_filter is not None: # we need to switch back and forth between RecordBatch and Table @@ -1019,14 +1021,27 @@ def _task_to_record_batches( # https://github.com/apache/arrow/issues/39220 arrow_table = pa.Table.from_batches([batch]) arrow_table = arrow_table.filter(pyarrow_filter) - arrow_batches = arrow_table.to_batches() - for arrow_batch in arrow_batches: - yield to_requested_schema(projected_schema, file_project_schema, arrow_table) - else: - yield to_requested_schema(projected_schema, file_project_schema, arrow_table) + batch = arrow_table.to_batches()[0] + yield to_requested_schema(projected_schema, file_project_schema, batch) current_index += len(batch) +def _task_to_table( + fs: FileSystem, + task: FileScanTask, + bound_row_filter: BooleanExpression, + projected_schema: Schema, + projected_field_ids: Set[int], + positional_deletes: Optional[List[ChunkedArray]], + case_sensitive: bool, + name_mapping: Optional[NameMapping] = None, +) -> pa.Table: + batches = _task_to_record_batches( + fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping + ) + return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema)) + + def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: deletes_per_file: Dict[str, List[ChunkedArray]] = {} unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) @@ -1103,7 +1118,6 @@ def project_table( projected_field_ids, deletes_per_file.get(task.file.file_path), case_sensitive, - limit, table_metadata.name_mapping(), ) for task in tasks @@ -1137,8 +1151,78 @@ def project_table( return result -def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.RecordBatch) -> pa.RecordBatch: - struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) +def project_batches( + tasks: Iterable[FileScanTask], + table_metadata: TableMetadata, + io: FileIO, + row_filter: BooleanExpression, + projected_schema: Schema, + case_sensitive: bool = True, + limit: Optional[int] = None, +) -> Iterator[pa.ReordBatch]: + """Resolve the right columns based on the identifier. + + Args: + tasks (Iterable[FileScanTask]): A URI or a path to a local file. + table_metadata (TableMetadata): The table metadata of the table that's being queried + io (FileIO): A FileIO to open streams to the object store + row_filter (BooleanExpression): The expression for filtering rows. + projected_schema (Schema): The output schema. + case_sensitive (bool): Case sensitivity when looking up column names. + limit (Optional[int]): Limit the number of records. + + Raises: + ResolveError: When an incompatible query is done. + """ + scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location) + if isinstance(io, PyArrowFileIO): + fs = io.fs_by_scheme(scheme, netloc) + else: + try: + from pyiceberg.io.fsspec import FsspecFileIO + + if isinstance(io, FsspecFileIO): + from pyarrow.fs import PyFileSystem + + fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme))) + else: + raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") + except ModuleNotFoundError as e: + # When FsSpec is not installed + raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e + + bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive) + + projected_field_ids = { + id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType)) + }.union(extract_field_ids(bound_row_filter)) + + deletes_per_file = _read_all_delete_files(fs, tasks) + + total_row_count = 0 + + for task in tasks: + batches = _task_to_record_batches( + fs, + task, + bound_row_filter, + projected_schema, + projected_field_ids, + deletes_per_file.get(task.file.file_path), + case_sensitive, + table_metadata.name_mapping(), + ) + for batch in batches: + if limit is not None: + if total_row_count + len(batch) >= limit: + yield batch.take(limit - total_row_count) + break + yield batch + total_row_count += len(batch) + + +def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch: + struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) arrays = [] fields = [] @@ -1247,8 +1331,8 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st if isinstance(partner_struct, pa.StructArray): return partner_struct.field(name) - elif isinstance(partner_struct, pa.Table): - return partner_struct.column(name).combine_chunks() + elif isinstance(partner_struct, pa.RecordBatch): + return partner_struct.column(name) return None @@ -1785,15 +1869,19 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT def write_parquet(task: WriteTask) -> DataFile: table_schema = task.schema - arrow_table = pa.Table.from_batches(task.record_batches) + # if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly # otherwise use the original schema if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema: file_schema = sanitized_schema - arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table) + batches = [ + to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch) + for batch in task.record_batches + ] else: file_schema = table_schema - + batches = task.record_batches + arrow_table = pa.Table.from_batches(batches) file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}' fo = io.new_output(file_path) with fo.create(overwrite=True) as fos: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index f160ab2441..5a0425924e 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1763,6 +1763,19 @@ def to_arrow(self) -> pa.Table: limit=self.limit, ) + def to_arrow_batches(self) -> pa.Table: + from pyiceberg.io.pyarrow import project_batches + + return project_batches( + self.plan_files(), + self.table_metadata, + self.io, + self.row_filter, + self.projection(), + case_sensitive=self.case_sensitive, + limit=self.limit, + ) + def to_pandas(self, **kwargs: Any) -> pd.DataFrame: return self.to_arrow().to_pandas(**kwargs) From 83e09d6bb3efe1e20e307c39ad7762f77c7a726b Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 3 Jun 2024 21:02:25 +0000 Subject: [PATCH 3/9] tests --- pyiceberg/io/pyarrow.py | 12 ++++++--- tests/integration/test_reads.py | 47 +++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 49a0987e5b..79798cfa13 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1010,10 +1010,7 @@ def _task_to_record_batches( if positional_deletes: # Create the mask of indices that we're interested in indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch)) - print(f"DEBUG: {indices=} {current_index=} {len(batch)=}") - print(f"{batch=}") batch = batch.take(indices) - print(f"{batch=}") # Apply the user filter if pyarrow_filter is not None: # we need to switch back and forth between RecordBatch and Table @@ -1039,7 +1036,14 @@ def _task_to_table( batches = _task_to_record_batches( fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping ) - return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema)) + # https://github.com/apache/iceberg-python/issues/791 + # schema_to_pyarrow does not always match the physical_schema of the fragment + # Hence we only use it to infer the pyarrow schema when the table is guaranteed to be empty + list_of_batches = list(batches) + if len(list_of_batches) > 0 and len(list_of_batches[0]) > 0: + pa.Table.from_batches(list_of_batches) + else: + return pa.Table.from_batches(list_of_batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False)) def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 80a6f18632..56de74338e 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -19,8 +19,10 @@ import math import time import uuid +from typing import Iterator from urllib.parse import urlparse +import pyarrow as pa import pyarrow.parquet as pq import pytest from hive_metastore.ttypes import LockRequest, LockResponse, LockState, UnlockRequest @@ -174,6 +176,51 @@ def test_pyarrow_not_nan_count(catalog: Catalog) -> None: assert len(not_nan) == 2 +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_pyarrow_batches_nan(catalog: Catalog) -> None: + table_test_null_nan = catalog.load_table("default.table_test_null_nan") + arrow_batches = table_test_null_nan.scan( + row_filter=IsNaN("col_numeric"), selected_fields=("idx", "col_numeric") + ).to_arrow_batches() + assert isinstance(arrow_batches, Iterator) + list_of_batches = list(arrow_batches) + assert len(list_of_batches) == 1 + arrow_table = pa.Table.from_batches(list_of_batches) + assert len(arrow_table) == 1 + assert arrow_table["idx"][0].as_py() == 1 + assert math.isnan(arrow_table["col_numeric"][0].as_py()) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_pyarrow_batches_nan_rewritten(catalog: Catalog) -> None: + table_test_null_nan_rewritten = catalog.load_table("default.test_null_nan_rewritten") + arrow_batches = table_test_null_nan_rewritten.scan( + row_filter=IsNaN("col_numeric"), selected_fields=("idx", "col_numeric") + ).to_arrow_batches() + assert isinstance(arrow_batches, Iterator) + list_of_batches = list(arrow_batches) + assert len(list_of_batches) == 1 + arrow_table = pa.Table.from_batches(list_of_batches) + assert len(arrow_table) == 1 + assert arrow_table["idx"][0].as_py() == 1 + assert math.isnan(arrow_table["col_numeric"][0].as_py()) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +@pytest.mark.skip(reason="Fixing issues with NaN's: https://github.com/apache/arrow/issues/34162") +def test_pyarrow_batches_not_nan_count(catalog: Catalog) -> None: + table_test_null_nan = catalog.load_table("default.test_batches_null_nan") + arrow_batches = table_test_null_nan.scan(row_filter=NotNaN("col_numeric"), selected_fields=("idx",)).to_arrow_batches() + assert isinstance(arrow_batches, Iterator) + list_of_batches = list(arrow_batches) + assert len(list_of_batches) == 1 + arrow_table = pa.Table.from_batches(list_of_batches) + assert len(arrow_table) == 2 + + @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_duckdb_nan(catalog: Catalog) -> None: From 7a4d7d275c4d5b6f986b9352b5f9ca12a0873c71 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 3 Jun 2024 21:25:07 +0000 Subject: [PATCH 4/9] fix --- pyiceberg/io/pyarrow.py | 2 +- tests/integration/test_reads.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 79798cfa13..cc92a348e5 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1041,7 +1041,7 @@ def _task_to_table( # Hence we only use it to infer the pyarrow schema when the table is guaranteed to be empty list_of_batches = list(batches) if len(list_of_batches) > 0 and len(list_of_batches[0]) > 0: - pa.Table.from_batches(list_of_batches) + return pa.Table.from_batches(list_of_batches) else: return pa.Table.from_batches(list_of_batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False)) diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 56de74338e..97f255a8e9 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -179,7 +179,7 @@ def test_pyarrow_not_nan_count(catalog: Catalog) -> None: @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_pyarrow_batches_nan(catalog: Catalog) -> None: - table_test_null_nan = catalog.load_table("default.table_test_null_nan") + table_test_null_nan = catalog.load_table("default.test_null_nan") arrow_batches = table_test_null_nan.scan( row_filter=IsNaN("col_numeric"), selected_fields=("idx", "col_numeric") ).to_arrow_batches() From 63d2b78c16a06612601570eb9c42bc1c9dbb2114 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 3 Jun 2024 21:38:05 +0000 Subject: [PATCH 5/9] fix --- tests/integration/test_reads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 97f255a8e9..4b126cabf7 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -212,7 +212,7 @@ def test_pyarrow_batches_nan_rewritten(catalog: Catalog) -> None: @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) @pytest.mark.skip(reason="Fixing issues with NaN's: https://github.com/apache/arrow/issues/34162") def test_pyarrow_batches_not_nan_count(catalog: Catalog) -> None: - table_test_null_nan = catalog.load_table("default.test_batches_null_nan") + table_test_null_nan = catalog.load_table("default.test_null_nan") arrow_batches = table_test_null_nan.scan(row_filter=NotNaN("col_numeric"), selected_fields=("idx",)).to_arrow_batches() assert isinstance(arrow_batches, Iterator) list_of_batches = list(arrow_batches) From 6b95390b3970fa1385d15883f2118cdb75787438 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Tue, 4 Jun 2024 16:03:26 +0000 Subject: [PATCH 6/9] deletes --- pyiceberg/io/pyarrow.py | 2 +- pyiceberg/table/__init__.py | 3 +- tests/integration/test_reads.py | 82 +++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 2 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index cc92a348e5..fe44b4d453 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1219,7 +1219,7 @@ def project_batches( for batch in batches: if limit is not None: if total_row_count + len(batch) >= limit: - yield batch.take(limit - total_row_count) + yield batch.slice(0, limit - total_row_count) break yield batch total_row_count += len(batch) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 5a0425924e..452cdb57e9 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -33,6 +33,7 @@ Dict, Generic, Iterable, + Iterator, List, Literal, Optional, @@ -1763,7 +1764,7 @@ def to_arrow(self) -> pa.Table: limit=self.limit, ) - def to_arrow_batches(self) -> pa.Table: + def to_arrow_batches(self) -> Iterator[pa.RecordBatch]: from pyiceberg.io.pyarrow import project_batches return project_batches( diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 4b126cabf7..baf3a47c38 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -401,6 +401,88 @@ def test_pyarrow_deletes_double(catalog: Catalog) -> None: assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10] +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_pyarrow_batches_deletes(catalog: Catalog) -> None: + # number, letter + # (1, 'a'), + # (2, 'b'), + # (3, 'c'), + # (4, 'd'), + # (5, 'e'), + # (6, 'f'), + # (7, 'g'), + # (8, 'h'), + # (9, 'i'), <- deleted + # (10, 'j'), + # (11, 'k'), + # (12, 'l') + test_positional_mor_deletes = catalog.load_table("default.test_positional_mor_deletes") + arrow_table = pa.Table.from_batches(test_positional_mor_deletes.scan().to_arrow_batches()) + assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12] + + # Checking the filter + arrow_table = pa.Table.from_batches( + test_positional_mor_deletes.scan( + row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k")) + ).to_arrow_batches() + ) + assert arrow_table["number"].to_pylist() == [5, 6, 7, 8, 10] + + # Testing the combination of a filter and a limit + arrow_table = pa.Table.from_batches( + test_positional_mor_deletes.scan( + row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k")), limit=1 + ).to_arrow_batches() + ) + assert arrow_table["number"].to_pylist() == [5] + + # Testing the slicing of indices + arrow_table = pa.Table.from_batches(test_positional_mor_deletes.scan(limit=3).to_arrow_batches()) + assert arrow_table["number"].to_pylist() == [1, 2, 3] + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_pyarrow_batches_deletes_double(catalog: Catalog) -> None: + # number, letter + # (1, 'a'), + # (2, 'b'), + # (3, 'c'), + # (4, 'd'), + # (5, 'e'), + # (6, 'f'), <- second delete + # (7, 'g'), + # (8, 'h'), + # (9, 'i'), <- first delete + # (10, 'j'), + # (11, 'k'), + # (12, 'l') + test_positional_mor_double_deletes = catalog.load_table("default.test_positional_mor_double_deletes") + arrow_table = pa.Table.from_batches(test_positional_mor_double_deletes.scan().to_arrow_batches()) + assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10, 11, 12] + + # Checking the filter + arrow_table = pa.Table.from_batches( + test_positional_mor_double_deletes.scan( + row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k")) + ).to_arrow_batches() + ) + assert arrow_table["number"].to_pylist() == [5, 7, 8, 10] + + # Testing the combination of a filter and a limit + arrow_table = pa.Table.from_batches( + test_positional_mor_double_deletes.scan( + row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k")), limit=1 + ).to_arrow_batches() + ) + assert arrow_table["number"].to_pylist() == [5] + + # Testing the slicing of indices + arrow_table = pa.Table.from_batches(test_positional_mor_double_deletes.scan(limit=8).to_arrow_batches()) + assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10] + + @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_partitioned_tables(catalog: Catalog) -> None: From 25704558c1342b5c1cdb64d48b5eb7eb3f863c93 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sun, 9 Jun 2024 17:10:19 +0000 Subject: [PATCH 7/9] batch reader --- mkdocs/docs/api.md | 9 +++++ pyiceberg/table/__init__.py | 30 ++++++++------ tests/integration/test_reads.py | 69 ++++++++++++++++----------------- 3 files changed, 60 insertions(+), 48 deletions(-) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 70b5fd62eb..dd5b0e9da8 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -981,6 +981,15 @@ tpep_dropoff_datetime: [[2021-04-01 00:47:59.000000,...,2021-05-01 00:14:47.0000 This will only pull in the files that that might contain matching rows. +One can also return a PyArrow RecordBatchReader, if reading one record batch at a time is preferred: + +```python +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader() +``` + ### Pandas diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 452cdb57e9..fad422a0ea 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -33,7 +33,6 @@ Dict, Generic, Iterable, - Iterator, List, Literal, Optional, @@ -1764,18 +1763,25 @@ def to_arrow(self) -> pa.Table: limit=self.limit, ) - def to_arrow_batches(self) -> Iterator[pa.RecordBatch]: - from pyiceberg.io.pyarrow import project_batches - - return project_batches( - self.plan_files(), - self.table_metadata, - self.io, - self.row_filter, - self.projection(), - case_sensitive=self.case_sensitive, - limit=self.limit, + def to_arrow_batch_reader(self) -> pa.RecordBatchReader: + from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow + import pyarrow as pa + + reader = pa.RecordBatchReader.from_batches( + schema_to_pyarrow(self.projection()), + project_batches( + self.plan_files(), + self.table_metadata, + self.io, + self.row_filter, + self.projection(), + case_sensitive=self.case_sensitive, + limit=self.limit, + ), ) + # Cast the reader to its projected schema its projected schema for consistency + # https://github.com/apache/iceberg-python/issues/791 + return reader.cast(reader.schema) def to_pandas(self, **kwargs: Any) -> pd.DataFrame: return self.to_arrow().to_pandas(**kwargs) diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index baf3a47c38..078abf406a 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -19,7 +19,6 @@ import math import time import uuid -from typing import Iterator from urllib.parse import urlparse import pyarrow as pa @@ -180,13 +179,11 @@ def test_pyarrow_not_nan_count(catalog: Catalog) -> None: @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_pyarrow_batches_nan(catalog: Catalog) -> None: table_test_null_nan = catalog.load_table("default.test_null_nan") - arrow_batches = table_test_null_nan.scan( + arrow_batch_reader = table_test_null_nan.scan( row_filter=IsNaN("col_numeric"), selected_fields=("idx", "col_numeric") - ).to_arrow_batches() - assert isinstance(arrow_batches, Iterator) - list_of_batches = list(arrow_batches) - assert len(list_of_batches) == 1 - arrow_table = pa.Table.from_batches(list_of_batches) + ).to_arrow_batch_reader() + assert isinstance(arrow_batch_reader, pa.RecordBatchReader) + arrow_table = arrow_batch_reader.read_all() assert len(arrow_table) == 1 assert arrow_table["idx"][0].as_py() == 1 assert math.isnan(arrow_table["col_numeric"][0].as_py()) @@ -196,13 +193,11 @@ def test_pyarrow_batches_nan(catalog: Catalog) -> None: @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_pyarrow_batches_nan_rewritten(catalog: Catalog) -> None: table_test_null_nan_rewritten = catalog.load_table("default.test_null_nan_rewritten") - arrow_batches = table_test_null_nan_rewritten.scan( + arrow_batch_reader = table_test_null_nan_rewritten.scan( row_filter=IsNaN("col_numeric"), selected_fields=("idx", "col_numeric") - ).to_arrow_batches() - assert isinstance(arrow_batches, Iterator) - list_of_batches = list(arrow_batches) - assert len(list_of_batches) == 1 - arrow_table = pa.Table.from_batches(list_of_batches) + ).to_arrow_batch_reader() + assert isinstance(arrow_batch_reader, pa.RecordBatchReader) + arrow_table = arrow_batch_reader.read_all() assert len(arrow_table) == 1 assert arrow_table["idx"][0].as_py() == 1 assert math.isnan(arrow_table["col_numeric"][0].as_py()) @@ -213,11 +208,11 @@ def test_pyarrow_batches_nan_rewritten(catalog: Catalog) -> None: @pytest.mark.skip(reason="Fixing issues with NaN's: https://github.com/apache/arrow/issues/34162") def test_pyarrow_batches_not_nan_count(catalog: Catalog) -> None: table_test_null_nan = catalog.load_table("default.test_null_nan") - arrow_batches = table_test_null_nan.scan(row_filter=NotNaN("col_numeric"), selected_fields=("idx",)).to_arrow_batches() - assert isinstance(arrow_batches, Iterator) - list_of_batches = list(arrow_batches) - assert len(list_of_batches) == 1 - arrow_table = pa.Table.from_batches(list_of_batches) + arrow_batch_reader = table_test_null_nan.scan( + row_filter=NotNaN("col_numeric"), selected_fields=("idx",) + ).to_arrow_batch_reader() + assert isinstance(arrow_batch_reader, pa.RecordBatchReader) + arrow_table = arrow_batch_reader.read_all() assert len(arrow_table) == 2 @@ -418,27 +413,27 @@ def test_pyarrow_batches_deletes(catalog: Catalog) -> None: # (11, 'k'), # (12, 'l') test_positional_mor_deletes = catalog.load_table("default.test_positional_mor_deletes") - arrow_table = pa.Table.from_batches(test_positional_mor_deletes.scan().to_arrow_batches()) + arrow_table = test_positional_mor_deletes.scan().to_arrow_batch_reader().read_all() assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12] # Checking the filter - arrow_table = pa.Table.from_batches( - test_positional_mor_deletes.scan( - row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k")) - ).to_arrow_batches() + arrow_table = ( + test_positional_mor_deletes.scan(row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k"))) + .to_arrow_batch_reader() + .read_all() ) assert arrow_table["number"].to_pylist() == [5, 6, 7, 8, 10] # Testing the combination of a filter and a limit - arrow_table = pa.Table.from_batches( - test_positional_mor_deletes.scan( - row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k")), limit=1 - ).to_arrow_batches() + arrow_table = ( + test_positional_mor_deletes.scan(row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k")), limit=1) + .to_arrow_batch_reader() + .read_all() ) assert arrow_table["number"].to_pylist() == [5] # Testing the slicing of indices - arrow_table = pa.Table.from_batches(test_positional_mor_deletes.scan(limit=3).to_arrow_batches()) + arrow_table = test_positional_mor_deletes.scan(limit=3).to_arrow_batch_reader().read_all() assert arrow_table["number"].to_pylist() == [1, 2, 3] @@ -459,27 +454,29 @@ def test_pyarrow_batches_deletes_double(catalog: Catalog) -> None: # (11, 'k'), # (12, 'l') test_positional_mor_double_deletes = catalog.load_table("default.test_positional_mor_double_deletes") - arrow_table = pa.Table.from_batches(test_positional_mor_double_deletes.scan().to_arrow_batches()) + arrow_table = test_positional_mor_double_deletes.scan().to_arrow_batch_reader().read_all() assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10, 11, 12] # Checking the filter - arrow_table = pa.Table.from_batches( - test_positional_mor_double_deletes.scan( - row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k")) - ).to_arrow_batches() + arrow_table = ( + test_positional_mor_double_deletes.scan(row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k"))) + .to_arrow_batch_reader() + .read_all() ) assert arrow_table["number"].to_pylist() == [5, 7, 8, 10] # Testing the combination of a filter and a limit - arrow_table = pa.Table.from_batches( + arrow_table = ( test_positional_mor_double_deletes.scan( row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k")), limit=1 - ).to_arrow_batches() + ) + .to_arrow_batch_reader() + .read_all() ) assert arrow_table["number"].to_pylist() == [5] # Testing the slicing of indices - arrow_table = pa.Table.from_batches(test_positional_mor_double_deletes.scan(limit=8).to_arrow_batches()) + arrow_table = test_positional_mor_double_deletes.scan(limit=8).to_arrow_batch_reader().read_all() assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10] From 905cc7a9fbae0f3d4e4d7f96111e11c5ecd89a25 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Fri, 14 Jun 2024 22:26:09 +0000 Subject: [PATCH 8/9] merge main --- pyiceberg/io/pyarrow.py | 6 +++--- pyiceberg/table/__init__.py | 10 ++++------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index a049584197..8face1a421 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1918,9 +1918,9 @@ def write_parquet(task: WriteTask) -> DataFile: file_schema = table_schema batches = [ - to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch) - for batch in task.record_batches - ] + to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch) + for batch in task.record_batches + ] arrow_table = pa.Table.from_batches(batches) file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}' fo = io.new_output(file_path) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 1313c74793..4f9ea6d409 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1781,10 +1781,11 @@ def to_arrow(self) -> pa.Table: ) def to_arrow_batch_reader(self) -> pa.RecordBatchReader: - from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow import pyarrow as pa - - reader = pa.RecordBatchReader.from_batches( + + from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow + + return pa.RecordBatchReader.from_batches( schema_to_pyarrow(self.projection()), project_batches( self.plan_files(), @@ -1796,9 +1797,6 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: limit=self.limit, ), ) - # Cast the reader to its projected schema its projected schema for consistency - # https://github.com/apache/iceberg-python/issues/791 - return reader.cast(reader.schema) def to_pandas(self, **kwargs: Any) -> pd.DataFrame: return self.to_arrow().to_pandas(**kwargs) From 39a99c472829e4eced155213c78cd77d2d1b23e1 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Fri, 21 Jun 2024 12:57:40 +0000 Subject: [PATCH 9/9] adopt review feedback --- pyiceberg/io/pyarrow.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 8face1a421..e6490ae156 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1066,14 +1066,7 @@ def _task_to_table( batches = _task_to_record_batches( fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping ) - # https://github.com/apache/iceberg-python/issues/791 - # schema_to_pyarrow does not always match the physical_schema of the fragment - # Hence we only use it to infer the pyarrow schema when the table is guaranteed to be empty - list_of_batches = list(batches) - if len(list_of_batches) > 0 and len(list_of_batches[0]) > 0: - return pa.Table.from_batches(list_of_batches) - else: - return pa.Table.from_batches(list_of_batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False)) + return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False)) def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: @@ -1193,7 +1186,7 @@ def project_batches( projected_schema: Schema, case_sensitive: bool = True, limit: Optional[int] = None, -) -> Iterator[pa.ReordBatch]: +) -> Iterator[pa.RecordBatch]: """Resolve the right columns based on the identifier. Args: @@ -1373,6 +1366,8 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st return partner_struct.field(name) elif isinstance(partner_struct, pa.RecordBatch): return partner_struct.column(name) + else: + raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}") return None