Skip to content

Commit

Permalink
feat: extend write_deltalake to accept Deltalake schema (#1922)
Browse files Browse the repository at this point in the history
# Description
A second attempt to extend the write_deltalake to accept either PyArrow
or Deltalake schema (messed up the previous PR with some rebase issues)
Added a test

# Related Issue(s)
closes #1862

---------

Signed-off-by: Nikolay Ulmasov <[email protected]>
  • Loading branch information
r3stl355 authored Nov 29, 2023
1 parent 461efb5 commit ceb8562
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
17 changes: 12 additions & 5 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from urllib.parse import unquote

from deltalake import Schema
from deltalake.fs import DeltaStorageHandler

from ._util import encode_partition_value
Expand Down Expand Up @@ -81,7 +82,7 @@ def write_deltalake(
RecordBatchReader,
],
*,
schema: Optional[pa.Schema] = ...,
schema: Optional[Union[pa.Schema, Schema]] = ...,
partition_by: Optional[Union[List[str], str]] = ...,
filesystem: Optional[pa_fs.FileSystem] = None,
mode: Literal["error", "append", "overwrite", "ignore"] = ...,
Expand Down Expand Up @@ -115,7 +116,7 @@ def write_deltalake(
RecordBatchReader,
],
*,
schema: Optional[pa.Schema] = ...,
schema: Optional[Union[pa.Schema, Schema]] = ...,
partition_by: Optional[Union[List[str], str]] = ...,
mode: Literal["error", "append", "overwrite", "ignore"] = ...,
max_rows_per_group: int = ...,
Expand All @@ -142,7 +143,7 @@ def write_deltalake(
RecordBatchReader,
],
*,
schema: Optional[pa.Schema] = None,
schema: Optional[Union[pa.Schema, Schema]] = None,
partition_by: Optional[Union[List[str], str]] = None,
filesystem: Optional[pa_fs.FileSystem] = None,
mode: Literal["error", "append", "overwrite", "ignore"] = "error",
Expand Down Expand Up @@ -244,6 +245,9 @@ def write_deltalake(
if isinstance(partition_by, str):
partition_by = [partition_by]

if isinstance(schema, Schema):
schema = schema.to_pyarrow()

if isinstance(data, RecordBatchReader):
data = convert_pyarrow_recordbatchreader(data, large_dtypes)
elif isinstance(data, pa.RecordBatch):
Expand Down Expand Up @@ -336,16 +340,19 @@ def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType:
return dtype

if partition_by:
table_schema: pa.Schema = schema
if PYARROW_MAJOR_VERSION < 12:
partition_schema = pa.schema(
[
pa.field(name, _large_to_normal_dtype(schema.field(name).type))
pa.field(
name, _large_to_normal_dtype(table_schema.field(name).type)
)
for name in partition_by
]
)
else:
partition_schema = pa.schema(
[schema.field(name) for name in partition_by]
[table_schema.field(name) for name in partition_by]
)
partitioning = ds.partitioning(partition_schema, flavor="hive")
else:
Expand Down
10 changes: 9 additions & 1 deletion python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pyarrow.dataset import ParquetFileFormat, ParquetReadOptions
from pyarrow.lib import RecordBatchReader

from deltalake import DeltaTable, write_deltalake
from deltalake import DeltaTable, Schema, write_deltalake
from deltalake.exceptions import CommitFailedError, DeltaError, DeltaProtocolError
from deltalake.table import ProtocolVersions
from deltalake.writer import try_get_table_and_table_uri
Expand Down Expand Up @@ -1176,3 +1176,11 @@ def test_float_values(tmp_path: pathlib.Path):
assert actions["min"].field("x2")[0].as_py() is None
assert actions["max"].field("x2")[0].as_py() == 1.0
assert actions["null_count"].field("x2")[0].as_py() == 1


def test_with_deltalake_schema(tmp_path: pathlib.Path, sample_data: pa.Table):
write_deltalake(
tmp_path, sample_data, schema=Schema.from_pyarrow(sample_data.schema)
)
delta_table = DeltaTable(tmp_path)
assert delta_table.schema().to_pyarrow() == sample_data.schema

0 comments on commit ceb8562

Please sign in to comment.