From 686865e31d21a04e4de47586b5a1d3918c1aa077 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 9 Apr 2024 12:55:03 +0200 Subject: [PATCH] properly handle the case of copy=False --- python/pyarrow/array.pxi | 14 +++++++++++++- python/pyarrow/tests/test_array.py | 28 +++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index ccb51b7337817..b48f08f80e514 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1544,7 +1544,19 @@ cdef class Array(_PandasConvertible): return _array_like_to_pandas(self, options, types_mapper=types_mapper) def __array__(self, dtype=None, copy=None): - # TODO honor the copy keyword + # TODO honor the copy=True case + if copy is False: + try: + values = self.to_numpy(zero_copy_only=True) + except ArrowInvalid as exc: + raise ArrowInvalid( + "Unable to avoid a copy while creating a numpy array as requested.\n" + "If using `np.array(obj, copy=False)` replace it with " + "`np.asarray(obj)` to allow a copy when needed" + ) + # values is already a numpy array at this point, but calling np.array(..) + # again to handle the `dtype` keyword with a no-copy guarantee + return np.array(values, dtype=dtype, copy=False) values = self.to_numpy(zero_copy_only=False) if dtype is None: return values diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 0191f15502894..a961166a2982f 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -31,6 +31,7 @@ import pyarrow as pa import pyarrow.tests.strategies as past +from pyarrow.vendored.version import Version def test_total_bytes_allocated(): @@ -3309,9 +3310,34 @@ def test_numpy_array_protocol(): np.testing.assert_array_equal(result, expected) # this should not raise a deprecation warning with numpy 2.0+ - result = np.asarray(arr, copy=False) + result = np.array(arr, copy=False) np.testing.assert_array_equal(result, expected) + result = np.array(arr, dtype="int64", copy=False) + np.testing.assert_array_equal(result, expected) + + # no zero-copy is possible + arr = pa.array([1, 2, None]) + expected = np.array([1, 2, np.nan], dtype="float64") + result = np.asarray(arr) + np.testing.assert_array_equal(result, expected) + + if Version(np.__version__) < Version("2.0"): + # copy keyword is not strict and not passed down to __array__ + result = np.array(arr, copy=False) + np.testing.assert_array_equal(result, expected) + + result = np.array(arr, dtype="float64", copy=False) + np.testing.assert_array_equal(result, expected) + else: + # starting with numpy 2.0, the copy=False keyword is assumed to be strict + with pytest.raises(ValueError, match="Unable to avoid a copy"): + np.array(arr, copy=False) + + arr = pa.array([1, 2, 3]) + with pytest.raises(ValueError): + np.array(arr, dtype="float64", copy=False) + def test_array_protocol():