Skip to content

Commit

Permalink
Expand dtype accessors (#3868)
Browse files Browse the repository at this point in the history
* Added constructor based on typenum, based on PyArray_DescrFromType

Added accessors for typenum, alignment, byteorder and flags fields of
PyArray_Descr struct.

* Added tests for new py::dtype constructor, and for accessors

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed the comment for alignment method

* Update include/pybind11/numpy.h

Co-authored-by: Aaron Gokaslan <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Aaron Gokaslan <[email protected]>
  • Loading branch information
3 people authored Apr 14, 2022
1 parent fa98804 commit ba7a0fa
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
24 changes: 24 additions & 0 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,13 @@ class dtype : public object {
m_ptr = from_args(std::move(args)).release().ptr();
}

explicit dtype(int typenum)
: object(detail::npy_api::get().PyArray_DescrFromType_(typenum), stolen_t{}) {
if (m_ptr == nullptr) {
throw error_already_set();
}
}

/// This is essentially the same as calling numpy.dtype(args) in Python.
static dtype from_args(object args) {
PyObject *ptr = nullptr;
Expand Down Expand Up @@ -596,6 +603,23 @@ class dtype : public object {
return detail::array_descriptor_proxy(m_ptr)->type;
}

/// type number of dtype.
ssize_t num() const {
// Note: The signature, `dtype::num` follows the naming of NumPy's public
// Python API (i.e., ``dtype.num``), rather than its internal
// C API (``PyArray_Descr::type_num``).
return detail::array_descriptor_proxy(m_ptr)->type_num;
}

/// Single character for byteorder
char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }

/// Alignment of the data type
int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; }

/// Flags for the array descriptor
char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; }

private:
static object _dtype_from_pep3118() {
static PyObject *obj = module_::import("numpy.core._internal")
Expand Down
29 changes: 29 additions & 0 deletions tests/test_numpy_dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ py::list test_dtype_ctors() {
list.append(py::dtype(names, formats, offsets, 20));
list.append(py::dtype(py::buffer_info((void *) 0, sizeof(unsigned int), "I", 1)));
list.append(py::dtype(py::buffer_info((void *) 0, 0, "T{i:a:f:b:}", 1)));
list.append(py::dtype(py::detail::npy_api::NPY_DOUBLE_));
return list;
}

Expand Down Expand Up @@ -440,6 +441,34 @@ TEST_SUBMODULE(numpy_dtypes, m) {
}
return list;
});
m.def("test_dtype_num", [dtype_names]() {
py::list list;
for (const auto &dt_name : dtype_names) {
list.append(py::dtype(dt_name).num());
}
return list;
});
m.def("test_dtype_byteorder", [dtype_names]() {
py::list list;
for (const auto &dt_name : dtype_names) {
list.append(py::dtype(dt_name).byteorder());
}
return list;
});
m.def("test_dtype_alignment", [dtype_names]() {
py::list list;
for (const auto &dt_name : dtype_names) {
list.append(py::dtype(dt_name).alignment());
}
return list;
});
m.def("test_dtype_flags", [dtype_names]() {
py::list list;
for (const auto &dt_name : dtype_names) {
list.append(py::dtype(dt_name).flags());
}
return list;
});
m.def("test_dtype_methods", []() {
py::list list;
auto dt1 = py::dtype::of<int32_t>();
Expand Down
8 changes: 7 additions & 1 deletion tests/test_numpy_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def test_dtype(simple_dtype):
d1,
np.dtype("uint32"),
d2,
np.dtype("d"),
]

assert m.test_dtype_methods() == [
Expand All @@ -175,8 +176,13 @@ def test_dtype(simple_dtype):
np.zeros(1, m.trailing_padding_dtype())
)

expected_chars = "bhilqBHILQefdgFDG?MmO"
assert m.test_dtype_kind() == list("iiiiiuuuuuffffcccbMmO")
assert m.test_dtype_char_() == list("bhilqBHILQefdgFDG?MmO")
assert m.test_dtype_char_() == list(expected_chars)
assert m.test_dtype_num() == [np.dtype(ch).num for ch in expected_chars]
assert m.test_dtype_byteorder() == [np.dtype(ch).byteorder for ch in expected_chars]
assert m.test_dtype_alignment() == [np.dtype(ch).alignment for ch in expected_chars]
assert m.test_dtype_flags() == [chr(np.dtype(ch).flags) for ch in expected_chars]


def test_recarray(simple_dtype, packed_dtype):
Expand Down

0 comments on commit ba7a0fa

Please sign in to comment.