Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyArrow: Don't enforce the schema #902

Merged
merged 14 commits into from
Jul 11, 2024
27 changes: 6 additions & 21 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,9 +1025,6 @@ def _task_to_record_batches(

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
# We always use large types in memory as it uses larger offsets
# That can chunk more row values into the buffers
schema=_pyarrow_schema_ensure_large_types(physical_schema),
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
filter=pyarrow_filter if not positional_deletes else None,
Expand Down Expand Up @@ -1066,7 +1063,7 @@ def _task_to_table(
batches = _task_to_record_batches(
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
)
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False))
return pa.Table.from_batches(batches)


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
Expand Down Expand Up @@ -1170,7 +1167,7 @@ def project_table(
if len(tables) < 1:
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema, include_field_ids=False))

result = pa.concat_tables(tables)
result = pa.concat_tables(tables, promote_options="permissive")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if limit is not None:
return result.slice(0, limit)
Expand Down Expand Up @@ -1249,15 +1246,9 @@ def project_batches(


def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch:
# We could re-use some of these visitors
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))

arrays = []
fields = []
for pos, field in enumerate(requested_schema.fields):
array = struct_array.field(pos)
arrays.append(array)
fields.append(pa.field(field.name, array.type, field.optional))
return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))
return pa.RecordBatch.from_struct_array(struct_array)


class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
Expand All @@ -1268,14 +1259,8 @@ def __init__(self, file_schema: Schema):

def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)
if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=False)) != values.type:
Fokko marked this conversation as resolved.
Show resolved Hide resolved
# if file_field and field_type (e.g. String) are the same
# but the pyarrow type of the array is different from the expected type
# (e.g. string vs larger_string), we want to cast the array to the larger type
return values.cast(target_type)
if field.field_type.is_primitive and field.field_type != file_field.field_type:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
Expand Down
81 changes: 73 additions & 8 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint:disable=redefined-outer-name

from datetime import date
from typing import Iterator, Optional
from typing import Iterator

import pyarrow as pa
import pyarrow.parquet as pq
Expand All @@ -26,7 +26,8 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.io import FileIO
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import Table
from pyiceberg.transforms import BucketTransform, IdentityTransform, MonthTransform
Expand Down Expand Up @@ -104,23 +105,32 @@
)


def _write_parquet(io: FileIO, file_path: str, arrow_schema: pa.Schema, arrow_table: pa.Table) -> None:
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_schema) as writer:
writer.write_table(arrow_table)


def _create_table(
session_catalog: Catalog, identifier: str, format_version: int, partition_spec: Optional[PartitionSpec] = None
session_catalog: Catalog,
identifier: str,
format_version: int,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
schema: Schema = TABLE_SCHEMA,
) -> Table:
try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(
return session_catalog.create_table(
identifier=identifier,
schema=TABLE_SCHEMA,
schema=schema,
properties={"format-version": str(format_version)},
partition_spec=partition_spec if partition_spec else PartitionSpec(),
partition_spec=partition_spec,
)

return tbl


@pytest.fixture(name="format_version", params=[pytest.param(1, id="format_version=1"), pytest.param(2, id="format_version=2")])
def format_version_fixure(request: pytest.FixtureRequest) -> Iterator[int]:
Expand Down Expand Up @@ -448,3 +458,58 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat

assert "snapshot_prop_a" in summary
assert summary["snapshot_prop_a"] == "test_prop_a"


@pytest.mark.integration
def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.unpartitioned_with_large_types{format_version}"

iceberg_schema = Schema(NestedField(1, "foo", StringType(), required=True))
arrow_schema = pa.schema([
pa.field("foo", pa.string(), nullable=False),
])
arrow_schema_large = pa.schema([
pa.field("foo", pa.large_string(), nullable=False),
])

tbl = _create_table(session_catalog, identifier, format_version, schema=iceberg_schema)

file_path = f"s3://warehouse/default/unpartitioned_with_large_types/v{format_version}/test-0.parquet"
_write_parquet(
tbl.io,
file_path,
arrow_schema,
pa.Table.from_pylist(
[
{
"foo": "normal",
}
],
schema=arrow_schema,
),
)

tbl.add_files([file_path])

table_schema = tbl.scan().to_arrow().schema
assert table_schema == arrow_schema

file_path_large = f"s3://warehouse/default/unpartitioned_with_large_types/v{format_version}/test-1.parquet"
_write_parquet(
tbl.io,
file_path_large,
arrow_schema_large,
pa.Table.from_pylist(
[
{
"foo": "normal",
}
],
schema=arrow_schema_large,
),
)

tbl.add_files([file_path_large])

table_schema = tbl.scan().to_arrow().schema
assert table_schema == arrow_schema_large
14 changes: 7 additions & 7 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,10 +1002,10 @@ def test_read_map(schema_map: Schema, file_map: str) -> None:

assert (
repr(result_table.schema)
== """properties: map<large_string, large_string>
child 0, entries: struct<key: large_string not null, value: large_string not null> not null
child 0, key: large_string not null
child 1, value: large_string not null"""
== """properties: map<string, string>
child 0, entries: struct<key: string not null, value: string not null> not null
child 0, key: string not null
child 1, value: string not null"""
)


Expand Down Expand Up @@ -1279,9 +1279,9 @@ def test_projection_maps_of_structs(schema_map_of_structs: Schema, file_map_of_s
assert actual.as_py() == expected
assert (
repr(result_table.schema)
== """locations: map<large_string, struct<latitude: double not null, longitude: double not null, altitude: double>>
child 0, entries: struct<key: large_string not null, value: struct<latitude: double not null, longitude: double not null, altitude: double> not null> not null
child 0, key: large_string not null
== """locations: map<string, struct<latitude: double not null, longitude: double not null, altitude: double>>
child 0, entries: struct<key: string not null, value: struct<latitude: double not null, longitude: double not null, altitude: double> not null> not null
child 0, key: string not null
child 1, value: struct<latitude: double not null, longitude: double not null, altitude: double> not null
child 0, latitude: double not null
child 1, longitude: double not null
Expand Down
Loading