Skip to content

Commit

Permalink
view for numpy arrays (#987)
Browse files Browse the repository at this point in the history
* reshape

* more tests

* Update numpy.h

* Update test_numpy_array.py

* array view

* test

* Update test_numpy_array.cpp

* Update numpy.h

* Update numpy.h

* Update test_numpy_array.cpp

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix merge bug

* Make clang-tidy happy

* Add xfail for PyPy

* Fix casting issue

* Fix formatting

* Apply clang-tidy

* Address reviews on additional tests

* Fix ordering

* Do a little more reordering

* Fix typo

* Try improving tests

* Fix error in reshape

* Add one more reshape test

* Fix bugs and add test

* Relax test

* streamlining new tests; removing a few stray msg

* Fix style revert

* Fix clang-tidy

* Misc tweaks:
* Comment: matching style in file (///), responsibility sentence, consistent punctuation.
* Replacing `unsigned char` with `uint8_t` for max consistency.
* Removing `1` from `array_view1` because there is only one.

* Partial clang-format-diff.

Co-authored-by: ncullen93 <[email protected]>
Co-authored-by: NC Cullen <[email protected]>
Co-authored-by: Aaron Gokaslan <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ralf Grosse-Kunstleve <[email protected]>
  • Loading branch information
6 people authored Aug 26, 2021
1 parent db44afa commit 503ff2a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
18 changes: 18 additions & 0 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ struct npy_api {
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int);
PyObject* (*PyArray_View_)(PyObject*, PyObject*, PyObject*);

private:
enum functions {
Expand All @@ -216,6 +217,7 @@ struct npy_api {
API_PyArray_DescrNewFromType = 96,
API_PyArray_Newshape = 135,
API_PyArray_Squeeze = 136,
API_PyArray_View = 137,
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
Expand Down Expand Up @@ -248,6 +250,7 @@ struct npy_api {
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_Newshape);
DECL_NPY_API(PyArray_Squeeze);
DECL_NPY_API(PyArray_View);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
Expand Down Expand Up @@ -802,6 +805,21 @@ class array : public buffer {
return new_array;
}

/// Create a view of an array in a different data type.
/// This function may fundamentally reinterpret the data in the array.
/// It is the responsibility of the caller to ensure that this is safe.
/// Only supports the `dtype` argument, the `type` argument is omitted,
/// to be added as needed.
array view(const std::string &dtype) {
auto &api = detail::npy_api::get();
auto new_view = reinterpret_steal<array>(api.PyArray_View_(
m_ptr, dtype::from_args(pybind11::str(dtype)).release().ptr(), nullptr));
if (!new_view) {
throw error_already_set();
}
return new_view;
}

/// Ensure that the argument is a NumPy array
/// In case of an error, nullptr is returned and the Python error is cleared.
static array ensure(handle h, int ExtraFlags = 0) {
Expand Down
3 changes: 3 additions & 0 deletions tests/test_numpy_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ TEST_SUBMODULE(numpy_array, sm) {
return a;
});

sm.def("array_view",
[](py::array_t<uint8_t> a, const std::string &dtype) { return a.view(dtype); });

sm.def("reshape_initializer_list", [](py::array_t<int> a, size_t N, size_t M, size_t O) {
return a.reshape({N, M, O});
});
Expand Down
15 changes: 15 additions & 0 deletions tests/test_numpy_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,21 @@ def test_array_create_and_resize():
assert np.all(a == 42.0)


def test_array_view():
a = np.ones(100 * 4).astype("uint8")
a_float_view = m.array_view(a, "float32")
assert a_float_view.shape == (100 * 1,) # 1 / 4 bytes = 8 / 32

a_int16_view = m.array_view(a, "int16") # 1 / 2 bytes = 16 / 32
assert a_int16_view.shape == (100 * 2,)


def test_array_view_invalid():
a = np.ones(100 * 4).astype("uint8")
with pytest.raises(TypeError):
m.array_view(a, "deadly_dtype")


def test_reshape_initializer_list():
a = np.arange(2 * 7 * 3) + 1
x = m.reshape_initializer_list(a, 2, 7, 3)
Expand Down

0 comments on commit 503ff2a

Please sign in to comment.