Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
danepitkin committed Aug 29, 2024
1 parent a8cc140 commit f9bc668
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
1 change: 0 additions & 1 deletion python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,6 @@ cdef class RecordBatch(_Tabular):
Schema _schema

cdef void init(self, const shared_ptr[CRecordBatch]& table)
cdef void _assert_cpu(self) except *


cdef class Device(_Weakrefable):
Expand Down
30 changes: 14 additions & 16 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3392,7 +3392,7 @@ def cuda_recordbatch(cuda_context, cpu_recordbatch):
return cpu_recordbatch.copy_to(cuda_context.memory_manager)


def verify_recordbatch_on_cuda_device(batch, expected_schema):
def verify_cuda_recordbatch(batch, expected_schema):
batch.validate()
assert batch.device_type == pa.DeviceAllocationType.CUDA
assert batch.is_cpu is False
Expand All @@ -3406,7 +3406,7 @@ def verify_recordbatch_on_cuda_device(batch, expected_schema):

def test_recordbatch_non_cpu(cuda_context, cpu_recordbatch, cuda_recordbatch,
cuda_arrays, schema):
verify_recordbatch_on_cuda_device(cuda_recordbatch, expected_schema=schema)
verify_cuda_recordbatch(cuda_recordbatch, expected_schema=schema)
assert cuda_recordbatch.shape == (5, 2)

# columns() test
Expand All @@ -3430,9 +3430,7 @@ def test_recordbatch_non_cpu(cuda_context, cpu_recordbatch, cuda_recordbatch,

# remove_column() test
new_batch = cuda_recordbatch.remove_column(1)
assert len(new_batch.columns) == 1
assert new_batch.device_type == pa.DeviceAllocationType.CUDA
assert new_batch[0].device_type == pa.DeviceAllocationType.CUDA
verify_cuda_recordbatch(new_batch, expected_schema=schema.remove(1))

# drop_columns() test
new_batch = cuda_recordbatch.drop_columns(['c0', 'c1'])
Expand All @@ -3441,19 +3439,17 @@ def test_recordbatch_non_cpu(cuda_context, cpu_recordbatch, cuda_recordbatch,

# select() test
new_batch = cuda_recordbatch.select(['c0'])
verify_recordbatch_on_cuda_device(new_batch, expected_schema=schema.remove(1))
verify_cuda_recordbatch(new_batch, expected_schema=schema.remove(1))

# cast() test
new_schema = pa.schema([pa.field('c0', pa.int64()), pa.field('c1', pa.int64())])
with pytest.raises(NotImplementedError):
cuda_recordbatch.cast(new_schema)

# drop_null() test
validity = cuda_context.buffer_from_data(
np.array([True, False, True, False, True], dtype=np.bool_))
null_col = pa.Array.from_buffers(
pa.int32(), 5,
[validity, cuda_context.buffer_from_data(np.array([0] * 5, dtype=np.int32))])
null_col = pa.array([-2, -1, 0, 1, 2],
mask=[True, False, True, False, True]).copy_to(
cuda_context.memory_manager)
cuda_recordbatch_with_nulls = cuda_recordbatch.add_column(2, 'c2', null_col)
with pytest.raises(NotImplementedError):
cuda_recordbatch_with_nulls.drop_null()
Expand Down Expand Up @@ -3481,11 +3477,13 @@ def test_recordbatch_non_cpu(cuda_context, cpu_recordbatch, cuda_recordbatch,

# from_arrays() test
new_batch = pa.RecordBatch.from_arrays(cuda_arrays, ['c0', 'c1'])
verify_recordbatch_on_cuda_device(new_batch, expected_schema=schema)
verify_cuda_recordbatch(new_batch, expected_schema=schema)
assert new_batch.copy_to(pa.default_cpu_memory_manager()).equals(cpu_recordbatch)

# from_pydict() test
new_batch = pa.RecordBatch.from_pydict({'c0': cuda_arrays[0], 'c1': cuda_arrays[1]})
verify_recordbatch_on_cuda_device(new_batch, expected_schema=schema)
verify_cuda_recordbatch(new_batch, expected_schema=schema)
assert new_batch.copy_to(pa.default_cpu_memory_manager()).equals(cpu_recordbatch)

# from_struct_array() test
fields = [schema.field(i) for i in range(len(schema.names))]
Expand Down Expand Up @@ -3527,21 +3525,21 @@ def test_recordbatch_non_cpu(cuda_context, cpu_recordbatch, cuda_recordbatch,

# slice() test
new_batch = cuda_recordbatch.slice(1, 3)
verify_recordbatch_on_cuda_device(new_batch, expected_schema=schema)
verify_cuda_recordbatch(new_batch, expected_schema=schema)
assert new_batch.num_rows == 3
cpu_batch = new_batch.copy_to(pa.default_cpu_memory_manager())
assert cpu_batch == cpu_recordbatch.slice(1, 3)

# replace_schema_metadata() test
new_batch = cuda_recordbatch.replace_schema_metadata({b'key': b'value'})
verify_recordbatch_on_cuda_device(new_batch, expected_schema=schema)
verify_cuda_recordbatch(new_batch, expected_schema=schema)
assert new_batch.schema.metadata == {b'key': b'value'}

# rename_columns() test
new_batch = cuda_recordbatch.rename_columns(['col0', 'col1'])
expected_schema = pa.schema(
[pa.field('col0', pa.int16()), pa.field('col1', pa.int32())])
verify_recordbatch_on_cuda_device(new_batch, expected_schema=expected_schema)
verify_cuda_recordbatch(new_batch, expected_schema=expected_schema)

# validate() test
cuda_recordbatch.validate()
Expand Down

0 comments on commit f9bc668

Please sign in to comment.