Skip to content

Commit

Permalink
feat(python/adbc_driver_manager): export handles through python Arrow…
Browse files Browse the repository at this point in the history
… Capsule interface (#1346)

Addresses #70

This PR adds the dunder methods to the Handle classes of the low-level
interface (which already enables using the low-level interface without
pyarrow and with the capsule protocol).

And secondly, in the places that accept data (eg ingest/bind), it now
also accepts objects that implement the dunders in addition to hardcoded
support for pyarrow.

---------

Co-authored-by: David Li <[email protected]>
  • Loading branch information
jorisvandenbossche and lidavidm authored Dec 13, 2023
1 parent fd705e0 commit 9544887
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 33 deletions.
9 changes: 7 additions & 2 deletions python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t

cdef extern from "adbc.h" nogil:
# C ABI

ctypedef void (*CArrowSchemaRelease)(void*)
ctypedef void (*CArrowArrayRelease)(void*)

cdef struct CArrowSchema"ArrowSchema":
pass
CArrowSchemaRelease release

cdef struct CArrowArray"ArrowArray":
pass
CArrowArrayRelease release

ctypedef int (*CArrowArrayStreamGetLastError)(void*)
ctypedef int (*CArrowArrayStreamGetNext)(void*, CArrowArray*)
Expand Down
110 changes: 97 additions & 13 deletions python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@ import threading
import typing
from typing import List, Tuple

cimport cpython
import cython
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.pycapsule cimport (
PyCapsule_GetPointer, PyCapsule_New, PyCapsule_CheckExact
)
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
from libc.string cimport memset
from libc.stdlib cimport malloc, free
from libc.string cimport memcpy, memset
from libcpp.vector cimport vector as c_vector

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -304,9 +309,29 @@ cdef class _AdbcHandle:
f"with open {self._child_type}")


cdef void pycapsule_schema_deleter(object capsule) noexcept:
cdef CArrowSchema* allocated = <CArrowSchema*>PyCapsule_GetPointer(
capsule, "arrow_schema"
)
if allocated.release != NULL:
allocated.release(allocated)
free(allocated)


cdef void pycapsule_stream_deleter(object capsule) noexcept:
cdef CArrowArrayStream* allocated = <CArrowArrayStream*> PyCapsule_GetPointer(
capsule, "arrow_array_stream"
)
if allocated.release != NULL:
allocated.release(allocated)
free(allocated)


cdef class ArrowSchemaHandle:
"""
A wrapper for an allocated ArrowSchema.
This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowSchema schema
Expand All @@ -316,23 +341,42 @@ cdef class ArrowSchemaHandle:
"""The address of the ArrowSchema."""
return <uintptr_t> &self.schema

def __arrow_c_schema__(self) -> object:
"""Consume this object to get a PyCapsule."""
# Reference:
# https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html#create-a-pycapsule
cdef CArrowSchema* allocated = <CArrowSchema*> malloc(sizeof(CArrowSchema))
allocated.release = NULL
capsule = PyCapsule_New(
<void*>allocated, "arrow_schema", &pycapsule_schema_deleter,
)
memcpy(allocated, &self.schema, sizeof(CArrowSchema))
self.schema.release = NULL
return capsule


cdef class ArrowArrayHandle:
"""
A wrapper for an allocated ArrowArray.

This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowArray array

@property
def address(self) -> int:
"""The address of the ArrowArray."""
"""
The address of the ArrowArray.
"""
return <uintptr_t> &self.array


cdef class ArrowArrayStreamHandle:
"""
A wrapper for an allocated ArrowArrayStream.

This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowArrayStream stream
Expand All @@ -342,6 +386,21 @@ cdef class ArrowArrayStreamHandle:
"""The address of the ArrowArrayStream."""
return <uintptr_t> &self.stream

def __arrow_c_stream__(self, requested_schema=None) -> object:
"""Consume this object to get a PyCapsule."""
if requested_schema is not None:
raise NotImplementedError("requested_schema")

cdef CArrowArrayStream* allocated = \
<CArrowArrayStream*> malloc(sizeof(CArrowArrayStream))
allocated.release = NULL
capsule = PyCapsule_New(
<void*>allocated, "arrow_array_stream", &pycapsule_stream_deleter,
)
memcpy(allocated, &self.stream, sizeof(CArrowArrayStream))
self.stream.release = NULL
return capsule


class GetObjectsDepth(enum.IntEnum):
ALL = ADBC_OBJECT_DEPTH_ALL
Expand Down Expand Up @@ -1000,32 +1059,47 @@ cdef class AdbcStatement(_AdbcHandle):

connection._open_child()

def bind(self, data, schema) -> None:
def bind(self, data, schema=None) -> None:
"""
Bind an ArrowArray to this statement.

