-
Notifications
You must be signed in to change notification settings - Fork 195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow writing pa.Table
that are either a subset of table schema or in arbitrary order, and support type promotion on write
#921
Changes from 1 commit
245acda
0118f2a
e75e0ad
b6e3410
6b774c6
e26eb23
f0125e9
29573d9
d7ec362
d4d80e3
865c446
7340476
28e20d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -120,6 +120,7 @@ | |
Schema, | ||
SchemaVisitorPerPrimitiveType, | ||
SchemaWithPartnerVisitor, | ||
assign_fresh_schema_ids, | ||
pre_order_visit, | ||
promote, | ||
prune_columns, | ||
|
@@ -1450,14 +1451,17 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st | |
except ValueError: | ||
return None | ||
|
||
if isinstance(partner_struct, pa.StructArray): | ||
return partner_struct.field(name) | ||
elif isinstance(partner_struct, pa.Table): | ||
return partner_struct.column(name).combine_chunks() | ||
elif isinstance(partner_struct, pa.RecordBatch): | ||
return partner_struct.column(name) | ||
else: | ||
raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}") | ||
try: | ||
if isinstance(partner_struct, pa.StructArray): | ||
return partner_struct.field(name) | ||
elif isinstance(partner_struct, pa.Table): | ||
return partner_struct.column(name).combine_chunks() | ||
elif isinstance(partner_struct, pa.RecordBatch): | ||
return partner_struct.column(name) | ||
else: | ||
raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}") | ||
except KeyError: | ||
return None | ||
|
||
return None | ||
|
||
|
@@ -2079,36 +2083,63 @@ def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, down | |
Raises: | ||
ValueError: If the schemas are not compatible. | ||
""" | ||
name_mapping = table_schema.name_mapping | ||
try: | ||
task_schema = pyarrow_to_schema( | ||
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us | ||
) | ||
except ValueError as e: | ||
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) | ||
additional_names = set(other_schema.column_names) - set(table_schema.column_names) | ||
raise ValueError( | ||
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." | ||
) from e | ||
|
||
if table_schema.as_struct() != task_schema.as_struct(): | ||
from rich.console import Console | ||
from rich.table import Table as RichTable | ||
task_schema = assign_fresh_schema_ids( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: This naming is still from when we only used it at the read-path, probably we make it more generic. Maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm debating myself if the API is the most extensible here. I think we should re-use |
||
_pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us) | ||
) | ||
|
||
console = Console(record=True) | ||
extra_fields = task_schema.field_names - table_schema.field_names | ||
missing_fields = table_schema.field_names - task_schema.field_names | ||
fields_in_both = task_schema.field_names.intersection(table_schema.field_names) | ||
|
||
from rich.console import Console | ||
from rich.table import Table as RichTable | ||
|
||
console = Console(record=True) | ||
|
||
rich_table = RichTable(show_header=True, header_style="bold") | ||
rich_table.add_column("Field Name") | ||
rich_table.add_column("Category") | ||
rich_table.add_column("Table field") | ||
rich_table.add_column("Dataframe field") | ||
|
||
def print_nullability(required: bool) -> str: | ||
return "required" if required else "optional" | ||
|
||
for field_name in fields_in_both: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just want to check my understanding, this works for nested fields because nested fields are "flattened" by For example: a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's consistent with my understanding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good point, this could be solved using the |
||
lhs = table_schema.find_field(field_name) | ||
rhs = task_schema.find_field(field_name) | ||
# Check nullability | ||
if lhs.required != rhs.required: | ||
Fokko marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rich_table.add_row( | ||
field_name, | ||
"Nullability", | ||
f"{print_nullability(lhs.required)} {str(lhs.field_type)}", | ||
f"{print_nullability(rhs.required)} {str(rhs.field_type)}", | ||
) | ||
# Check if type is consistent | ||
if any( | ||
(isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type)) | ||
for container_type in {StructType, MapType, ListType} | ||
Fokko marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
continue | ||
elif lhs.field_type != rhs.field_type: | ||
rich_table.add_row( | ||
field_name, | ||
"Type", | ||
f"{print_nullability(lhs.required)} {str(lhs.field_type)}", | ||
f"{print_nullability(rhs.required)} {str(rhs.field_type)}", | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am thinking if we can be less restrictive on type. If the rhs's type can be promoted to lhs's type, the case may still be considered as compatible: elif lhs.field_type != rhs.field_type:
try:
promote(rhs.field_type, lhs.field_type)
except ResolveError:
rich_table.add_row(
field_name,
"Type",
f"{print_nullability(lhs.required)} {str(lhs.field_type)}",
f"{print_nullability(rhs.required)} {str(rhs.field_type)}",
) Example Test casedef test_schema_uuid() -> None:
table_schema = Schema(
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="baz", field_type=UUIDType(), required=False),
schema_id=1,
identifier_field_ids=[2],
)
other_schema = pa.schema((
pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.binary(16), nullable=True),
))
_check_schema_compatible(table_schema, other_schema)
other_schema_fail = pa.schema((
pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.binary(15), nullable=True),
))
with pytest.raises(ValueError):
_check_schema_compatible(table_schema, other_schema_fail) This could be a possible solution for #855, and should also cover writing pa.int32() (IntegerType) to LongType There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @HonahX I think that's a great suggestion! Thank you for pointing that out. I think it'll actually be a very simple change that addresses my question above:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @HonahX - I tried this out, and I think we may benefit from scoping this out from this PR and investing some more time to figure out the correct way to support type promotions on write. The exception I'm getting is as follows:
And this is because the I think this is going to take a bit of work to ensure that we are using the schema that actually represents the datatype of the types within the Arrow dataframe, because we also have to create a Schema representation of the PyArrow schema that has field_ids consistent with the Iceberg Table schema, because the I'd like to continue this discussion out of scope of this release, but I think we will have to decide on one of the following two approaches:
Long story short, our underlying piping currently doesn't yet support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be good to get this to work as well. It should be pretty easy by just first upcasting the buffers before writing. |
||
|
||
rich_table = RichTable(show_header=True, header_style="bold") | ||
rich_table.add_column("") | ||
rich_table.add_column("Table field") | ||
rich_table.add_column("Dataframe field") | ||
for field_name in extra_fields: | ||
rhs = task_schema.find_field(field_name) | ||
rich_table.add_row(field_name, "Extra Fields", "", f"{print_nullability(rhs.required)} {str(rhs.field_type)}") | ||
|
||
for lhs in table_schema.fields: | ||
try: | ||
rhs = task_schema.find_field(lhs.field_id) | ||
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) | ||
except ValueError: | ||
rich_table.add_row("❌", str(lhs), "Missing") | ||
for field_name in missing_fields: | ||
lhs = table_schema.find_field(field_name) | ||
if lhs.required: | ||
rich_table.add_row(field_name, "Missing Fields", f"{print_nullability(lhs.required)} {str(lhs.field_type)}", "") | ||
|
||
if rich_table.row_count: | ||
console.print(rich_table) | ||
raise ValueError(f"Mismatch in fields:\n{console.export_text()}") | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -501,14 +501,11 @@ def test_add_files_fails_on_schema_mismatch(spark: SparkSession, session_catalog | |
) | ||
|
||
expected = """Mismatch in fields: | ||
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ | ||
┃ ┃ Table field ┃ Dataframe field ┃ | ||
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ | ||
│ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │ | ||
| ✅ │ 2: bar: optional string │ 2: bar: optional string │ | ||
│ ❌ │ 3: baz: optional int │ 3: baz: optional string │ | ||
│ ✅ │ 4: qux: optional date │ 4: qux: optional date │ | ||
└────┴──────────────────────────┴──────────────────────────┘ | ||
┏━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it just me, or is the left easier to read? 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opted for this approach because I wanted to group the Extra Fields in the dataframe also into the table. But if we are taking the approach of using the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the new way since it tells me exactly which field to focus on and the reason its not compatible |
||
┃ Field Name ┃ Category ┃ Table field ┃ Dataframe field ┃ | ||
┡━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ | ||
│ baz │ Type │ optional int │ optional string │ | ||
└────────────┴──────────┴──────────────┴─────────────────┘ | ||
""" | ||
|
||
with pytest.raises(ValueError, match=expected): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -964,18 +964,38 @@ def test_sanitize_character_partitioned(catalog: Catalog) -> None: | |
assert len(tbl.scan().to_arrow()) == 22 | ||
|
||
|
||
@pytest.mark.integration | ||
@pytest.mark.parametrize("format_version", [1, 2]) | ||
def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: | ||
identifier = "default.table_append_subset_of_schema" | ||
def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: | ||
identifier = "default.test_table_write_subset_of_schema" | ||
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null]) | ||
arrow_table_without_some_columns = arrow_table_with_null.combine_chunks().drop(arrow_table_with_null.column_names[0]) | ||
print(arrow_table_without_some_columns.schema) | ||
print(arrow_table_with_null.schema) | ||
sungwy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert len(arrow_table_without_some_columns.columns) < len(arrow_table_with_null.columns) | ||
tbl.overwrite(arrow_table_without_some_columns) | ||
tbl.append(arrow_table_without_some_columns) | ||
# overwrite and then append should produce twice the data | ||
assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns) * 2 | ||
|
||
|
||
@pytest.mark.integration | ||
@pytest.mark.parametrize("format_version", [1, 2]) | ||
def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you @kevinjqliu for writing up these tests! Ported them over from your PR |
||
identifier = "default.test_table_write_out_of_order_schema" | ||
# rotate the schema fields by 1 | ||
fields = list(arrow_table_with_null.schema) | ||
rotated_fields = fields[1:] + fields[:1] | ||
rotated_schema = pa.schema(rotated_fields) | ||
assert arrow_table_with_null.schema != rotated_schema | ||
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=rotated_schema) | ||
|
||
tbl.overwrite(arrow_table_with_null) | ||
tbl.append(arrow_table_with_null) | ||
# overwrite and then append should produce twice the data | ||
assert len(tbl.scan().to_arrow()) == len(arrow_table_with_null) * 2 | ||
|
||
|
||
@pytest.mark.integration | ||
@pytest.mark.parametrize("format_version", [1, 2]) | ||
def test_write_all_timestamp_precision( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is necessary to support writing dataframes / recordbatches with a subset of the schema. Otherwise, the
ArrowAccessor
throws aKeyError
. This way, we return aNone
and theArrowProjectionVisitor
is responsible for checking if the field is nullable, and can be filled in with a null array.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change responsible for schema projection / writing a subset of the schema? Do you mind expanding on the mechanism behind how this works? I'm curious
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's right - the
ArrowProjectionVisitor
is responsible for detecting that thefield_partner
isNone
and then checking if the table field is also optional before filling it in with a null array. This change is necessary so that theArrowAccessor
doesn't throw an exception if the field can't be found in the arrow component, and enablesArrowProjectionVisitor
to make use of a code pathway it wasn't able to make use of before:iceberg-python/pyiceberg/io/pyarrow.py
Lines 1388 to 1395 in b11cdb5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Above we have the
file_schema
that should correspond with thepartner_struct
. I expect that when looking up the field-id, it should alreadyreturn None
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I as I pointed out in this comment: #921 (comment) I think
write_parquet
is using the Table Schema, instead of the Schema corresponding to the data types of the PyArrow construct.I will take that to mean that this isn't intended and making sure that we use the Schema corresponding to the data types of the PyArrow construct is what we intend to do here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the context. This isn't intended, the schema should align with the data. I checked against the last commit, and it doesn't throw the
KeyError
anymore because of your fix. Thanks 👍There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion - I've removed this try exception block in the latest update.