Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Table.to_arrow_batch_reader to return RecordBatchReader instead of a fully materialized Arrow Table #786

Merged
merged 10 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mkdocs/docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,15 @@ tpep_dropoff_datetime: [[2021-04-01 00:47:59.000000,...,2021-05-01 00:14:47.0000

This will only pull in the files that that might contain matching rows.

One can also return a PyArrow RecordBatchReader, if reading one record batch at a time is preferred:

```python
table.scan(
row_filter=GreaterThanOrEqual("trip_distance", 10.0),
selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"),
).to_arrow_batch_reader()
```

### Pandas

<!-- prettier-ignore-start -->
Expand Down
155 changes: 116 additions & 39 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,12 +655,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedAr
}


def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: int) -> pa.Array:
def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start_index: int, end_index: int) -> pa.Array:
if len(positional_deletes) == 1:
all_chunks = positional_deletes[0]
else:
all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes]))
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)
return np.subtract(np.setdiff1d(np.arange(start_index, end_index), all_chunks, assume_unique=False), start_index)


def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema:
Expand Down Expand Up @@ -995,17 +995,16 @@ def _field_id(self, field: pa.Field) -> int:
return -1


def _task_to_table(
def _task_to_record_batches(
fs: FileSystem,
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
limit: Optional[int] = None,
name_mapping: Optional[NameMapping] = None,
) -> Optional[pa.Table]:
) -> Iterator[pa.RecordBatch]:
_, _, path = PyArrowFileIO.parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
with fs.open_input_file(path) as fin:
Expand Down Expand Up @@ -1035,36 +1034,39 @@ def _task_to_table(
columns=[col.name for col in file_project_schema.columns],
)

if positional_deletes:
# Create the mask of indices that we're interested in
indices = _combine_positional_deletes(positional_deletes, fragment.count_rows())

if limit:
if pyarrow_filter is not None:
# In case of the filter, we don't exactly know how many rows
# we need to fetch upfront, can be optimized in the future:
# https://github.com/apache/arrow/issues/35301
arrow_table = fragment_scanner.take(indices)
arrow_table = arrow_table.filter(pyarrow_filter)
arrow_table = arrow_table.slice(0, limit)
else:
arrow_table = fragment_scanner.take(indices[0:limit])
else:
arrow_table = fragment_scanner.take(indices)
current_index = 0
batches = fragment_scanner.to_batches()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I review this, it occurs to me that it might be useful to expose options related to batching/read ahead, etc, that pyarrow accepts when constructing the scanner. See the pyarrow docs for more details.

Specifically, I think setting batch_size is probably something that ought to be tunable, since the memory pressure will be a function of batch size and the number and types of columns in the table.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great suggestion @corleyma I'll adopt this feedback when I make the next round of changes

for batch in batches:
if positional_deletes:
# Create the mask of indices that we're interested in
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
batch = batch.take(indices)
# Apply the user filter
if pyarrow_filter is not None:
# we need to switch back and forth between RecordBatch and Table
# as Expression filter isn't yet supported in RecordBatch
# https://github.com/apache/arrow/issues/39220
arrow_table = pa.Table.from_batches([batch])
arrow_table = arrow_table.filter(pyarrow_filter)
else:
# If there are no deletes, we can just take the head
# and the user-filter is already applied
if limit:
arrow_table = fragment_scanner.head(limit)
else:
arrow_table = fragment_scanner.to_table()
batch = arrow_table.to_batches()[0]
yield to_requested_schema(projected_schema, file_project_schema, batch)
current_index += len(batch)

if len(arrow_table) < 1:
return None
return to_requested_schema(projected_schema, file_project_schema, arrow_table)

def _task_to_table(
fs: FileSystem,
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
) -> pa.Table:
batches = _task_to_record_batches(
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
)
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False))


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
Expand Down Expand Up @@ -1143,7 +1145,6 @@ def project_table(
projected_field_ids,
deletes_per_file.get(task.file.file_path),
case_sensitive,
limit,
table_metadata.name_mapping(),
)
for task in tasks
Expand Down Expand Up @@ -1177,16 +1178,86 @@ def project_table(
return result


def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
def project_batches(
tasks: Iterable[FileScanTask],
table_metadata: TableMetadata,
io: FileIO,
row_filter: BooleanExpression,
projected_schema: Schema,
case_sensitive: bool = True,
limit: Optional[int] = None,
) -> Iterator[pa.RecordBatch]:
"""Resolve the right columns based on the identifier.

Args:
tasks (Iterable[FileScanTask]): A URI or a path to a local file.
table_metadata (TableMetadata): The table metadata of the table that's being queried
io (FileIO): A FileIO to open streams to the object store
row_filter (BooleanExpression): The expression for filtering rows.
projected_schema (Schema): The output schema.
case_sensitive (bool): Case sensitivity when looking up column names.
limit (Optional[int]): Limit the number of records.

Raises:
ResolveError: When an incompatible query is done.
"""
scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location)
if isinstance(io, PyArrowFileIO):
fs = io.fs_by_scheme(scheme, netloc)
else:
try:
from pyiceberg.io.fsspec import FsspecFileIO

