Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyArrow: Don't enforce the schema #902

Merged
merged 14 commits into from
Jul 11, 2024
73 changes: 45 additions & 28 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,8 +1047,10 @@ def _task_to_record_batches(

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
# We always use large types in memory as it uses larger offsets
# That can chunk more row values into the buffers
# With PyArrow 16.0.0 there is an issue with casting record-batches:
# https://github.com/apache/arrow/issues/41884
# https://github.com/apache/arrow/issues/43183
# Would be good to remove this later on
schema=_pyarrow_schema_ensure_large_types(physical_schema),
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
Expand Down Expand Up @@ -1084,11 +1086,17 @@ def _task_to_table(
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
) -> pa.Table:
batches = _task_to_record_batches(
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
) -> Optional[pa.Table]:
batches = list(
_task_to_record_batches(
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
)
)
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False))

if len(batches) > 0:
return pa.Table.from_batches(batches)
else:
return None


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
Expand Down Expand Up @@ -1192,7 +1200,7 @@ def project_table(
if len(tables) < 1:
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema, include_field_ids=False))

result = pa.concat_tables(tables)
result = pa.concat_tables(tables, promote_options="permissive")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if limit is not None:
return result.slice(0, limit)
Expand Down Expand Up @@ -1271,54 +1279,62 @@ def project_batches(


def to_requested_schema(
requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, downcast_ns_timestamp_to_us: bool = False
requested_schema: Schema,
file_schema: Schema,
batch: pa.RecordBatch,
downcast_ns_timestamp_to_us: bool = False,
include_field_ids: bool = False,
) -> pa.RecordBatch:
# We could re-use some of these visitors
struct_array = visit_with_partner(
requested_schema, batch, ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us), ArrowAccessor(file_schema)
requested_schema,
batch,
ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids),
ArrowAccessor(file_schema),
)

arrays = []
fields = []
for pos, field in enumerate(requested_schema.fields):
array = struct_array.field(pos)
arrays.append(array)
fields.append(pa.field(field.name, array.type, field.optional))
return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))
return pa.RecordBatch.from_struct_array(struct_array)


class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
file_schema: Schema
_include_field_ids: bool

def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False):
def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None:
self.file_schema = file_schema
self._include_field_ids = include_field_ids
self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us

def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)

if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=False)) != values.type:
Fokko marked this conversation as resolved.
Show resolved Hide resolved
# if file_field and field_type (e.g. String) are the same
# but the pyarrow type of the array is different from the expected type
# (e.g. string vs larger_string), we want to cast the array to the larger type
safe = True
return values.cast(
schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids)
)
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type:
# Downcasting of nanoseconds to microseconds
if (
pa.types.is_timestamp(target_type)
and target_type.unit == "us"
and pa.types.is_timestamp(values.type)
and values.type.unit == "ns"
):
safe = False
return values.cast(target_type, safe=safe)
return values.cast(target_type, safe=False)
return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
metadata = {}
if field.doc:
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
if self._include_field_ids:
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch on this one as well 👍


return pa.field(
name=field.name,
type=arrow_type,
nullable=field.optional,
metadata={DOC: field.doc} if field.doc is not None else None,
metadata=metadata,
)

def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]:
Expand Down Expand Up @@ -1960,14 +1976,15 @@ def write_parquet(task: WriteTask) -> DataFile:
file_schema=table_schema,
batch=batch,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
include_field_ids=True,
)
for batch in task.record_batches
]
arrow_table = pa.Table.from_batches(batches)
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=file_schema.as_arrow(), **parquet_writer_kwargs) as writer:
with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer:
writer.write(arrow_table, row_group_size=row_group_size)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
Expand Down
3 changes: 2 additions & 1 deletion pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,8 +2053,9 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:

from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow

target_schema = schema_to_pyarrow(self.projection())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we are making an opinionated decision on whether we are using large or small type as the pyarrow schema when reading the Iceberg table as a RecordBatchReader. Is there a reason why we don't want to do the same for the table API? I've noticed that we've changed the return type of the Table API to Optional[pa.Table] in order to avoid having to use schema_to_pyarrow.

Similarly, other libraries like polars use the approach of choosing one type over the other (large types in the case of polars).

>>> strings = pa.array(["a", "b"])
>>> pydict = {"strings": strings}
>>> pa.Table.from_pydict(pydict)
pyarrow.Table
strings: string
----
strings: [["a","b"]]
>>> pq.write_table(pa.Table.from_pydict(pydict), "strings.parquet")
>>> pldf = pl.read_parquet("strings.parquet", use_pyarrow=True)
>>> pldf.dtypes
[String]
>>> pldf.to_arrow()
pyarrow.Table
strings: large_string
----
strings: [["a","b"]]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preference would be to let Arrow decide. For Polars it is different because they are also the query engine. Casting the types will recompute the buffers, consuming additional memory/CPU, which I would rather avoid.

For the table, we first materialize all the batches in memory, so if one of them is large, it will automatically upcast, otherwise, it will keep the small types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My knowledge on Parquet data to Arrow buffer conversion is less versed, so please do check me if I am not making much sense 🙂

But are we actually casting the types on read?

We make a decision on whether we are choosing to read with large or small types when instantiating the fragment scanner, which loads the parquet data into the Arrow buffers. The schema_to_pyarrow() calls to pa.Table or pa.RecordBatchReader or in to_requested_schema following that all represent the Table schema in the consistent (large or small) format which shouldn't result in any additional casting and reassigning of buffers.

