Skip to content

Commit

Permalink
fix(python): sort before schema comparison (#2209)
Browse files Browse the repository at this point in the history
# Description
Supersedes this PR: #1854,
@PierreDubrulle thanks for pointing it out

# Related Issue(s)
- closes #1853
  • Loading branch information
ion-elgreco authored Feb 25, 2024
1 parent 7357098 commit 56cfd62
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
10 changes: 7 additions & 3 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def write_deltalake(
partition_by = [partition_by]

if isinstance(schema, DeltaSchema):
schema = schema.to_pyarrow()
schema = schema.to_pyarrow(as_large_types=True)

if isinstance(data, RecordBatchReader):
data = convert_pyarrow_recordbatchreader(data, large_dtypes)
Expand Down Expand Up @@ -320,9 +320,13 @@ def write_deltalake(
# We need to write against the latest table version
filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options))

def sort_arrow_schema(schema: pa.schema) -> pa.schema:
sorted_cols = sorted(iter(schema), key=lambda x: (x.name, str(x.type)))
return pa.schema(sorted_cols)

if table: # already exists
if schema != table.schema().to_pyarrow(
as_large_types=large_dtypes
if sort_arrow_schema(schema) != sort_arrow_schema(
table.schema().to_pyarrow(as_large_types=large_dtypes)
) and not (mode == "overwrite" and overwrite_schema):
raise ValueError(
"Schema of data does not match table schema\n"
Expand Down
34 changes: 33 additions & 1 deletion python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def test_update_schema_rust_writer(existing_table: DeltaTable):
overwrite_schema=False,
engine="rust",
)
# TODO(ion): Remove this once we add schema overwrite support
write_deltalake(
existing_table,
new_data,
Expand Down Expand Up @@ -1273,3 +1272,36 @@ def test_write_stats_empty_rowgroups(tmp_path: pathlib.Path):
dt.to_pyarrow_dataset().to_table(filter=(pc.field("data") == "B")).shape[0]
== 33792
)


@pytest.mark.parametrize("engine", ["pyarrow", "rust"])
def test_schema_cols_diff_order(tmp_path: pathlib.Path, engine):
data = pa.table(
{
"foo": pa.array(["B"] * 10),
"bar": pa.array([1] * 10),
"baz": pa.array([2.0] * 10),
}
)
write_deltalake(tmp_path, data, mode="append", engine=engine)

data = pa.table(
{
"baz": pa.array([2.0] * 10),
"bar": pa.array([1] * 10),
"foo": pa.array(["B"] * 10),
}
)
write_deltalake(tmp_path, data, mode="append", engine=engine)
dt = DeltaTable(tmp_path)
assert dt.version() == 1

expected = pa.table(
{
"baz": pa.array([2.0] * 20),
"bar": pa.array([1] * 20),
"foo": pa.array(["B"] * 20),
}
)

assert dt.to_pyarrow_table(columns=["baz", "bar", "foo"]) == expected

0 comments on commit 56cfd62

Please sign in to comment.