if isinstance(io, FsspecFileIO):
from pyarrow.fs import PyFileSystem

fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
else:
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}")
except ModuleNotFoundError as e:
# When FsSpec is not installed
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e

bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)

projected_field_ids = {
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
}.union(extract_field_ids(bound_row_filter))

deletes_per_file = _read_all_delete_files(fs, tasks)

total_row_count = 0

for task in tasks:
batches = _task_to_record_batches(
fs,
task,
bound_row_filter,
projected_schema,
projected_field_ids,
deletes_per_file.get(task.file.file_path),
case_sensitive,
table_metadata.name_mapping(),
)
for batch in batches:
if limit is not None:
if total_row_count + len(batch) >= limit:
yield batch.slice(0, limit - total_row_count)
break
yield batch
total_row_count += len(batch)


def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch:
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))

arrays = []
fields = []
for pos, field in enumerate(requested_schema.fields):
array = struct_array.field(pos)
arrays.append(array)
fields.append(pa.field(field.name, array.type, field.optional))
return pa.Table.from_arrays(arrays, schema=pa.schema(fields))
return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))


class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
Expand Down Expand Up @@ -1293,8 +1364,10 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st

if isinstance(partner_struct, pa.StructArray):
return partner_struct.field(name)
elif isinstance(partner_struct, pa.Table):
return partner_struct.column(name).combine_chunks()
elif isinstance(partner_struct, pa.RecordBatch):
return partner_struct.column(name)
sungwy marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}")

return None

Expand Down Expand Up @@ -1831,15 +1904,19 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT

def write_parquet(task: WriteTask) -> DataFile:
table_schema = task.schema
arrow_table = pa.Table.from_batches(task.record_batches)

# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
# otherwise use the original schema
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
file_schema = sanitized_schema
else:
file_schema = table_schema

arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
batches = [
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch)
for batch in task.record_batches
]
arrow_table = pa.Table.from_batches(batches)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, looking here, this forced materialization seems to preclude streaming writes, which would you may want if e.g. upserting large amounts of data. ParquetWriter can be used for streaming writes, so this seems unnecessary?

Copy link

@corleyma corleyma Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e., maybe we could do something like the following?:

def sanitize_batches(batches: Iterator[RecordBatch], table_schema: Schema, sanitized_schema: Schema) -> Iterator[RecordBatch]:
    if sanitized_schema != table_schema:
        for batch in batches:
            yield to_requested_schema(requested_schema=sanitized_schema, file_schema=table_schema, batch=batch)
    else:
        yield from batches

def write_parquet(task: WriteTask) -> DataFile:
    table_schema = task.schema

    # Check if schema needs to be transformed
    sanitized_schema = sanitize_column_names(table_schema)

    file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
    fo = io.new_output(file_path)
    
    with fo.create(overwrite=True) as fos:
        with pq.ParquetWriter(fos, schema=sanitized_schema.as_arrow(), **parquet_writer_kwargs) as writer:
            for sanitized_batch in sanitize_batches(task.record_batches, table_schema, sanitized_schema):
                writer.write_table(pa.Table.from_batches([sanitized_batch]), row_group_size=row_group_size)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I totally agree. I wanted to focus this PR on introducing the reader first, and then work on a subsequent PR to incorporate batches into writes. This just maintains the existing functionality while making use of the refactored to_requested_schema

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah so the change here is the order of operations. We want to call to_requested_schema on each batch first before creating a pyarrow Table from those batches.

I wonder if we can push to_requested_schema up the stack since we already bin-pack batches before passing to WriteTask

for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)

Also in this #829, I wanted to introduce schema projection
https://github.com/apache/iceberg-python/pull/829/files#diff-23e8153e0fd497a9212215bd2067068f3b56fa071770c7ef326db3d3d03cee9bR2904

We can keep write_parquet very simple and just handle the writing, and preprocess all the batch/table level operations together.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s a great suggestion @kevinjqliu - good to see that our work is going to be converging naturally here.

I was hoping to focus on the new read API here and the necessary refactoring in the utility functions, and keep the changes to the write functions to a minimum.

I could incorporate these changes and continue the discussion on updating the write functions in a follow up PR. I think there’s much discussion that are worth continuing on that topic (can we avoid materializing an arrow table and write with record batches)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm, ty!

Copy link
Collaborator Author

@sungwy sungwy Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review!

file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
Expand Down
18 changes: 18 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,24 @@ def to_arrow(self) -> pa.Table:
limit=self.limit,
)

def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
import pyarrow as pa

from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow

return pa.RecordBatchReader.from_batches(
schema_to_pyarrow(self.projection()),
project_batches(
self.plan_files(),
self.table_metadata,
self.io,
self.row_filter,
self.projection(),
case_sensitive=self.case_sensitive,
limit=self.limit,
),
)

def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
return self.to_arrow().to_pandas(**kwargs)

Expand Down
Loading