Parameters
----------
data : int or ArrowArrayHandle
schema : int or ArrowSchemaHandle
data : PyCapsule or int or ArrowArrayHandle
schema : PyCapsule or int or ArrowSchemaHandle
"""
cdef CAdbcError c_error = empty_error()
cdef CArrowArray* c_array
cdef CArrowSchema* c_schema

if isinstance(data, ArrowArrayHandle):
if hasattr(data, "__arrow_c_array__") and not isinstance(data, ArrowArrayHandle):
if schema is not None:
raise ValueError(
"Can not provide a schema when passing Arrow-compatible "
"data that implements the Arrow PyCapsule Protocol"
)
schema, data = data.__arrow_c_array__()

if PyCapsule_CheckExact(data):
c_array = <CArrowArray*> PyCapsule_GetPointer(data, "arrow_array")
elif isinstance(data, ArrowArrayHandle):
c_array = &(<ArrowArrayHandle> data).array
elif isinstance(data, int):
c_array = <CArrowArray*> data
else:
raise TypeError(f"data must be int or ArrowArrayHandle, not {type(data)}")

if isinstance(schema, ArrowSchemaHandle):
raise TypeError(
"data must be Arrow-compatible data (implementing the Arrow PyCapsule "
f"Protocol), a PyCapsule, int or ArrowArrayHandle, not {type(data)}"
)

if PyCapsule_CheckExact(schema):
c_schema = <CArrowSchema*> PyCapsule_GetPointer(schema, "arrow_schema")
elif isinstance(schema, ArrowSchemaHandle):
c_schema = &(<ArrowSchemaHandle> schema).schema
elif isinstance(schema, int):
c_schema = <CArrowSchema*> schema
else:
raise TypeError(f"schema must be int or ArrowSchemaHandle, "
raise TypeError("schema must be a PyCapsule, int or ArrowSchemaHandle, "
f"not {type(schema)}")

with nogil:
Expand All @@ -1042,17 +1116,27 @@ cdef class AdbcStatement(_AdbcHandle):

Parameters
----------
stream : int or ArrowArrayStreamHandle
stream : PyCapsule or int or ArrowArrayStreamHandle
"""
cdef CAdbcError c_error = empty_error()
cdef CArrowArrayStream* c_stream

if isinstance(stream, ArrowArrayStreamHandle):
if (
hasattr(stream, "__arrow_c_stream__")
and not isinstance(stream, ArrowArrayStreamHandle)
):
stream = stream.__arrow_c_stream__()

