From dc9b59a8f0448fa2ead4a719a42439a7b7df98c0 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Fri, 8 Mar 2024 09:25:24 +0100 Subject: [PATCH] API: Add `numpy2.h` instead and make `numpy.h` safe This means that users of `numpy.h` cannot be broken, but need to update to `numpy2.h` if they want to compile for NumPy 2. Using Macros simply and didn't bother to try to remove unnecessary code paths. --- CMakeLists.txt | 1 + include/pybind11/eigen/matrix.h | 2 +- include/pybind11/eigen/tensor.h | 2 +- include/pybind11/numpy.h | 67 ++++++++++++++++++------ include/pybind11/numpy2.h | 5 ++ tests/extra_python_package/test_files.py | 1 + 6 files changed, 61 insertions(+), 17 deletions(-) create mode 100644 include/pybind11/numpy2.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 1e75e99eb95..8bafdacd463 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -161,6 +161,7 @@ set(PYBIND11_HEADERS include/pybind11/iostream.h include/pybind11/functional.h include/pybind11/numpy.h + include/pybind11/numpy2.h include/pybind11/operators.h include/pybind11/pybind11.h include/pybind11/pytypes.h diff --git a/include/pybind11/eigen/matrix.h b/include/pybind11/eigen/matrix.h index 8d4342f81bb..0554ca4e4f5 100644 --- a/include/pybind11/eigen/matrix.h +++ b/include/pybind11/eigen/matrix.h @@ -9,7 +9,7 @@ #pragma once -#include "../numpy.h" +#include "../numpy2.h" #include "common.h" /* HINT: To suppress warnings originating from the Eigen headers, use -isystem. diff --git a/include/pybind11/eigen/tensor.h b/include/pybind11/eigen/tensor.h index 25d12baca13..21e73fa1a00 100644 --- a/include/pybind11/eigen/tensor.h +++ b/include/pybind11/eigen/tensor.h @@ -7,7 +7,7 @@ #pragma once -#include "../numpy.h" +#include "../numpy2.h" #include "common.h" #if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 6058d81d1c3..1b704f84a1f 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -53,17 +53,6 @@ struct handle_type_name { template struct npy_format_descriptor; -struct PyArrayDescr_Proxy { - PyObject_HEAD - PyObject *typeobj; - char kind; - char type; - char byteorder; - char _former_flags; - int type_num; - /* Additional fields are NumPy version specific. */ -}; - /* NumPy 1 proxy (always includes legacy fields) */ struct PyArrayDescr1_Proxy { PyObject_HEAD @@ -80,6 +69,22 @@ struct PyArrayDescr1_Proxy { PyObject *names; }; +#ifdef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT +struct PyArrayDescr_Proxy { + PyObject_HEAD + PyObject *typeobj; + char kind; + char type; + char byteorder; + char _former_flags; + int type_num; + /* Additional fields are NumPy version specific. */ +}; +#else +/* NumPy 1.x only, we can expose all fields */ +typedef PyArrayDescr1_Proxy PyArrayDescr_Proxy; +#endif + /* NumPy 2 proxy, including legacy fields */ struct PyArrayDescr2_Proxy { PyObject_HEAD @@ -164,6 +169,13 @@ PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name object numpy_version = numpy_lib.attr("NumpyVersion")(version_string); int major_version = numpy_version.attr("major").cast(); +#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT + if (major_version >= 2) { + throw std::runtime_error("module compiled without NumPy 2 support. Please modify the " + "`pybind11/numpy.h` include to `pybind11/numpy2.h` and recompile " + "(this remains NumPy 1.x compatible but has minor changes)."); + } +#endif /* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially became a private module. */ std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core"; @@ -276,6 +288,16 @@ struct npy_api { PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *); int (*PyArray_DescrConverter_)(PyObject *, PyObject **); bool (*PyArray_EquivTypes_)(PyObject *, PyObject *); +#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT + int (*PyArray_GetArrayParamsFromObject_)(PyObject *, + PyObject *, + unsigned char, + PyObject **, + int *, + Py_intptr_t *, + PyObject **, + PyObject *); +#endif PyObject *(*PyArray_Squeeze_)(PyObject *); // Unused. Not removed because that affects ABI of the class. int (*PyArray_SetBaseObject_)(PyObject *, PyObject *); @@ -302,6 +324,9 @@ struct npy_api { API_PyArray_View = 137, API_PyArray_DescrConverter = 174, API_PyArray_EquivTypes = 182, +#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT + API_PyArray_GetArrayParamsFromObject = 278, +#endif API_PyArray_SetBaseObject = 282 }; @@ -644,12 +669,16 @@ class dtype : public object { } /// Size of the data type in bytes. +#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT + int itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; } +#else ssize_t itemsize() const { if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) { return detail::array_descriptor1_proxy(m_ptr)->elsize; } return detail::array_descriptor2_proxy(m_ptr)->elsize; } +#endif /// Returns true for structured data types. bool has_fields() const { @@ -686,21 +715,29 @@ class dtype : public object { /// Single character for byteorder char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; } - /// Alignment of the data type +/// Alignment of the data type +#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT + int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; } +#else ssize_t alignment() const { if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) { - return detail::array_descriptor1_proxy(m_ptr)->alignment; + return detail::array_descriptor2_proxy(m_ptr)->alignment; } - return detail::array_descriptor2_proxy(m_ptr)->alignment; + return detail::array_descriptor1_proxy(m_ptr)->alignment; } +#endif - /// Flags for the array descriptor +/// Flags for the array descriptor +#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT + char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; } +#else std::uint64_t flags() const { if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) { return (unsigned char) detail::array_descriptor1_proxy(m_ptr)->flags; } return detail::array_descriptor2_proxy(m_ptr)->flags; } +#endif private: static object &_dtype_from_pep3118() { diff --git a/include/pybind11/numpy2.h b/include/pybind11/numpy2.h new file mode 100644 index 00000000000..35039f1b3dc --- /dev/null +++ b/include/pybind11/numpy2.h @@ -0,0 +1,5 @@ +#pragma once + +#define PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT +#include "numpy.h" +#undef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT diff --git a/tests/extra_python_package/test_files.py b/tests/extra_python_package/test_files.py index 344e70d5db8..262fadc8ded 100644 --- a/tests/extra_python_package/test_files.py +++ b/tests/extra_python_package/test_files.py @@ -38,6 +38,7 @@ "include/pybind11/gil_safe_call_once.h", "include/pybind11/iostream.h", "include/pybind11/numpy.h", + "include/pybind11/numpy2.h", "include/pybind11/operators.h", "include/pybind11/options.h", "include/pybind11/pybind11.h",