Skip to content

Commit

Permalink
Manually implement PyTypeInfo to ensure that downcasting considers dt…
Browse files Browse the repository at this point in the history
…ype and ndim.
  • Loading branch information
adamreichold committed Jan 23, 2022
1 parent f72e038 commit f2ac1b1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- Unreleased
- Support object arrays ([#216](https://github.com/PyO3/rust-numpy/pull/216))
- Support borrowing arrays that are part of other Python objects via `PyArray::borrow_from_array` ([#230](https://github.com/PyO3/rust-numpy/pull/216))
- Fixed downcasting ignoring element type and dimensionality ([#265](https://github.com/PyO3/rust-numpy/pull/265))
- `PyArray::new` is now `unsafe`, as it produces uninitialized arrays ([#220](https://github.com/PyO3/rust-numpy/pull/220))
- `PyArray::from_exact_iter` does not unsoundly trust `ExactSizeIterator::len` any more ([#262](https://github.com/PyO3/rust-numpy/pull/262))
- `PyArray::as_cell_slice` was removed as it unsoundly interacts with `PyReadonlyArray` allowing safe code to violate aliasing rules ([#260](https://github.com/PyO3/rust-numpy/pull/260))
Expand Down
2 changes: 1 addition & 1 deletion examples/simple-extension/tests/test_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def test_conj():


def test_extract():
x = np.arange(5)
x = np.arange(5.0)
d = { "x": x }
np.testing.assert_almost_equal(extract(d), 10.0)
32 changes: 20 additions & 12 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ use ndarray::{
};
use num_traits::AsPrimitive;
use pyo3::{
ffi, pyobject_native_type_info, pyobject_native_type_named, type_object, types::PyModule,
ffi, pyobject_native_type_named, type_object, types::PyModule,
AsPyPointer, FromPyObject, IntoPy, Py, PyAny, PyDowncastError, PyErr, PyNativeType, PyObject,
PyResult, Python, ToPyObject,
PyResult, Python, ToPyObject, PyTypeInfo
};

use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
Expand Down Expand Up @@ -110,16 +110,24 @@ pub fn get_array_module(py: Python<'_>) -> PyResult<&PyModule> {
}

unsafe impl<T, D> type_object::PyLayout<PyArray<T, D>> for npyffi::PyArrayObject {}

impl<T, D> type_object::PySizedLayout<PyArray<T, D>> for npyffi::PyArrayObject {}

pyobject_native_type_info!(
PyArray<T, D>,
*npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
Some("numpy"),
#checkfunction=npyffi::PyArray_Check
; T
; D
);
unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
type AsRefTarget = Self;

const NAME: &'static str = "PyArray<T, D>";
const MODULE: ::std::option::Option<&'static str> = Some("numpy");

#[inline]
fn type_object_raw(_py: Python) -> *mut ffi::PyTypeObject {
unsafe { npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type) }
}

fn is_type_of(ob: &PyAny) -> bool {
<&Self>::extract(ob).is_ok()
}
}

pyobject_native_type_named!(PyArray<T, D> ; T ; D);

Expand All @@ -129,12 +137,12 @@ impl<T, D> IntoPy<PyObject> for PyArray<T, D> {
}
}

impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for &'py PyArray<T, D> {
// here we do type-check three times
// 1. Checks if the object is PyArray
// 2. Checks if the data type of the array is T
// 3. Checks if the dimension is same as D
fn extract(ob: &'a PyAny) -> PyResult<Self> {
fn extract(ob: &'py PyAny) -> PyResult<Self> {
let array = unsafe {
if npyffi::PyArray_Check(ob.as_ptr()) == 0 {
return Err(PyDowncastError::new(ob, "PyArray<T, D>").into());
Expand Down

0 comments on commit f2ac1b1

Please sign in to comment.