if PyCapsule_CheckExact(stream):
c_stream = <CArrowArrayStream*> PyCapsule_GetPointer(
stream, "arrow_array_stream"
)
elif isinstance(stream, ArrowArrayStreamHandle):
c_stream = &(<ArrowArrayStreamHandle> stream).stream
elif isinstance(stream, int):
c_stream = <CArrowArrayStream*> stream
else:
raise TypeError(f"data must be int or ArrowArrayStreamHandle, "
raise TypeError(f"data must be a PyCapsule, int or ArrowArrayStreamHandle, "
f"not {type(stream)}")

with nogil:
Expand Down
51 changes: 33 additions & 18 deletions python/adbc_driver_manager/adbc_driver_manager/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,17 +612,21 @@ def close(self):
self._closed = True

def _bind(self, parameters) -> None:
if isinstance(parameters, pyarrow.RecordBatch):
if hasattr(parameters, "__arrow_c_array__"):
self._stmt.bind(parameters)
elif hasattr(parameters, "__arrow_c_stream__"):
self._stmt.bind_stream(parameters)
elif isinstance(parameters, pyarrow.RecordBatch):
arr_handle = _lib.ArrowArrayHandle()
sch_handle = _lib.ArrowSchemaHandle()
parameters._export_to_c(arr_handle.address, sch_handle.address)
self._stmt.bind(arr_handle, sch_handle)
return
if isinstance(parameters, pyarrow.Table):
parameters = parameters.to_reader()
stream_handle = _lib.ArrowArrayStreamHandle()
parameters._export_to_c(stream_handle.address)
self._stmt.bind_stream(stream_handle)
else:
if isinstance(parameters, pyarrow.Table):
parameters = parameters.to_reader()
stream_handle = _lib.ArrowArrayStreamHandle()
parameters._export_to_c(stream_handle.address)
self._stmt.bind_stream(stream_handle)

def _prepare_execute(self, operation, parameters=None) -> None:
self._results = None
Expand All @@ -639,9 +643,7 @@ def _prepare_execute(self, operation, parameters=None) -> None:
# Not all drivers support it
pass

if isinstance(
parameters, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader)
):
if _is_arrow_data(parameters):
self._bind(parameters)
elif parameters:
rb = pyarrow.record_batch(
Expand All @@ -668,7 +670,6 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None:
self._prepare_execute(operation, parameters)
handle, self._rowcount = self._stmt.execute_query()
self._results = _RowIterator(
# pyarrow.RecordBatchReader._import_from_c(handle.address)
_reader.AdbcRecordBatchReader._import_from_c(handle.address)
)

Expand All @@ -683,7 +684,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
operation : bytes or str
The query to execute. Pass SQL queries as strings,
(serialized) Substrait plans as bytes.
parameters
seq_of_parameters
Parameters to bind. Can be a list of Python sequences, or
an Arrow record batch, table, or record batch reader. If
None, then the query will be executed once, else it will
Expand All @@ -695,10 +696,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
self._stmt.set_sql_query(operation)
self._stmt.prepare()

if isinstance(
seq_of_parameters,
(pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader),
):
if _is_arrow_data(seq_of_parameters):
arrow_parameters = seq_of_parameters
elif seq_of_parameters:
arrow_parameters = pyarrow.RecordBatch.from_pydict(
Expand Down Expand Up @@ -806,7 +804,10 @@ def adbc_ingest(
table_name
The table to insert into.
data
The Arrow data to insert.
The Arrow data to insert. This can be a pyarrow RecordBatch, Table
or RecordBatchReader, or any Arrow-compatible data that implements
the Arrow PyCapsule Protocol (i.e. has an ``__arrow_c_array__``
or ``__arrow_c_stream__ ``method).
mode
How to deal with existing data:
Expand Down Expand Up @@ -878,7 +879,11 @@ def adbc_ingest(
except NotSupportedError:
pass

if isinstance(data, pyarrow.RecordBatch):
if hasattr(data, "__arrow_c_array__"):
self._stmt.bind(data)
elif hasattr(data, "__arrow_c_stream__"):
self._stmt.bind_stream(data)
elif isinstance(data, pyarrow.RecordBatch):
array = _lib.ArrowArrayHandle()
schema = _lib.ArrowSchemaHandle()
data._export_to_c(array.address, schema.address)
Expand Down Expand Up @@ -1151,3 +1156,13 @@ def _warn_unclosed(name):
category=ResourceWarning,
stacklevel=2,
)


def _is_arrow_data(data):
return (
hasattr(data, "__arrow_c_array__")
or hasattr(data, "__arrow_c_stream__")
or isinstance(
data, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader)
)
)
28 changes: 28 additions & 0 deletions python/adbc_driver_manager/tests/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ def test_get_table_types(sqlite):
assert sqlite.adbc_get_table_types() == ["table", "view"]


class ArrayWrapper:
def __init__(self, array):
self.array = array

def __arrow_c_array__(self, requested_schema=None):
return self.array.__arrow_c_array__(requested_schema=requested_schema)


class StreamWrapper:
def __init__(self, stream):
self.stream = stream

def __arrow_c_stream__(self, requested_schema=None):
return self.stream.__arrow_c_stream__(requested_schema=requested_schema)


@pytest.mark.parametrize(
"data",
[
Expand All @@ -142,6 +158,12 @@ def test_get_table_types(sqlite):
lambda: pyarrow.table(
[[1, 2], ["foo", ""]], names=["ints", "strs"]
).to_reader(),
lambda: ArrayWrapper(
pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"])
),
lambda: StreamWrapper(
pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"])
),
],
)
@pytest.mark.sqlite
Expand Down Expand Up @@ -237,6 +259,8 @@ def test_query_fetch_df(sqlite):
(1.0, 2),
pyarrow.record_batch([[1.0], [2]], names=["float", "int"]),
pyarrow.table([[1.0], [2]], names=["float", "int"]),
ArrayWrapper(pyarrow.record_batch([[1.0], [2]], names=["float", "int"])),
StreamWrapper(pyarrow.table([[1.0], [2]], names=["float", "int"])),
],
)
def test_execute_parameters(sqlite, parameters):
Expand All @@ -253,6 +277,10 @@ def test_execute_parameters(sqlite, parameters):
pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]).to_batches()[0],
ArrayWrapper(
pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"])
),
StreamWrapper(pyarrow.table([[1, 3], ["a", None]], names=["float", "str"])),
((x, y) for x, y in ((1, "a"), (3, None))),
],
)
Expand Down
Loading

0 comments on commit 9544887

Please sign in to comment.