diff --git a/dense/src/array/iterators.rs b/dense/src/array/iterators.rs index 57abaa86..77985a9c 100644 --- a/dense/src/array/iterators.rs +++ b/dense/src/array/iterators.rs @@ -2,7 +2,6 @@ use crate::array::*; use crate::layout::convert_1d_nd_from_shape; -use rlst_common::traits::{RandomAccessByValue, RandomAccessMut}; use rlst_common::types::Scalar; pub struct ArrayDefaultIterator< @@ -14,6 +13,7 @@ pub struct ArrayDefaultIterator< arr: &'a Array, shape: [usize; NDIM], pos: usize, + nelements: usize, } pub struct ArrayDefaultIteratorMut< @@ -27,6 +27,7 @@ pub struct ArrayDefaultIteratorMut< arr: &'a mut Array, shape: [usize; NDIM], pos: usize, + nelements: usize, } impl< @@ -41,6 +42,7 @@ impl< arr, shape: arr.shape(), pos: 0, + nelements: arr.shape().iter().product(), } } } @@ -56,7 +58,12 @@ impl< { fn new(arr: &'a mut Array) -> Self { let shape = arr.shape(); - Self { arr, shape, pos: 0 } + Self { + arr, + shape, + pos: 0, + nelements: shape.iter().product(), + } } } @@ -69,10 +76,12 @@ impl< { type Item = Item; fn next(&mut self) -> Option { - let indices = convert_1d_nd_from_shape(self.pos, self.shape)?; - let elem = self.arr.get_value(indices); + if self.pos >= self.nelements { + return None; + } + let indices = convert_1d_nd_from_shape(self.pos, self.shape); self.pos += 1; - elem + unsafe { Some(self.arr.get_value_unchecked(indices)) } } } @@ -94,12 +103,15 @@ impl< { type Item = &'a mut Item; fn next(&mut self) -> Option { - let indices = convert_1d_nd_from_shape(self.pos, self.shape)?; - let elem = self.arr.get_mut(indices); + if self.pos >= self.nelements { + return None; + } + let indices = convert_1d_nd_from_shape(self.pos, self.shape); self.pos += 1; - match elem { - None => None, - Some(inner) => Some(unsafe { std::mem::transmute::<&mut Item, &'a mut Item>(inner) }), + unsafe { + Some(std::mem::transmute::<&mut Item, &'a mut Item>( + self.arr.get_unchecked_mut(indices), + )) } } } diff --git a/dense/src/array/operations.rs b/dense/src/array/operations.rs index 91872f28..5f6be18c 100644 --- a/dense/src/array/operations.rs +++ b/dense/src/array/operations.rs @@ -47,9 +47,10 @@ impl< for data_index in 0..chunk.valid_entries { unsafe { - *self.get_unchecked_mut( - convert_1d_nd_from_shape(data_start + data_index, self.shape()).unwrap(), - ) = chunk.data[data_index]; + *self.get_unchecked_mut(convert_1d_nd_from_shape( + data_start + data_index, + self.shape(), + )) = chunk.data[data_index]; } } chunk_index += 1; @@ -107,9 +108,10 @@ impl< for data_index in 0..chunk.valid_entries { unsafe { - *self.get_unchecked_mut( - convert_1d_nd_from_shape(data_index + data_start, self.shape()).unwrap(), - ) = my_chunk.data[data_index]; + *self.get_unchecked_mut(convert_1d_nd_from_shape( + data_index + data_start, + self.shape(), + )) = my_chunk.data[data_index]; } } diff --git a/dense/src/array/views.rs b/dense/src/array/views.rs index c124781b..2c44da2b 100644 --- a/dense/src/array/views.rs +++ b/dense/src/array/views.rs @@ -1,6 +1,6 @@ //! Views onto an array -use crate::layout::convert_1d_nd_from_shape; +use crate::layout::{check_indices_in_bounds, convert_1d_nd_from_shape}; use super::Array; use rlst_common::traits::*; @@ -107,6 +107,7 @@ impl< type Item = Item; #[inline] unsafe fn get_value_unchecked(&self, indices: [usize; NDIM]) -> Self::Item { + debug_assert!(check_indices_in_bounds(indices, self.shape())); self.arr .get_value_unchecked(offset_indices(indices, self.offset)) } @@ -124,6 +125,7 @@ impl< type Item = Item; #[inline] unsafe fn get_unchecked(&self, indices: [usize; NDIM]) -> &Self::Item { + debug_assert!(check_indices_in_bounds(indices, self.shape())); self.arr.get_unchecked(offset_indices(indices, self.offset)) } } @@ -151,10 +153,10 @@ impl< if let Some(mut chunk) = super::empty_chunk(chunk_index, nelements) { for count in 0..chunk.valid_entries { unsafe { - chunk.data[count] = self.get_value_unchecked( - convert_1d_nd_from_shape(chunk.start_index + count, self.shape()) - .unwrap(), - ) + chunk.data[count] = self.get_value_unchecked(convert_1d_nd_from_shape( + chunk.start_index + count, + self.shape(), + )) } } Some(chunk) @@ -193,6 +195,7 @@ impl< type Item = Item; #[inline] unsafe fn get_value_unchecked(&self, indices: [usize; NDIM]) -> Self::Item { + debug_assert!(check_indices_in_bounds(indices, self.shape())); self.arr .get_value_unchecked(offset_indices(indices, self.offset)) } @@ -211,6 +214,7 @@ impl< type Item = Item; #[inline] unsafe fn get_unchecked(&self, indices: [usize; NDIM]) -> &Self::Item { + debug_assert!(check_indices_in_bounds(indices, self.shape())); self.arr.get_unchecked(offset_indices(indices, self.offset)) } } @@ -242,10 +246,10 @@ impl< if let Some(mut chunk) = super::empty_chunk(chunk_index, nelements) { for count in 0..chunk.valid_entries { unsafe { - chunk.data[count] = self.get_value_unchecked( - convert_1d_nd_from_shape(chunk.start_index + count, self.shape()) - .unwrap(), - ) + chunk.data[count] = self.get_value_unchecked(convert_1d_nd_from_shape( + chunk.start_index + count, + self.shape(), + )) } } Some(chunk) @@ -268,6 +272,7 @@ impl< type Item = Item; #[inline] unsafe fn get_unchecked_mut(&mut self, indices: [usize; NDIM]) -> &mut Self::Item { + debug_assert!(check_indices_in_bounds(indices, self.shape())); self.arr .get_unchecked_mut(offset_indices(indices, self.offset)) } diff --git a/dense/src/base_array.rs b/dense/src/base_array.rs index bf870456..c443bf62 100644 --- a/dense/src/base_array.rs +++ b/dense/src/base_array.rs @@ -1,9 +1,9 @@ use crate::array::empty_chunk; use crate::data_container::{DataContainer, DataContainerMut}; -use crate::layout::{convert_1d_nd_from_shape, convert_nd_raw, stride_from_shape}; -use rlst_common::traits::{ - ChunkedAccess, RandomAccessByValue, UnsafeRandomAccessByValue, UnsafeRandomAccessMut, +use crate::layout::{ + check_indices_in_bounds, convert_1d_nd_from_shape, convert_nd_raw, stride_from_shape, }; +use rlst_common::traits::{ChunkedAccess, UnsafeRandomAccessByValue, UnsafeRandomAccessMut}; use rlst_common::{ traits::{Shape, UnsafeRandomAccessByRef}, types::Scalar, @@ -51,6 +51,7 @@ impl, const NDIM: usize> #[inline] unsafe fn get_unchecked(&self, indices: [usize; NDIM]) -> &Self::Item { + debug_assert!(check_indices_in_bounds(indices, self.shape())); let index = convert_nd_raw(indices, self.stride); self.data.get_unchecked(index) } @@ -63,6 +64,7 @@ impl, const NDIM: usize> #[inline] unsafe fn get_value_unchecked(&self, indices: [usize; NDIM]) -> Self::Item { + debug_assert!(check_indices_in_bounds(indices, self.shape())); let index = convert_nd_raw(indices, self.stride); self.data.get_unchecked_value(index) } @@ -75,6 +77,7 @@ impl, const NDIM: usize> #[inline] unsafe fn get_unchecked_mut(&mut self, indices: [usize; NDIM]) -> &mut Self::Item { + debug_assert!(check_indices_in_bounds(indices, self.shape())); let index = convert_nd_raw(indices, self.stride); self.data.get_unchecked_mut(index) } @@ -94,9 +97,10 @@ impl, const N: usize, const ND if let Some(mut chunk) = empty_chunk(chunk_index, nelements) { for count in 0..chunk.valid_entries { unsafe { - chunk.data[count] = self.get_value_unchecked( - convert_1d_nd_from_shape(chunk.start_index + count, self.shape()).unwrap(), - ) + chunk.data[count] = self.get_value_unchecked(convert_1d_nd_from_shape( + chunk.start_index + count, + self.shape(), + )) } } Some(chunk) diff --git a/dense/src/data_container.rs b/dense/src/data_container.rs index 5accb273..787fcbea 100644 --- a/dense/src/data_container.rs +++ b/dense/src/data_container.rs @@ -165,10 +165,12 @@ impl DataContainer for VectorContainer { type Item = Item; unsafe fn get_unchecked_value(&self, index: usize) -> Self::Item { + debug_assert!(index < self.number_of_elements()); *self.data.get_unchecked(index) } unsafe fn get_unchecked(&self, index: usize) -> &Self::Item { + debug_assert!(index < self.number_of_elements()); self.data.get_unchecked(index) } @@ -183,6 +185,7 @@ impl DataContainer for VectorContainer { impl DataContainerMut for VectorContainer { unsafe fn get_unchecked_mut(&mut self, index: usize) -> &mut Self::Item { + debug_assert!(index < self.number_of_elements()); self.data.get_unchecked_mut(index) } @@ -195,10 +198,12 @@ impl DataContainer for ArrayContainer { type Item = Item; unsafe fn get_unchecked_value(&self, index: usize) -> Self::Item { + debug_assert!(index < self.number_of_elements()); *self.data.get_unchecked(index) } unsafe fn get_unchecked(&self, index: usize) -> &Self::Item { + debug_assert!(index < self.number_of_elements()); self.data.get_unchecked(index) } @@ -213,6 +218,7 @@ impl DataContainer for ArrayContainer { impl DataContainerMut for ArrayContainer { unsafe fn get_unchecked_mut(&mut self, index: usize) -> &mut Self::Item { + debug_assert!(index < self.number_of_elements()); self.data.get_unchecked_mut(index) } @@ -225,10 +231,12 @@ impl<'a, Item: Scalar> DataContainer for SliceContainer<'a, Item> { type Item = Item; unsafe fn get_unchecked_value(&self, index: usize) -> Self::Item { + debug_assert!(index < self.number_of_elements()); *self.data.get_unchecked(index) } unsafe fn get_unchecked(&self, index: usize) -> &Self::Item { + debug_assert!(index < self.number_of_elements()); self.data.get_unchecked(index) } @@ -245,10 +253,12 @@ impl<'a, Item: Scalar> DataContainer for SliceContainerMut<'a, Item> { type Item = Item; unsafe fn get_unchecked_value(&self, index: usize) -> Self::Item { + debug_assert!(index < self.number_of_elements()); *self.data.get_unchecked(index) } unsafe fn get_unchecked(&self, index: usize) -> &Self::Item { + debug_assert!(index < self.number_of_elements()); self.data.get_unchecked(index) } @@ -263,6 +273,7 @@ impl<'a, Item: Scalar> DataContainer for SliceContainerMut<'a, Item> { impl<'a, Item: Scalar> DataContainerMut for SliceContainerMut<'a, Item> { unsafe fn get_unchecked_mut(&mut self, index: usize) -> &mut Self::Item { + debug_assert!(index < self.number_of_elements()); self.data.get_unchecked_mut(index) } diff --git a/dense/src/layout.rs b/dense/src/layout.rs index 75347f4b..78a84de0 100644 --- a/dense/src/layout.rs +++ b/dense/src/layout.rs @@ -16,6 +16,16 @@ pub fn stride_from_shape(shape: [usize; NDIM]) -> [usize; NDI output } +/// Return true if `indices` in bounds with respect to `shape`. +pub fn check_indices_in_bounds(indices: [usize; N], shape: [usize; N]) -> bool { + for (ind, s) in indices.iter().zip(shape.iter()) { + if ind >= s { + return false; + } + } + true +} + /// Convert an n-d index into a 1d index. #[inline] pub fn convert_nd_raw(indices: [usize; NDIM], stride: [usize; NDIM]) -> usize { @@ -33,19 +43,14 @@ pub fn convert_nd_raw(indices: [usize; NDIM], stride: [usize; pub fn convert_1d_nd_from_shape( mut index: usize, shape: [usize; NDIM], -) -> Option<[usize; NDIM]> { +) -> [usize; NDIM] { let mut res = [0; NDIM]; - let nelements = shape.iter().product(); - if index >= nelements { - None - } else { - for ind in 0..NDIM { - res[ind] = index % shape[ind]; - index /= shape[ind]; - } - - Some(res) + debug_assert!(index < shape.iter().product()); + for ind in 0..NDIM { + res[ind] = index % shape[ind]; + index /= shape[ind]; } + res } #[cfg(test)] @@ -59,7 +64,7 @@ mod test { let stride = stride_from_shape(shape); let index_1d = convert_nd_raw(indices, stride); - let actual_nd = convert_1d_nd_from_shape(index_1d, shape).unwrap(); + let actual_nd = convert_1d_nd_from_shape(index_1d, shape); println!("{}, {:#?}", index_1d, actual_nd); diff --git a/rlst/tests/array_operations.rs b/rlst/tests/array_operations.rs index c28537e4..733c330c 100644 --- a/rlst/tests/array_operations.rs +++ b/rlst/tests/array_operations.rs @@ -61,7 +61,7 @@ fn test_multiple_operations() { res.sum_into_chunked::<_, 64>(arr3.view()); for index in 0..nelements { - let indices = convert_1d_nd_from_shape(index, res.shape()).unwrap(); + let indices = convert_1d_nd_from_shape(index, res.shape()); approx::assert_relative_eq!( res[indices],