Skip to content

Commit

Permalink
feat(python): handle PyCapsule interface objects in write_deltalake (#…
Browse files Browse the repository at this point in the history
…2534)

# Description

Adds support for the [Arrow PyCapsule
interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).

Since pyarrow is already a required dependency, this takes the minimal
route of converting pycapsule interface objects into pyarrow objects.
This requires pyarrow 15 or higher for the stream conversion
(apache/arrow#39217).

This doesn't modify the existing hard-coded support for pyarrow and
pandas

# Related Issue(s)

- closes #2376

# Documentation

---------

Co-authored-by: Ion Koutsouris <[email protected]>
  • Loading branch information
kylebarron and ion-elgreco authored Jul 18, 2024
1 parent 64bca17 commit 640ee6e
Showing 1 changed file with 33 additions and 3 deletions.
36 changes: 33 additions & 3 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
List,
Mapping,
Optional,
Protocol,
Tuple,
Union,
overload,
Expand All @@ -34,7 +35,7 @@
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.fs as pa_fs
from pyarrow.lib import RecordBatchReader
from pyarrow import RecordBatchReader

from ._internal import DeltaDataChecker as _DeltaDataChecker
from ._internal import batch_distinct
Expand Down Expand Up @@ -70,6 +71,17 @@
DEFAULT_DATA_SKIPPING_NUM_INDEX_COLS = 32


class ArrowStreamExportable(Protocol):
"""Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface.
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
"""

def __arrow_c_stream__(
self, requested_schema: Optional[object] = None
) -> object: ...


@dataclass
class AddAction:
path: str
Expand All @@ -90,6 +102,7 @@ def write_deltalake(
pa.RecordBatch,
Iterable[pa.RecordBatch],
RecordBatchReader,
ArrowStreamExportable,
],
*,
schema: Optional[Union[pa.Schema, DeltaSchema]] = ...,
Expand Down Expand Up @@ -123,6 +136,7 @@ def write_deltalake(
pa.RecordBatch,
Iterable[pa.RecordBatch],
RecordBatchReader,
ArrowStreamExportable,
],
*,
schema: Optional[Union[pa.Schema, DeltaSchema]] = ...,
Expand Down Expand Up @@ -150,6 +164,7 @@ def write_deltalake(
pa.RecordBatch,
Iterable[pa.RecordBatch],
RecordBatchReader,
ArrowStreamExportable,
],
*,
schema: Optional[Union[pa.Schema, DeltaSchema]] = ...,
Expand Down Expand Up @@ -177,6 +192,7 @@ def write_deltalake(
pa.RecordBatch,
Iterable[pa.RecordBatch],
RecordBatchReader,
ArrowStreamExportable,
],
*,
schema: Optional[Union[pa.Schema, DeltaSchema]] = None,
Expand Down Expand Up @@ -285,12 +301,26 @@ def write_deltalake(
data = convert_pyarrow_table(
pa.Table.from_pandas(data), large_dtypes=large_dtypes
)
elif hasattr(data, "__arrow_c_array__"):
data = convert_pyarrow_recordbatch(
pa.record_batch(data), # type:ignore[attr-defined]
large_dtypes,
)
elif hasattr(data, "__arrow_c_stream__"):
if not hasattr(RecordBatchReader, "from_stream"):
raise ValueError(
"pyarrow 15 or later required to read stream via pycapsule interface"
)

data = convert_pyarrow_recordbatchreader(
RecordBatchReader.from_stream(data), large_dtypes
)
elif isinstance(data, Iterable):
if schema is None:
raise ValueError("You must provide schema if data is Iterable")
else:
raise TypeError(
f"{type(data).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Iterable[RecordBatch], Table, Dataset or Pandas DataFrame are valid inputs for source."
f"{type(data).__name__} is not a valid input. Only PyArrow RecordBatchReader, RecordBatch, Iterable[RecordBatch], Table, Dataset or Pandas DataFrame or objects implementing the Arrow PyCapsule Interface are valid inputs for source."
)

if schema is None:
Expand Down Expand Up @@ -443,7 +473,7 @@ def visitor(written_file: Any) -> None:
raise DeltaProtocolError(
"This table's min_writer_version is "
f"{table_protocol.min_writer_version}, "
f"""but this method only supports version 2 or 7 with at max these features {SUPPORTED_WRITER_FEATURES} enabled.
f"""but this method only supports version 2 or 7 with at max these features {SUPPORTED_WRITER_FEATURES} enabled.
Try engine='rust' instead which supports more features and writer versions."""
)
if (
Expand Down

0 comments on commit 640ee6e

Please sign in to comment.