diff --git a/benches/array.rs b/benches/array.rs index eed05afef..3102e4d87 100644 --- a/benches/array.rs +++ b/benches/array.rs @@ -30,6 +30,27 @@ fn extract_failure(bencher: &mut Bencher) { }); } +#[bench] +fn downcast_success(bencher: &mut Bencher) { + Python::with_gil(|py| { + let any: &PyAny = PyArray2::::zeros(py, (10, 10), false); + + bencher.iter(|| { + black_box(any).downcast::>().unwrap(); + }); + }); +} + +#[bench] +fn downcast_failure(bencher: &mut Bencher) { + Python::with_gil(|py| { + let any: &PyAny = PyArray2::::zeros(py, (10, 10), false); + + bencher.iter(|| { + black_box(any).downcast::>().unwrap_err(); + }); + }); +} struct Iter(Range); impl Iterator for Iter { diff --git a/src/array.rs b/src/array.rs index f6907ea1e..38342293d 100644 --- a/src/array.rs +++ b/src/array.rs @@ -26,7 +26,7 @@ use crate::cold; use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; use crate::dtype::{Element, PyArrayDescr}; use crate::error::{ - BorrowError, DimensionalityError, FromVecError, NotContiguousError, TypeError, + BorrowError, DimensionalityError, FromVecError, IgnoreError, NotContiguousError, TypeError, DIMENSIONALITY_MISMATCH_ERR, MAX_DIMENSIONALITY_ERR, }; use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API}; @@ -131,7 +131,7 @@ unsafe impl PyTypeInfo for PyArray { } fn is_type_of(ob: &PyAny) -> bool { - <&Self>::extract(ob).is_ok() + Self::extract::(ob).is_ok() } } @@ -145,30 +145,7 @@ impl IntoPy for PyArray { impl<'py, T: Element, D: Dimension> FromPyObject<'py> for &'py PyArray { fn extract(ob: &'py PyAny) -> PyResult { - // Check if the object is an array. - let array = unsafe { - if npyffi::PyArray_Check(ob.py(), ob.as_ptr()) == 0 { - return Err(PyDowncastError::new(ob, PyArray::::NAME).into()); - } - &*(ob as *const PyAny as *const PyArray) - }; - - // Check if the dimensionality matches `D`. - let src_ndim = array.ndim(); - if let Some(dst_ndim) = D::NDIM { - if src_ndim != dst_ndim { - return Err(DimensionalityError::new(src_ndim, dst_ndim).into()); - } - } - - // Check if the element type matches `T`. - let src_dtype = array.dtype(); - let dst_dtype = T::get_dtype(ob.py()); - if !src_dtype.is_equiv_to(dst_dtype) { - return Err(TypeError::new(src_dtype, dst_dtype).into()); - } - - Ok(array) + PyArray::extract(ob) } } @@ -390,6 +367,36 @@ impl PyArray { } impl PyArray { + fn extract<'py, E>(ob: &'py PyAny) -> Result<&'py Self, E> + where + E: From> + From + From>, + { + // Check if the object is an array. + let array = unsafe { + if npyffi::PyArray_Check(ob.py(), ob.as_ptr()) == 0 { + return Err(PyDowncastError::new(ob, Self::NAME).into()); + } + &*(ob as *const PyAny as *const Self) + }; + + // Check if the dimensionality matches `D`. + let src_ndim = array.ndim(); + if let Some(dst_ndim) = D::NDIM { + if src_ndim != dst_ndim { + return Err(DimensionalityError::new(src_ndim, dst_ndim).into()); + } + } + + // Check if the element type matches `T`. + let src_dtype = array.dtype(); + let dst_dtype = T::get_dtype(ob.py()); + if !src_dtype.is_equiv_to(dst_dtype) { + return Err(TypeError::new(src_dtype, dst_dtype).into()); + } + + Ok(array) + } + /// Same as [`shape`][Self::shape], but returns `D` insead of `&[usize]`. #[inline(always)] pub fn dims(&self) -> D { diff --git a/src/dtype.rs b/src/dtype.rs index 9c008e213..7bbd257f6 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -117,9 +117,12 @@ impl PyArrayDescr { /// Returns true if two type descriptors are equivalent. pub fn is_equiv_to(&self, other: &Self) -> bool { + let self_ptr = self.as_dtype_ptr(); + let other_ptr = other.as_dtype_ptr(); + unsafe { - PY_ARRAY_API.PyArray_EquivTypes(self.py(), self.as_dtype_ptr(), other.as_dtype_ptr()) - != 0 + self_ptr == other_ptr + || PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0 } } @@ -413,7 +416,7 @@ fn npy_int_type_lookup(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES { fn npy_int_type() -> NPY_TYPES { let is_unsigned = T::min_value() == T::zero(); - let bit_width = size_of::() << 3; + let bit_width = 8 * size_of::(); match (is_unsigned, bit_width) { (false, 8) => NPY_TYPES::NPY_BYTE, @@ -449,6 +452,7 @@ macro_rules! impl_element_scalar { $(#[$meta])* unsafe impl Element for $ty { const IS_COPY: bool = true; + fn get_dtype(py: Python) -> &PyArrayDescr { PyArrayDescr::from_npy_type(py, $npy_type) } diff --git a/src/error.rs b/src/error.rs index 13586fa84..ca8825b16 100644 --- a/src/error.rs +++ b/src/error.rs @@ -162,3 +162,18 @@ impl fmt::Display for BorrowError { } impl_pyerr!(BorrowError); + +/// An internal type used to ignore certain error conditions +/// +/// This is beneficial when those errors will never reach a public API anyway +/// but dropping them will improve performance. +pub(crate) struct IgnoreError; + +impl From for IgnoreError +where + PyErr: From, +{ + fn from(_err: E) -> Self { + Self + } +}