I think the only time we are casting the types is on write, where we may want to downcast it for forward compatibility. It looks like we have to choose a schema to use on write anyways, because using a schema for the ParquetWriter that isn't consistent with the schema within the dataframe results in an exception.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only time we are casting the types is on write, where we may want to downcast it for forward compatibility.

+1 Currently, we use "large_*" types during write. I think it could be better if we can write file based on the input pyarrow dataframe schema: if the dataframe is string, we also write with string

return pa.RecordBatchReader.from_batches(
schema_to_pyarrow(self.projection()),
target_schema,
project_batches(
self.plan_files(),
self.table_metadata,
Expand Down
80 changes: 72 additions & 8 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import os
from datetime import date
from typing import Iterator, Optional
from typing import Iterator

import pyarrow as pa
import pyarrow.parquet as pq
Expand All @@ -28,7 +28,8 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.io import FileIO
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import Table
from pyiceberg.transforms import BucketTransform, IdentityTransform, MonthTransform
Expand Down Expand Up @@ -107,23 +108,32 @@
)


def _write_parquet(io: FileIO, file_path: str, arrow_schema: pa.Schema, arrow_table: pa.Table) -> None:
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_schema) as writer:
writer.write_table(arrow_table)


def _create_table(
session_catalog: Catalog, identifier: str, format_version: int, partition_spec: Optional[PartitionSpec] = None
session_catalog: Catalog,
identifier: str,
format_version: int,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
schema: Schema = TABLE_SCHEMA,
) -> Table:
try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(
return session_catalog.create_table(
identifier=identifier,
schema=TABLE_SCHEMA,
schema=schema,
properties={"format-version": str(format_version)},
partition_spec=partition_spec if partition_spec else PartitionSpec(),
partition_spec=partition_spec,
)

return tbl


@pytest.fixture(name="format_version", params=[pytest.param(1, id="format_version=1"), pytest.param(2, id="format_version=2")])
def format_version_fixure(request: pytest.FixtureRequest) -> Iterator[int]:
Expand Down Expand Up @@ -454,6 +464,60 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat


@pytest.mark.integration
def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.unpartitioned_with_large_types{format_version}"

iceberg_schema = Schema(NestedField(1, "foo", StringType(), required=True))
arrow_schema = pa.schema([
pa.field("foo", pa.string(), nullable=False),
])
arrow_schema_large = pa.schema([
pa.field("foo", pa.large_string(), nullable=False),
])

tbl = _create_table(session_catalog, identifier, format_version, schema=iceberg_schema)

file_path = f"s3://warehouse/default/unpartitioned_with_large_types/v{format_version}/test-0.parquet"
_write_parquet(
tbl.io,
file_path,
arrow_schema,
pa.Table.from_pylist(
[
{
"foo": "normal",
}
],
schema=arrow_schema,
),
)

tbl.add_files([file_path])

table_schema = tbl.scan().to_arrow().schema
assert table_schema == arrow_schema_large

file_path_large = f"s3://warehouse/default/unpartitioned_with_large_types/v{format_version}/test-1.parquet"
_write_parquet(
tbl.io,
file_path_large,
arrow_schema_large,
pa.Table.from_pylist(
[
{
"foo": "normal",
}
],
schema=arrow_schema_large,
),
)

tbl.add_files([file_path_large])

table_schema = tbl.scan().to_arrow().schema
assert table_schema == arrow_schema_large


def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_deletes.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_partitioned_table_positional_deletes_sequence_number(spark: SparkSessio
assert snapshots[2].summary == Summary(
Operation.OVERWRITE,
**{
"added-files-size": "1145",
"added-files-size": snapshots[2].summary["total-files-size"],
"added-data-files": "1",
"added-records": "2",
"changed-partition-count": "1",
Expand Down
9 changes: 6 additions & 3 deletions tests/integration/test_inspect_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,25 @@ def test_inspect_snapshots(
for manifest_list in df["manifest_list"]:
assert manifest_list.as_py().startswith("s3://")

file_size = int(next(value for key, value in df["summary"][0].as_py() if key == "added-files-size"))
assert file_size > 0

# Append
assert df["summary"][0].as_py() == [
("added-files-size", "5459"),
("added-files-size", str(file_size)),
("added-data-files", "1"),
("added-records", "3"),
("total-data-files", "1"),
("total-delete-files", "0"),
("total-records", "3"),
("total-files-size", "5459"),
("total-files-size", str(file_size)),
("total-position-deletes", "0"),
("total-equality-deletes", "0"),
]

# Delete
assert df["summary"][1].as_py() == [
("removed-files-size", "5459"),
("removed-files-size", str(file_size)),
("deleted-data-files", "1"),
("deleted-records", "3"),
("total-data-files", "0"),
Expand Down
12 changes: 8 additions & 4 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,28 +252,32 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro
assert operations == ["append", "append"]

summaries = [row.summary for row in rows]

file_size = int(summaries[0]["added-files-size"])
assert file_size > 0

assert summaries[0] == {
"changed-partition-count": "3",
"added-data-files": "3",
"added-files-size": "15029",
"added-files-size": str(file_size),
"added-records": "3",
"total-data-files": "3",
"total-delete-files": "0",
"total-equality-deletes": "0",
"total-files-size": "15029",
"total-files-size": str(file_size),
"total-position-deletes": "0",
"total-records": "3",
}

assert summaries[1] == {
"changed-partition-count": "3",
"added-data-files": "3",
"added-files-size": "15029",
"added-files-size": str(file_size),
"added-records": "3",
"total-data-files": "6",
"total-delete-files": "0",
"total-equality-deletes": "0",
"total-files-size": "30058",
"total-files-size": str(file_size * 2),
"total-position-deletes": "0",
"total-records": "6",
}
Expand Down
Loading