From b8c5bb77c5ea436aeced17676aa30d09c1224ed9 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Fri, 21 Jun 2024 09:44:24 -0400 Subject: [PATCH] Support `Table.to_arrow_batch_reader` (#786) * _task_to_table to _task_to_record_batches * to_arrow_batches * tests * fix * fix * deletes * batch reader * merge main * adopt review feedback --- mkdocs/docs/api.md | 9 ++ pyiceberg/io/pyarrow.py | 155 ++++++++++++++++++++++++-------- pyiceberg/table/__init__.py | 18 ++++ tests/integration/test_reads.py | 126 ++++++++++++++++++++++++++ 4 files changed, 269 insertions(+), 39 deletions(-) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 6bbd9abea1..54f4a20c57 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -1003,6 +1003,15 @@ tpep_dropoff_datetime: [[2021-04-01 00:47:59.000000,...,2021-05-01 00:14:47.0000 This will only pull in the files that that might contain matching rows. +One can also return a PyArrow RecordBatchReader, if reading one record batch at a time is preferred: + +```python +table.scan( + row_filter=GreaterThanOrEqual("trip_distance", 10.0), + selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"), +).to_arrow_batch_reader() +``` + ### Pandas diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 935b78cece..e6490ae156 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -655,12 +655,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedAr } -def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: int) -> pa.Array: +def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start_index: int, end_index: int) -> pa.Array: if len(positional_deletes) == 1: all_chunks = positional_deletes[0] else: all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes])) - return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False) + return np.subtract(np.setdiff1d(np.arange(start_index, end_index), all_chunks, assume_unique=False), start_index) def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema: @@ -995,7 +995,7 @@ def _field_id(self, field: pa.Field) -> int: return -1 -def _task_to_table( +def _task_to_record_batches( fs: FileSystem, task: FileScanTask, bound_row_filter: BooleanExpression, @@ -1003,9 +1003,8 @@ def _task_to_table( projected_field_ids: Set[int], positional_deletes: Optional[List[ChunkedArray]], case_sensitive: bool, - limit: Optional[int] = None, name_mapping: Optional[NameMapping] = None, -) -> Optional[pa.Table]: +) -> Iterator[pa.RecordBatch]: _, _, path = PyArrowFileIO.parse_location(task.file.file_path) arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with fs.open_input_file(path) as fin: @@ -1035,36 +1034,39 @@ def _task_to_table( columns=[col.name for col in file_project_schema.columns], ) - if positional_deletes: - # Create the mask of indices that we're interested in - indices = _combine_positional_deletes(positional_deletes, fragment.count_rows()) - - if limit: - if pyarrow_filter is not None: - # In case of the filter, we don't exactly know how many rows - # we need to fetch upfront, can be optimized in the future: - # https://github.com/apache/arrow/issues/35301 - arrow_table = fragment_scanner.take(indices) - arrow_table = arrow_table.filter(pyarrow_filter) - arrow_table = arrow_table.slice(0, limit) - else: - arrow_table = fragment_scanner.take(indices[0:limit]) - else: - arrow_table = fragment_scanner.take(indices) + current_index = 0 + batches = fragment_scanner.to_batches() + for batch in batches: + if positional_deletes: + # Create the mask of indices that we're interested in + indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch)) + batch = batch.take(indices) # Apply the user filter if pyarrow_filter is not None: + # we need to switch back and forth between RecordBatch and Table + # as Expression filter isn't yet supported in RecordBatch + # https://github.com/apache/arrow/issues/39220 + arrow_table = pa.Table.from_batches([batch]) arrow_table = arrow_table.filter(pyarrow_filter) - else: - # If there are no deletes, we can just take the head - # and the user-filter is already applied - if limit: - arrow_table = fragment_scanner.head(limit) - else: - arrow_table = fragment_scanner.to_table() + batch = arrow_table.to_batches()[0] + yield to_requested_schema(projected_schema, file_project_schema, batch) + current_index += len(batch) - if len(arrow_table) < 1: - return None - return to_requested_schema(projected_schema, file_project_schema, arrow_table) + +def _task_to_table( + fs: FileSystem, + task: FileScanTask, + bound_row_filter: BooleanExpression, + projected_schema: Schema, + projected_field_ids: Set[int], + positional_deletes: Optional[List[ChunkedArray]], + case_sensitive: bool, + name_mapping: Optional[NameMapping] = None, +) -> pa.Table: + batches = _task_to_record_batches( + fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping + ) + return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False)) def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: @@ -1143,7 +1145,6 @@ def project_table( projected_field_ids, deletes_per_file.get(task.file.file_path), case_sensitive, - limit, table_metadata.name_mapping(), ) for task in tasks @@ -1177,8 +1178,78 @@ def project_table( return result -def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table: - struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) +def project_batches( + tasks: Iterable[FileScanTask], + table_metadata: TableMetadata, + io: FileIO, + row_filter: BooleanExpression, + projected_schema: Schema, + case_sensitive: bool = True, + limit: Optional[int] = None, +) -> Iterator[pa.RecordBatch]: + """Resolve the right columns based on the identifier. + + Args: + tasks (Iterable[FileScanTask]): A URI or a path to a local file. + table_metadata (TableMetadata): The table metadata of the table that's being queried + io (FileIO): A FileIO to open streams to the object store + row_filter (BooleanExpression): The expression for filtering rows. + projected_schema (Schema): The output schema. + case_sensitive (bool): Case sensitivity when looking up column names. + limit (Optional[int]): Limit the number of records. + + Raises: + ResolveError: When an incompatible query is done. + """ + scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location) + if isinstance(io, PyArrowFileIO): + fs = io.fs_by_scheme(scheme, netloc) + else: + try: + from pyiceberg.io.fsspec import FsspecFileIO + + if isinstance(io, FsspecFileIO): + from pyarrow.fs import PyFileSystem + + fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme))) + else: + raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") + except ModuleNotFoundError as e: + # When FsSpec is not installed + raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e + + bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive) + + projected_field_ids = { + id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType)) + }.union(extract_field_ids(bound_row_filter)) + + deletes_per_file = _read_all_delete_files(fs, tasks) + + total_row_count = 0 + + for task in tasks: + batches = _task_to_record_batches( + fs, + task, + bound_row_filter, + projected_schema, + projected_field_ids, + deletes_per_file.get(task.file.file_path), + case_sensitive, + table_metadata.name_mapping(), + ) + for batch in batches: + if limit is not None: + if total_row_count + len(batch) >= limit: + yield batch.slice(0, limit - total_row_count) + break + yield batch + total_row_count += len(batch) + + +def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch: + struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) arrays = [] fields = [] @@ -1186,7 +1257,7 @@ def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa array = struct_array.field(pos) arrays.append(array) fields.append(pa.field(field.name, array.type, field.optional)) - return pa.Table.from_arrays(arrays, schema=pa.schema(fields)) + return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields)) class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]): @@ -1293,8 +1364,10 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st if isinstance(partner_struct, pa.StructArray): return partner_struct.field(name) - elif isinstance(partner_struct, pa.Table): - return partner_struct.column(name).combine_chunks() + elif isinstance(partner_struct, pa.RecordBatch): + return partner_struct.column(name) + else: + raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}") return None @@ -1831,7 +1904,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT def write_parquet(task: WriteTask) -> DataFile: table_schema = task.schema - arrow_table = pa.Table.from_batches(task.record_batches) + # if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly # otherwise use the original schema if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema: @@ -1839,7 +1912,11 @@ def write_parquet(task: WriteTask) -> DataFile: else: file_schema = table_schema - arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table) + batches = [ + to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch) + for batch in task.record_batches + ] + arrow_table = pa.Table.from_batches(batches) file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}' fo = io.new_output(file_path) with fo.create(overwrite=True) as fos: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 9a10fc6bf5..c78e005cac 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1878,6 +1878,24 @@ def to_arrow(self) -> pa.Table: limit=self.limit, ) + def to_arrow_batch_reader(self) -> pa.RecordBatchReader: + import pyarrow as pa + + from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow + + return pa.RecordBatchReader.from_batches( + schema_to_pyarrow(self.projection()), + project_batches( + self.plan_files(), + self.table_metadata, + self.io, + self.row_filter, + self.projection(), + case_sensitive=self.case_sensitive, + limit=self.limit, + ), + ) + def to_pandas(self, **kwargs: Any) -> pd.DataFrame: return self.to_arrow().to_pandas(**kwargs) diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 80a6f18632..078abf406a 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -21,6 +21,7 @@ import uuid from urllib.parse import urlparse +import pyarrow as pa import pyarrow.parquet as pq import pytest from hive_metastore.ttypes import LockRequest, LockResponse, LockState, UnlockRequest @@ -174,6 +175,47 @@ def test_pyarrow_not_nan_count(catalog: Catalog) -> None: assert len(not_nan) == 2 +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_pyarrow_batches_nan(catalog: Catalog) -> None: + table_test_null_nan = catalog.load_table("default.test_null_nan") + arrow_batch_reader = table_test_null_nan.scan( + row_filter=IsNaN("col_numeric"), selected_fields=("idx", "col_numeric") + ).to_arrow_batch_reader() + assert isinstance(arrow_batch_reader, pa.RecordBatchReader) + arrow_table = arrow_batch_reader.read_all() + assert len(arrow_table) == 1 + assert arrow_table["idx"][0].as_py() == 1 + assert math.isnan(arrow_table["col_numeric"][0].as_py()) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_pyarrow_batches_nan_rewritten(catalog: Catalog) -> None: + table_test_null_nan_rewritten = catalog.load_table("default.test_null_nan_rewritten") + arrow_batch_reader = table_test_null_nan_rewritten.scan( + row_filter=IsNaN("col_numeric"), selected_fields=("idx", "col_numeric") + ).to_arrow_batch_reader() + assert isinstance(arrow_batch_reader, pa.RecordBatchReader) + arrow_table = arrow_batch_reader.read_all() + assert len(arrow_table) == 1 + assert arrow_table["idx"][0].as_py() == 1 + assert math.isnan(arrow_table["col_numeric"][0].as_py()) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +@pytest.mark.skip(reason="Fixing issues with NaN's: https://github.com/apache/arrow/issues/34162") +def test_pyarrow_batches_not_nan_count(catalog: Catalog) -> None: + table_test_null_nan = catalog.load_table("default.test_null_nan") + arrow_batch_reader = table_test_null_nan.scan( + row_filter=NotNaN("col_numeric"), selected_fields=("idx",) + ).to_arrow_batch_reader() + assert isinstance(arrow_batch_reader, pa.RecordBatchReader) + arrow_table = arrow_batch_reader.read_all() + assert len(arrow_table) == 2 + + @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_duckdb_nan(catalog: Catalog) -> None: @@ -354,6 +396,90 @@ def test_pyarrow_deletes_double(catalog: Catalog) -> None: assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10] +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_pyarrow_batches_deletes(catalog: Catalog) -> None: + # number, letter + # (1, 'a'), + # (2, 'b'), + # (3, 'c'), + # (4, 'd'), + # (5, 'e'), + # (6, 'f'), + # (7, 'g'), + # (8, 'h'), + # (9, 'i'), <- deleted + # (10, 'j'), + # (11, 'k'), + # (12, 'l') + test_positional_mor_deletes = catalog.load_table("default.test_positional_mor_deletes") + arrow_table = test_positional_mor_deletes.scan().to_arrow_batch_reader().read_all() + assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12] + + # Checking the filter + arrow_table = ( + test_positional_mor_deletes.scan(row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k"))) + .to_arrow_batch_reader() + .read_all() + ) + assert arrow_table["number"].to_pylist() == [5, 6, 7, 8, 10] + + # Testing the combination of a filter and a limit + arrow_table = ( + test_positional_mor_deletes.scan(row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k")), limit=1) + .to_arrow_batch_reader() + .read_all() + ) + assert arrow_table["number"].to_pylist() == [5] + + # Testing the slicing of indices + arrow_table = test_positional_mor_deletes.scan(limit=3).to_arrow_batch_reader().read_all() + assert arrow_table["number"].to_pylist() == [1, 2, 3] + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_pyarrow_batches_deletes_double(catalog: Catalog) -> None: + # number, letter + # (1, 'a'), + # (2, 'b'), + # (3, 'c'), + # (4, 'd'), + # (5, 'e'), + # (6, 'f'), <- second delete + # (7, 'g'), + # (8, 'h'), + # (9, 'i'), <- first delete + # (10, 'j'), + # (11, 'k'), + # (12, 'l') + test_positional_mor_double_deletes = catalog.load_table("default.test_positional_mor_double_deletes") + arrow_table = test_positional_mor_double_deletes.scan().to_arrow_batch_reader().read_all() + assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10, 11, 12] + + # Checking the filter + arrow_table = ( + test_positional_mor_double_deletes.scan(row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k"))) + .to_arrow_batch_reader() + .read_all() + ) + assert arrow_table["number"].to_pylist() == [5, 7, 8, 10] + + # Testing the combination of a filter and a limit + arrow_table = ( + test_positional_mor_double_deletes.scan( + row_filter=And(GreaterThanOrEqual("letter", "e"), LessThan("letter", "k")), limit=1 + ) + .to_arrow_batch_reader() + .read_all() + ) + assert arrow_table["number"].to_pylist() == [5] + + # Testing the slicing of indices + arrow_table = test_positional_mor_double_deletes.scan(limit=8).to_arrow_batch_reader().read_all() + assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10] + + @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_partitioned_tables(catalog: Catalog) -> None: