Skip to content

Commit

Permalink
[Python] Table fails gracefully on non-cpu devices
Browse files Browse the repository at this point in the history
  • Loading branch information
danepitkin committed Sep 5, 2024
1 parent 50219ef commit 8d7ece6
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ cdef class Table(_Tabular):
cdef:
shared_ptr[CTable] sp_table
CTable* table
c_bool _is_cpu
c_bool _init_is_cpu

cdef void init(self, const shared_ptr[CTable]& table)

Expand Down
15 changes: 15 additions & 0 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -4180,6 +4180,7 @@ cdef class Table(_Tabular):

def __cinit__(self):
self.table = NULL
self._init_is_cpu = False

cdef void init(self, const shared_ptr[CTable]& table):
self.sp_table = table
Expand Down Expand Up @@ -5799,6 +5800,20 @@ cdef class Table(_Tabular):
"""
return self.to_reader().__arrow_c_stream__(requested_schema)

@property
def is_cpu(self):
"""
Whether all ChunkedArrays are CPU-accessible.
"""
if not self._init_is_cpu:
self._is_cpu = all(c.is_cpu for c in self.itercolumns())
self._init_is_cpu = True
return self._is_cpu

cdef void _assert_cpu(self) except *:
if not self.is_cpu:
raise NotImplementedError("Implemented only for data on CPU device")


def _reconstruct_table(arrays, schema):
"""
Expand Down
21 changes: 21 additions & 0 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3430,6 +3430,21 @@ def cuda_recordbatch(cuda_context, cpu_recordbatch):
return cpu_recordbatch.copy_to(cuda_context.memory_manager)


@pytest.fixture
def cpu_table(schema, cpu_chunked_array):
return pa.table([cpu_chunked_array, cpu_chunked_array], schema=schema)


@pytest.fixture
def cuda_table(schema, cuda_chunked_array):
return pa.table([cuda_chunked_array, cuda_chunked_array], schema=schema)


@pytest.fixture
def cpu_and_cuda_table(schema, cpu_chunked_array, cuda_chunked_array):
return pa.table([cpu_chunked_array, cuda_chunked_array], schema=schema)


def test_chunked_array_non_cpu(cuda_context, cpu_chunked_array, cuda_chunked_array,
cpu_and_cuda_chunked_array):
# type test
Expand Down Expand Up @@ -3737,3 +3752,9 @@ def test_recordbatch_non_cpu(cuda_context, cpu_recordbatch, cuda_recordbatch,
# __dataframe__() test
with pytest.raises(NotImplementedError):
from_dataframe(cuda_recordbatch.__dataframe__())


def test_table_non_cpu(cpu_table, cuda_table, cpu_and_cuda_table):
assert cpu_table.is_cpu
assert not cuda_table.is_cpu
assert not cpu_and_cuda_table.is_cpu

0 comments on commit 8d7ece6

Please sign in to comment.