Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow writing dataframes that are either a subset of table schema or in arbitrary order #829

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2034,16 +2034,18 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[

def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None:
"""
Check if the `table_schema` is compatible with `other_schema`.
Check if the `table_schema` is compatible with `other_schema` in terms of the Iceberg Schema representation.

Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
The schemas are compatible if:
- All fields in `other_schema` are present in `table_schema`. (other_schema <= table_schema)
- All required fields in `table_schema` are present in `other_schema`.

Raises:
ValueError: If the schemas are not compatible.
"""
name_mapping = table_schema.name_mapping
try:
task_schema = pyarrow_to_schema(
other_schema = pyarrow_to_schema(
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
Expand All @@ -2053,7 +2055,10 @@ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, down
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

if table_schema.as_struct() != task_schema.as_struct():
fields_missing_from_table = {field for field in other_schema.fields if field not in table_schema.fields}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't work for nested structs, need a better solution

required_fields_in_table = {field for field in table_schema.fields if field.required}
missing_required_fields_in_other = {field for field in required_fields_in_table if field not in other_schema.fields}
if fields_missing_from_table or missing_required_fields_in_other:
from rich.console import Console
from rich.table import Table as RichTable

Expand All @@ -2066,7 +2071,7 @@ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, down

for lhs in table_schema.fields:
try:
rhs = task_schema.find_field(lhs.field_id)
rhs = other_schema.find_field(lhs.field_id)
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
except ValueError:
rich_table.add_row("❌", str(lhs), "Missing")
Expand Down Expand Up @@ -2177,17 +2182,20 @@ def _dataframe_to_data_files(
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)

# projects schema to match the pyarrow table
write_schema = pyarrow_to_schema(df.schema, name_mapping=table_metadata.schema().name_mapping)

if table_metadata.spec().is_unpartitioned():
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema())
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=write_schema)
for batches in bin_pack_arrow_table(df, target_file_size)
]),
)
else:
partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)
partitions = _determine_partitions(spec=table_metadata.spec(), schema=write_schema, arrow_table=df)
yield from write_file(
io=io,
table_metadata=table_metadata,
Expand Down
8 changes: 0 additions & 8 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
# cast if the two schemas are compatible but not equal
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@syun64 I want to get your take on this part. Due to the timestamp change, do you know if the df need to be casted?
There are a couple of different parts involved in the write path. In particular, we need to look at the table schema, the df schema, and the df itself. As well as dealing with bin-packing and other transformations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to extract this convo into an issue, to also continue the convo from #786 (comment)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@syun64 I want to get your take on this part. Due to the timestamp change, do you know if the df need to be casted? There are a couple of different parts involved in the write path. In particular, we need to look at the table schema, the df schema, and the df itself. As well as dealing with bin-packing and other transformations.

I have a PR open to try to fix this behavior: #910 I think it's almost ready to merge 😄

table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

manifest_merge_enabled = PropertyUtil.property_as_bool(
self.table_metadata.properties,
Expand Down Expand Up @@ -545,10 +541,6 @@ def overwrite(
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)

Expand Down
22 changes: 20 additions & 2 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,9 +963,10 @@ def test_sanitize_character_partitioned(catalog: Catalog) -> None:
assert len(tbl.scan().to_arrow()) == 22


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None:
identifier = "default.table_append_subset_of_schema"
def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None:
kevinjqliu marked this conversation as resolved.
Show resolved Hide resolved
identifier = "default.test_table_write_subset_of_schema"
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null])
arrow_table_without_some_columns = arrow_table_with_null.combine_chunks().drop(arrow_table_with_null.column_names[0])
assert len(arrow_table_without_some_columns.columns) < len(arrow_table_with_null.columns)
Expand All @@ -975,6 +976,23 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null
assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns) * 2


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None:
identifier = "default.test_table_write_out_of_order_schema"
# rotate the schema fields by 1
fields = list(arrow_table_with_null.schema)
rotated_fields = fields[1:] + fields[:1]
rotated_schema = pa.schema(rotated_fields)
assert arrow_table_with_null.schema != rotated_schema
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=rotated_schema)

tbl.overwrite(arrow_table_with_null)
tbl.append(arrow_table_with_null)
# overwrite and then append should produce twice the data
assert len(tbl.scan().to_arrow()) == len(arrow_table_with_null) * 2


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: Catalog, format_version: int) -> None:
Expand Down
21 changes: 20 additions & 1 deletion tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,25 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_compatible(table_schema_simple: Schema) -> None:
try:
_check_schema_compatible(table_schema_simple, table_schema_simple.as_arrow())
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`")


def test_schema_projection(table_schema_simple: Schema) -> None:
# remove optional `baz` field from `table_schema_simple`
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
))
try:
_check_schema_compatible(table_schema_simple, other_schema)
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`")


def test_schema_downcast(table_schema_simple: Schema) -> None:
# large_string type is compatible with string type
other_schema = pa.schema((
Expand All @@ -1810,7 +1829,7 @@ def test_schema_downcast(table_schema_simple: Schema) -> None:
try:
_check_schema_compatible(table_schema_simple, other_schema)
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_schema`")
pytest.fail("Unexpected Exception raised when calling `_check_schema_compatible`")


def test_partition_for_demo() -> None:
Expand Down
Loading