Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow writing pa.Table that are either a subset of table schema or in arbitrary order, and support type promotion on write #921

Merged
merged 13 commits into from
Jul 17, 2024
99 changes: 65 additions & 34 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
Schema,
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
assign_fresh_schema_ids,
pre_order_visit,
promote,
prune_columns,
Expand Down Expand Up @@ -1450,14 +1451,17 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st
except ValueError:
return None

if isinstance(partner_struct, pa.StructArray):
return partner_struct.field(name)
elif isinstance(partner_struct, pa.Table):
return partner_struct.column(name).combine_chunks()
elif isinstance(partner_struct, pa.RecordBatch):
return partner_struct.column(name)
else:
raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}")
try:
if isinstance(partner_struct, pa.StructArray):
return partner_struct.field(name)
elif isinstance(partner_struct, pa.Table):
return partner_struct.column(name).combine_chunks()
elif isinstance(partner_struct, pa.RecordBatch):
return partner_struct.column(name)
else:
raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}")
except KeyError:
Copy link
Collaborator Author

@sungwy sungwy Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is necessary to support writing dataframes / recordbatches with a subset of the schema. Otherwise, the ArrowAccessor throws a KeyError. This way, we return a None and the ArrowProjectionVisitor is responsible for checking if the field is nullable, and can be filled in with a null array.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change responsible for schema projection / writing a subset of the schema? Do you mind expanding on the mechanism behind how this works? I'm curious

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right - the ArrowProjectionVisitor is responsible for detecting that the field_partner is None and then checking if the table field is also optional before filling it in with a null array. This change is necessary so that the ArrowAccessor doesn't throw an exception if the field can't be found in the arrow component, and enables ArrowProjectionVisitor to make use of a code pathway it wasn't able to make use of before:

if field_array is not None:
array = self._cast_if_needed(field, field_array)
field_arrays.append(array)
fields.append(self._construct_field(field, array.type))
elif field.optional:
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=False)
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
fields.append(self._construct_field(field, arrow_type))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above we have the file_schema that should correspond with the partner_struct. I expect that when looking up the field-id, it should already return None.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I as I pointed out in this comment: #921 (comment) I think write_parquet is using the Table Schema, instead of the Schema corresponding to the data types of the PyArrow construct.

I will take that to mean that this isn't intended and making sure that we use the Schema corresponding to the data types of the PyArrow construct is what we intend to do here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the context. This isn't intended, the schema should align with the data. I checked against the last commit, and it doesn't throw the KeyError anymore because of your fix. Thanks 👍

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion - I've removed this try exception block in the latest update.

return None

return None

