From 4a95f1dda18e6c85f4329686508855cec49559b6 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sun, 23 Jan 2022 19:43:03 +0100 Subject: [PATCH] Manually implement PyTypeInfo to ensure that downcasting considers dtype and ndim. --- CHANGELOG.md | 1 + examples/simple-extension/tests/test_ext.py | 2 +- src/array.rs | 34 +++++++++++++-------- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa33f0f23..dd31173c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - Unreleased - Support object arrays ([#216](https://github.com/PyO3/rust-numpy/pull/216)) - Support borrowing arrays that are part of other Python objects via `PyArray::borrow_from_array` ([#230](https://github.com/PyO3/rust-numpy/pull/216)) + - Fixed downcasting ignoring element type and dimensionality ([#265](https://github.com/PyO3/rust-numpy/pull/265)) - `PyArray::new` is now `unsafe`, as it produces uninitialized arrays ([#220](https://github.com/PyO3/rust-numpy/pull/220)) - `PyArray::from_exact_iter` does not unsoundly trust `ExactSizeIterator::len` any more ([#262](https://github.com/PyO3/rust-numpy/pull/262)) - `PyArray::as_cell_slice` was removed as it unsoundly interacts with `PyReadonlyArray` allowing safe code to violate aliasing rules ([#260](https://github.com/PyO3/rust-numpy/pull/260)) diff --git a/examples/simple-extension/tests/test_ext.py b/examples/simple-extension/tests/test_ext.py index 4554d5230..c163f71a0 100644 --- a/examples/simple-extension/tests/test_ext.py +++ b/examples/simple-extension/tests/test_ext.py @@ -25,6 +25,6 @@ def test_conj(): def test_extract(): - x = np.arange(5) + x = np.arange(5.0) d = { "x": x } np.testing.assert_almost_equal(extract(d), 10.0) diff --git a/src/array.rs b/src/array.rs index c94ff85f9..37462ac2f 100644 --- a/src/array.rs +++ b/src/array.rs @@ -13,9 +13,9 @@ use ndarray::{ }; use num_traits::AsPrimitive; use pyo3::{ - ffi, pyobject_native_type_info, pyobject_native_type_named, type_object, types::PyModule, - AsPyPointer, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyErr, PyNativeType, PyObject, - PyResult, Python, ToPyObject, + ffi, pyobject_native_type_named, type_object, types::PyModule, AsPyPointer, FromPyObject, + IntoPy, Py, PyAny, PyDowncastError, PyErr, PyNativeType, PyObject, PyResult, PyTypeInfo, + Python, ToPyObject, }; use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; @@ -110,16 +110,24 @@ pub fn get_array_module(py: Python<'_>) -> PyResult<&PyModule> { } unsafe impl type_object::PyLayout> for npyffi::PyArrayObject {} + impl type_object::PySizedLayout> for npyffi::PyArrayObject {} -pyobject_native_type_info!( - PyArray, - *npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type), - Some("numpy"), - #checkfunction=npyffi::PyArray_Check - ; T - ; D -); +unsafe impl PyTypeInfo for PyArray { + type AsRefTarget = Self; + + const NAME: &'static str = "PyArray"; + const MODULE: ::std::option::Option<&'static str> = Some("numpy"); + + #[inline] + fn type_object_raw(_py: Python) -> *mut ffi::PyTypeObject { + unsafe { npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type) } + } + + fn is_type_of(ob: &PyAny) -> bool { + <&Self>::extract(ob).is_ok() + } +} pyobject_native_type_named!(PyArray ; T ; D); @@ -129,12 +137,12 @@ impl IntoPy for PyArray { } } -impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray { +impl<'py, T: Element, D: Dimension> FromPyObject<'py> for &'py PyArray { // here we do type-check three times // 1. Checks if the object is PyArray // 2. Checks if the data type of the array is T // 3. Checks if the dimension is same as D - fn extract(ob: &'a PyAny) -> PyResult { + fn extract(ob: &'py PyAny) -> PyResult { let array = unsafe { if npyffi::PyArray_Check(ob.as_ptr()) == 0 { return Err(PyDowncastError::new(ob, "PyArray").into());