Skip to content

Commit

Permalink
Support Table.to_arrow_batch_reader (#786)
Browse files Browse the repository at this point in the history
* _task_to_table to _task_to_record_batches

* to_arrow_batches

* tests

* fix

* fix

* deletes

* batch reader

* merge main

* adopt review feedback
  • Loading branch information
sungwy authored Jun 21, 2024
1 parent 2182060 commit b8c5bb7
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 39 deletions.
9 changes: 9 additions & 0 deletions mkdocs/docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,15 @@ tpep_dropoff_datetime: [[2021-04-01 00:47:59.000000,...,2021-05-01 00:14:47.0000

This will only pull in the files that that might contain matching rows.

One can also return a PyArrow RecordBatchReader, if reading one record batch at a time is preferred:

```python
table.scan(
row_filter=GreaterThanOrEqual("trip_distance", 10.0),
selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"),
).to_arrow_batch_reader()
```

### Pandas

<!-- prettier-ignore-start -->
Expand Down
155 changes: 116 additions & 39 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,12 +655,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedAr
}


def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: int) -> pa.Array:
def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start_index: int, end_index: int) -> pa.Array:
if len(positional_deletes) == 1:
all_chunks = positional_deletes[0]
else:
all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes]))
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)
return np.subtract(np.setdiff1d(np.arange(start_index, end_index), all_chunks, assume_unique=False), start_index)


def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema:
Expand Down Expand Up @@ -995,17 +995,16 @@ def _field_id(self, field: pa.Field) -> int:
return -1


def _task_to_table(
def _task_to_record_batches(
fs: FileSystem,
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
limit: Optional[int] = None,
name_mapping: Optional[NameMapping] = None,
) -> Optional[pa.Table]:
) -> Iterator[pa.RecordBatch]:
_, _, path = PyArrowFileIO.parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
with fs.open_input_file(path) as fin:
Expand Down Expand Up @@ -1035,36 +1034,39 @@ def _task_to_table(
columns=[col.name for col in file_project_schema.columns],
)

if positional_deletes:
# Create the mask of indices that we're interested in
indices = _combine_positional_deletes(positional_deletes, fragment.count_rows())

if limit:
if pyarrow_filter is not None:
# In case of the filter, we don't exactly know how many rows
# we need to fetch upfront, can be optimized in the future:
# https://github.com/apache/arrow/issues/35301
arrow_table = fragment_scanner.take(indices)
arrow_table = arrow_table.filter(pyarrow_filter)
arrow_table = arrow_table.slice(0, limit)
else:
arrow_table = fragment_scanner.take(indices[0:limit])
else:
arrow_table = fragment_scanner.take(indices)
current_index = 0
batches = fragment_scanner.to_batches()
for batch in batches:
if positional_deletes:
# Create the mask of indices that we're interested in
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
batch = batch.take(indices)
# Apply the user filter
if pyarrow_filter is not None:
# we need to switch back and forth between RecordBatch and Table
# as Expression filter isn't yet supported in RecordBatch
# https://github.com/apache/arrow/issues/39220
arrow_table = pa.Table.from_batches([batch])
arrow_table = arrow_table.filter(pyarrow_filter)
else:
# If there are no deletes, we can just take the head
# and the user-filter is already applied
if limit:
arrow_table = fragment_scanner.head(limit)
else:
arrow_table = fragment_scanner.to_table()
batch = arrow_table.to_batches()[0]
yield to_requested_schema(projected_schema, file_project_schema, batch)
current_index += len(batch)

if len(arrow_table) < 1:
return None
return to_requested_schema(projected_schema, file_project_schema, arrow_table)

def _task_to_table(
fs: FileSystem,
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
) -> pa.Table:
batches = _task_to_record_batches(
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
)
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False))


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
Expand Down Expand Up @@ -1143,7 +1145,6 @@ def project_table(
projected_field_ids,
deletes_per_file.get(task.file.file_path),
case_sensitive,
limit,
table_metadata.name_mapping(),
)
for task in tasks
Expand Down Expand Up @@ -1177,16 +1178,86 @@ def project_table(
return result


def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
def project_batches(
tasks: Iterable[FileScanTask],
table_metadata: TableMetadata,
io: FileIO,
row_filter: BooleanExpression,
projected_schema: Schema,
case_sensitive: bool = True,
limit: Optional[int] = None,
) -> Iterator[pa.RecordBatch]:
"""Resolve the right columns based on the identifier.
Args:
tasks (Iterable[FileScanTask]): A URI or a path to a local file.
table_metadata (TableMetadata): The table metadata of the table that's being queried
io (FileIO): A FileIO to open streams to the object store
row_filter (BooleanExpression): The expression for filtering rows.
projected_schema (Schema): The output schema.
case_sensitive (bool): Case sensitivity when looking up column names.
limit (Optional[int]): Limit the number of records.
Raises:
ResolveError: When an incompatible query is done.
"""
scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location)
if isinstance(io, PyArrowFileIO):
fs = io.fs_by_scheme(scheme, netloc)
else:
try:
from pyiceberg.io.fsspec import FsspecFileIO

if isinstance(io, FsspecFileIO):
from pyarrow.fs import PyFileSystem

fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
else:
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}")
except ModuleNotFoundError as e:
# When FsSpec is not installed
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e

bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)

projected_field_ids = {
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
}.union(extract_field_ids(bound_row_filter))

deletes_per_file = _read_all_delete_files(fs, tasks)

total_row_count = 0

for task in tasks:
batches = _task_to_record_batches(
fs,
task,
bound_row_filter,
projected_schema,
projected_field_ids,
deletes_per_file.get(task.file.file_path),
case_sensitive,
table_metadata.name_mapping(),
)
for batch in batches:
if limit is not None:
if total_row_count + len(batch) >= limit:
yield batch.slice(0, limit - total_row_count)
break
yield batch
total_row_count += len(batch)


def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch:
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))

arrays = []
fields = []
for pos, field in enumerate(requested_schema.fields):
array = struct_array.field(pos)
arrays.append(array)
fields.append(pa.field(field.name, array.type, field.optional))
return pa.Table.from_arrays(arrays, schema=pa.schema(fields))
return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))


class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
Expand Down Expand Up @@ -1293,8 +1364,10 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st

if isinstance(partner_struct, pa.StructArray):
return partner_struct.field(name)
elif isinstance(partner_struct, pa.Table):
return partner_struct.column(name).combine_chunks()
elif isinstance(partner_struct, pa.RecordBatch):
return partner_struct.column(name)
else:
raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}")

return None

Expand Down Expand Up @@ -1831,15 +1904,19 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT

def write_parquet(task: WriteTask) -> DataFile:
table_schema = task.schema
arrow_table = pa.Table.from_batches(task.record_batches)

# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
# otherwise use the original schema
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
file_schema = sanitized_schema
else:
file_schema = table_schema

arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
batches = [
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch)
for batch in task.record_batches
]
arrow_table = pa.Table.from_batches(batches)
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
Expand Down
18 changes: 18 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,6 +1878,24 @@ def to_arrow(self) -> pa.Table:
limit=self.limit,
)

def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
import pyarrow as pa

from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow

return pa.RecordBatchReader.from_batches(
schema_to_pyarrow(self.projection()),
project_batches(
self.plan_files(),
self.table_metadata,
self.io,
self.row_filter,
self.projection(),
case_sensitive=self.case_sensitive,
limit=self.limit,
),
)

def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
return self.to_arrow().to_pandas(**kwargs)

Expand Down
Loading

0 comments on commit b8c5bb7

Please sign in to comment.