Expand Down Expand Up @@ -2079,36 +2083,63 @@ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, down
Raises:
ValueError: If the schemas are not compatible.
"""
name_mapping = table_schema.name_mapping
try:
task_schema = pyarrow_to_schema(
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

if table_schema.as_struct() != task_schema.as_struct():
from rich.console import Console
from rich.table import Table as RichTable
task_schema = assign_fresh_schema_ids(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This naming is still from when we only used it at the read-path, probably we make it more generic. Maybe provided_schema and requested_schema? Open for suggestions!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm debating myself if the API is the most extensible here. I think we should re-use _check_schema_compatible(requested_schema: Schema, provided_schema: Schema) instead of reimplementing the logic in Arrow here. This nicely splits pyarrow_to_schema and _check_schema_compatible.

_pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
)

console = Console(record=True)
extra_fields = task_schema.field_names - table_schema.field_names
missing_fields = table_schema.field_names - task_schema.field_names
fields_in_both = task_schema.field_names.intersection(table_schema.field_names)

from rich.console import Console
from rich.table import Table as RichTable

console = Console(record=True)

rich_table = RichTable(show_header=True, header_style="bold")
rich_table.add_column("Field Name")
rich_table.add_column("Category")
rich_table.add_column("Table field")
rich_table.add_column("Dataframe field")

def print_nullability(required: bool) -> str:
return "required" if required else "optional"

for field_name in fields_in_both:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just want to check my understanding, this works for nested fields because nested fields are "flattened" by . field_names and then fetched by .find_field.

For example: a df schema like

task_schema = pa.field(
    "person",
    pa.struct([
        pa.field("name", pa.string(), nullable=True),
    ]),
    nullable=True,
)

task_schema.field_names will produce {"person", "person.name"}.
task_schema.find_field("person") and task_schema.find_field("person.name") will fetch the corresponding fields

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's consistent with my understanding

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, this could be solved using the SchemaWithPartnerVisitor

lhs = table_schema.find_field(field_name)
rhs = task_schema.find_field(field_name)
# Check nullability
if lhs.required != rhs.required:
Fokko marked this conversation as resolved.
Show resolved Hide resolved
rich_table.add_row(
field_name,
"Nullability",
f"{print_nullability(lhs.required)} {str(lhs.field_type)}",
f"{print_nullability(rhs.required)} {str(rhs.field_type)}",
)
# Check if type is consistent
if any(
(isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type))
for container_type in {StructType, MapType, ListType}
Fokko marked this conversation as resolved.
Show resolved Hide resolved
):
continue
elif lhs.field_type != rhs.field_type:
rich_table.add_row(
field_name,
"Type",
f"{print_nullability(lhs.required)} {str(lhs.field_type)}",
f"{print_nullability(rhs.required)} {str(rhs.field_type)}",
)
Copy link
Contributor

@HonahX HonahX Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking if we can be less restrictive on type. If the rhs's type can be promoted to lhs's type, the case may still be considered as compatible:

 elif lhs.field_type != rhs.field_type:
            try:
                promote(rhs.field_type, lhs.field_type)
            except ResolveError:
                rich_table.add_row(
                    field_name,
                    "Type",
                    f"{print_nullability(lhs.required)} {str(lhs.field_type)}",
                    f"{print_nullability(rhs.required)} {str(rhs.field_type)}",
                )
Example Test case
def test_schema_uuid() -> None:
    table_schema = Schema(
        NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
        NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
        NestedField(field_id=3, name="baz", field_type=UUIDType(), required=False),
        schema_id=1,
        identifier_field_ids=[2],
    )
    other_schema = pa.schema((
        pa.field("foo", pa.large_string(), nullable=True),
        pa.field("bar", pa.int32(), nullable=False),
        pa.field("baz", pa.binary(16), nullable=True),
    ))

    _check_schema_compatible(table_schema, other_schema)

    other_schema_fail = pa.schema((
        pa.field("foo", pa.large_string(), nullable=True),
        pa.field("bar", pa.int32(), nullable=False),
        pa.field("baz", pa.binary(15), nullable=True),
    ))

    with pytest.raises(ValueError):
        _check_schema_compatible(table_schema, other_schema_fail)

This could be a possible solution for #855, and should also cover writing pa.int32() (IntegerType) to LongType

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @HonahX I think that's a great suggestion! Thank you for pointing that out. I think it'll actually be a very simple change that addresses my question above:

Question: is it correct to compare both as Iceberg Schema? or do we want to allow a more permissive range of pyarrow types to be supported for writes?

For example, do we want to support writing pa.int32() into LongType()? - maybe we could support this in a subsequent PR?

Copy link
Collaborator Author

@sungwy sungwy Jul 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @HonahX - I tried this out, and I think we may benefit from scoping this out from this PR and investing some more time to figure out the correct way to support type promotions on write. The exception I'm getting is as follows:

tests/integration/test_writes/utils.py:79: in _create_table
    tbl.append(d)
pyiceberg/table/__init__.py:1557: in append
    tx.append(df=df, snapshot_properties=snapshot_properties)
pyiceberg/table/__init__.py:503: in append
    for data_file in data_files:
pyiceberg/io/pyarrow.py:2252: in _dataframe_to_data_files
    yield from write_file(
/usr/local/python/3.10.13/lib/python3.10/concurrent/futures/_base.py:621: in result_iterator
    yield _result_or_cancel(fs.pop())
/usr/local/python/3.10.13/lib/python3.10/concurrent/futures/_base.py:319: in _result_or_cancel
    return fut.result(timeout)
/usr/local/python/3.10.13/lib/python3.10/concurrent/futures/_base.py:458: in result
    return self.__get_result()
/usr/local/python/3.10.13/lib/python3.10/concurrent/futures/_base.py:403: in __get_result
    raise self._exception
/usr/local/python/3.10.13/lib/python3.10/concurrent/futures/thread.py:58: in run
    result = self.fn(*self.args, **self.kwargs)
pyiceberg/io/pyarrow.py:2030: in write_parquet
    statistics = data_file_statistics_from_parquet_metadata(
pyiceberg/io/pyarrow.py:1963: in data_file_statistics_from_parquet_metadata
    col_aggs[field_id] = StatsAggregator(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <pyiceberg.io.pyarrow.StatsAggregator object at 0x787561b422c0>, iceberg_type = LongType(), physical_type_string = 'INT32', trunc_length = None

    def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc_length: Optional[int] = None) -> None:
        self.current_min = None
        self.current_max = None
        self.trunc_length = trunc_length
    
        expected_physical_type = _primitive_to_physical(iceberg_type)
        if expected_physical_type != physical_type_string:
>           raise ValueError(
                f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}"
            )
E           ValueError: Unexpected physical type INT32 for long, expected INT64

pyiceberg/io/pyarrow.py:1556: ValueError

And this is because the file_schema that's passed to _to_requested_schema in write_parquet function is just the IcebergTable schema instead of being a Schema representation of the pyarrow Table's data type itself. So when the types of the file_schema and the requested_schema are compared, they are both comparing the Iceberg table type (e.g. LongType) instead of the smaller pyarrow type in the dataframe (e.g. IntegerType).

I think this is going to take a bit of work to ensure that we are using the schema that actually represents the datatype of the types within the Arrow dataframe, because we also have to create a Schema representation of the PyArrow schema that has field_ids consistent with the Iceberg Table schema, because the ArrowProjectionVisitor uses field_ids for lookup against the file_schema.

I'd like to continue this discussion out of scope of this release, but I think we will have to decide on one of the following two approaches:

  1. We decide to write with the compatible smaller parquet types (write using INT32 for a LongType) and fix the StatsAggregator to handle different physical types
  2. Update the file_schema input to _to_requested_schema in write_parquet so that we upcast the arrow data type and write using the larger expected physical types into the parquet file.

Long story short, our underlying piping currently doesn't yet support promote on write and there's still some work left for us in order to do so.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to get this to work as well. It should be pretty easy by just first upcasting the buffers before writing.


rich_table = RichTable(show_header=True, header_style="bold")
rich_table.add_column("")
rich_table.add_column("Table field")
rich_table.add_column("Dataframe field")
for field_name in extra_fields:
rhs = task_schema.find_field(field_name)
rich_table.add_row(field_name, "Extra Fields", "", f"{print_nullability(rhs.required)} {str(rhs.field_type)}")

for lhs in table_schema.fields:
try:
rhs = task_schema.find_field(lhs.field_id)
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
except ValueError:
rich_table.add_row("❌", str(lhs), "Missing")
for field_name in missing_fields:
lhs = table_schema.find_field(field_name)
if lhs.required:
rich_table.add_row(field_name, "Missing Fields", f"{print_nullability(lhs.required)} {str(lhs.field_type)}", "")

if rich_table.row_count:
console.print(rich_table)
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")

Expand Down
5 changes: 5 additions & 0 deletions pyiceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ def field_ids(self) -> Set[int]:
"""Return the IDs of the current schema."""
return set(self._name_to_id.values())

@property
def field_names(self) -> Set[str]:
Fokko marked this conversation as resolved.
Show resolved Hide resolved
"""Return the Names of the current schema."""
return set(self._name_to_id.keys())

def _validate_identifier_field(self, field_id: int) -> None:
"""Validate that the field with the given ID is a valid identifier field.

