Skip to content

Commit

Permalink
GH-39217: [Python] RecordBatchReader.from_stream constructor for obje…
Browse files Browse the repository at this point in the history
…cts implementing the Arrow PyCapsule protocol (#39218)

### Rationale for this change

In contrast to Array, RecordBatch and Schema, for the C Stream (mapping to RecordBatchReader) we haven't an equivalent factory function that can accept any Arrow-compatible object and turn it into a pyarrow object through the PyCapsule Protocol.

For that reason, this proposes an explicit constructor class method for this: `RecordBatchReader.from_stream` (this is a quite generic name, so other name suggestions are certainly welcome).

### Are these changes tested?
TODO

* Closes: #39217

Authored-by: Joris Van den Bossche <[email protected]>
Signed-off-by: Joris Van den Bossche <[email protected]>
  • Loading branch information
jorisvandenbossche authored Jan 8, 2024
1 parent ffcfabd commit dc40e5f
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 8 deletions.
43 changes: 43 additions & 0 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,49 @@ cdef class RecordBatchReader(_Weakrefable):
self.reader = c_reader
return self

@staticmethod
def from_stream(data, schema=None):
"""
Create RecordBatchReader from a Arrow-compatible stream object.
This accepts objects implementing the Arrow PyCapsule Protocol for
streams, i.e. objects that have a ``__arrow_c_stream__`` method.
Parameters
----------
data : Arrow-compatible stream object
Any object that implements the Arrow PyCapsule Protocol for
streams.
schema : Schema, default None
The schema to which the stream should be casted, if supported
by the stream object.
Returns
-------
RecordBatchReader
"""

if not hasattr(data, "__arrow_c_stream__"):
raise TypeError(
"Expected an object implementing the Arrow PyCapsule Protocol for "
"streams (i.e. having a `__arrow_c_stream__` method), "
f"got {type(data)!r}."
)

if schema is not None:
if not hasattr(schema, "__arrow_c_schema__"):
raise TypeError(
"Expected an object implementing the Arrow PyCapsule Protocol for "
"schema (i.e. having a `__arrow_c_schema__` method), "
f"got {type(schema)!r}."
)
requested = schema.__arrow_c_schema__()
else:
requested = None

capsule = data.__arrow_c_stream__(requested)
return RecordBatchReader._import_from_c_capsule(capsule)

@staticmethod
def from_batches(Schema schema not None, batches):
"""
Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3351,8 +3351,8 @@ class ArrayWrapper:
def __init__(self, data):
self.data = data

def __arrow_c_array__(self, requested_type=None):
return self.data.__arrow_c_array__(requested_type)
def __arrow_c_array__(self, requested_schema=None):
return self.data.__arrow_c_array__(requested_schema)

# Can roundtrip through the C array protocol
arr = ArrayWrapper(pa.array([1, 2, 3], type=pa.int64()))
Expand Down
44 changes: 44 additions & 0 deletions python/pyarrow/tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,3 +1194,47 @@ def make_batches():
with pytest.raises(TypeError):
reader = pa.RecordBatchReader.from_batches(None, batches)
pass


def test_record_batch_reader_from_arrow_stream():

class StreamWrapper:
def __init__(self, batches):
self.batches = batches

def __arrow_c_stream__(self, requested_schema=None):
reader = pa.RecordBatchReader.from_batches(
self.batches[0].schema, self.batches)
return reader.__arrow_c_stream__(requested_schema)

data = [
pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),
pa.record_batch([pa.array([4, 5, 6], type=pa.int64())], names=['a'])
]
wrapper = StreamWrapper(data)

# Can roundtrip a pyarrow stream-like object
expected = pa.Table.from_batches(data)
reader = pa.RecordBatchReader.from_stream(expected)
assert reader.read_all() == expected

# Can roundtrip through the wrapper.
reader = pa.RecordBatchReader.from_stream(wrapper)
assert reader.read_all() == expected

# Passing schema works if already that schema
reader = pa.RecordBatchReader.from_stream(wrapper, schema=data[0].schema)
assert reader.read_all() == expected

# If schema doesn't match, raises NotImplementedError
with pytest.raises(NotImplementedError):
pa.RecordBatchReader.from_stream(
wrapper, schema=pa.schema([pa.field('a', pa.int32())])
)

# Proper type errors for wrong input
with pytest.raises(TypeError):
pa.RecordBatchReader.from_stream(data[0]['a'])

with pytest.raises(TypeError):
pa.RecordBatchReader.from_stream(expected, schema=data[0])
12 changes: 6 additions & 6 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,8 @@ class BatchWrapper:
def __init__(self, batch):
self.batch = batch

def __arrow_c_array__(self, requested_type=None):
return self.batch.__arrow_c_array__(requested_type)
def __arrow_c_array__(self, requested_schema=None):
return self.batch.__arrow_c_array__(requested_schema)

data = pa.record_batch([
pa.array([1, 2, 3], type=pa.int64())
Expand All @@ -586,8 +586,8 @@ class BatchWrapper:
def __init__(self, batch):
self.batch = batch

def __arrow_c_array__(self, requested_type=None):
return self.batch.__arrow_c_array__(requested_type)
def __arrow_c_array__(self, requested_schema=None):
return self.batch.__arrow_c_array__(requested_schema)

data = pa.record_batch([
pa.array([1, 2, 3], type=pa.int64())
Expand Down Expand Up @@ -615,10 +615,10 @@ class StreamWrapper:
def __init__(self, batches):
self.batches = batches

def __arrow_c_stream__(self, requested_type=None):
def __arrow_c_stream__(self, requested_schema=None):
reader = pa.RecordBatchReader.from_batches(
self.batches[0].schema, self.batches)
return reader.__arrow_c_stream__(requested_type)
return reader.__arrow_c_stream__(requested_schema)

data = [
pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),
Expand Down

0 comments on commit dc40e5f

Please sign in to comment.