Expand Down
7 changes: 6 additions & 1 deletion pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
manifest_evaluator,
)
from pyiceberg.io import FileIO, OutputFile, load_file_io
from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.manifest import (
POSITIONAL_DELETE_SCHEMA,
DataFile,
Expand Down Expand Up @@ -471,6 +470,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

Expand Down Expand Up @@ -528,6 +529,8 @@ def overwrite(
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

Expand Down Expand Up @@ -566,6 +569,8 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
delete_filter: A boolean expression to delete rows from a table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table

if (
self.table_metadata.properties.get(TableProperties.DELETE_MODE, TableProperties.DELETE_MODE_DEFAULT)
== TableProperties.DELETE_MODE_MERGE_ON_READ
Expand Down
13 changes: 5 additions & 8 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,14 +501,11 @@ def test_add_files_fails_on_schema_mismatch(spark: SparkSession, session_catalog
)

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │
| ✅ │ 2: bar: optional string │ 2: bar: optional string │
│ ❌ │ 3: baz: optional int │ 3: baz: optional string │
│ ✅ │ 4: qux: optional date │ 4: qux: optional date │
└────┴──────────────────────────┴──────────────────────────┘
┏━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it just me, or is the left easier to read? 😅

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted for this approach because I wanted to group the Extra Fields in the dataframe also into the table. But if we are taking the approach of using the name_mapping to generate the Iceberg Schema with consistent IDs after first checking that there are no extra fields, I think we can go back to the old way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the new way since it tells me exactly which field to focus on and the reason its not compatible

┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃
┡━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ baz │ Type │ optional int │ optional string │
└────────────┴──────────┴──────────────┴─────────────────┘
"""

with pytest.raises(ValueError, match=expected):
Expand Down
24 changes: 22 additions & 2 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,18 +964,38 @@ def test_sanitize_character_partitioned(catalog: Catalog) -> None:
assert len(tbl.scan().to_arrow()) == 22


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None:
identifier = "default.table_append_subset_of_schema"
def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None:
identifier = "default.test_table_write_subset_of_schema"
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null])
arrow_table_without_some_columns = arrow_table_with_null.combine_chunks().drop(arrow_table_with_null.column_names[0])
print(arrow_table_without_some_columns.schema)
print(arrow_table_with_null.schema)
sungwy marked this conversation as resolved.
Show resolved Hide resolved
assert len(arrow_table_without_some_columns.columns) < len(arrow_table_with_null.columns)
tbl.overwrite(arrow_table_without_some_columns)
tbl.append(arrow_table_without_some_columns)
# overwrite and then append should produce twice the data
assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns) * 2


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @kevinjqliu for writing up these tests! Ported them over from your PR

identifier = "default.test_table_write_out_of_order_schema"
# rotate the schema fields by 1
fields = list(arrow_table_with_null.schema)
rotated_fields = fields[1:] + fields[:1]
rotated_schema = pa.schema(rotated_fields)
assert arrow_table_with_null.schema != rotated_schema
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=rotated_schema)

tbl.overwrite(arrow_table_with_null)
tbl.append(arrow_table_with_null)
# overwrite and then append should produce twice the data
assert len(tbl.scan().to_arrow()) == len(arrow_table_with_null) * 2


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_write_all_timestamp_precision(
Expand Down
Loading