diff --git a/common/src/lib.rs b/common/src/lib.rs index 0ade0ded..e188d25c 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,6 +1,5 @@ //! Common RLST data structures #![cfg_attr(feature = "strict", deny(warnings))] -pub mod tools; pub mod traits; pub mod types; diff --git a/common/src/traits/accessors.rs b/common/src/traits/accessors.rs index 1ea7f436..8b137891 100644 --- a/common/src/traits/accessors.rs +++ b/common/src/traits/accessors.rs @@ -1,143 +1 @@ -//! Traits for access to matrix data. -//! -//! Each random access trait has a two-dimensional and a one-dimensional access method, -//! namely `get` and `get1d` (together with their mutable and unsafe variants). -//! The two-dimensional access takes a row and a column and returns the corresponding -//! matrix element. The one-dimensional access takes a single `index` parameter that -//! iterates through the matrix elements. -//! -//! If the [`crate::traits::properties::Shape`] and [`crate::traits::properties::NumberOfElements`] -//! traits are implemented on top of [`UnsafeRandomAccessByValue`], [`UnsafeRandomAccessByRef`] -//! and [`UnsafeRandomAccessMut`] then the -//! corresponding bounds-checked traits [`RandomAccessByValue`], [`RandomAccessByRef`] and -//! [`RandomAccessMut`] are auto-implemented. -//! -//! To get raw access to the underlying data use the [`RawAccess`] and [`RawAccessMut`] traits. -use crate::traits::properties::Shape; -use crate::types::{DataChunk, Scalar}; - -/// This trait provides unsafe access by value to the underlying data. -pub trait UnsafeRandomAccessByValue { - type Item: Scalar; - - /// Return the element at position determined by `multi_index`. - /// - /// # Safety - /// `multi_index` must not be out of bounds. - unsafe fn get_value_unchecked(&self, multi_index: [usize; NDIM]) -> Self::Item; -} - -/// This trait provides unsafe access by reference to the underlying data. -pub trait UnsafeRandomAccessByRef { - type Item: Scalar; - - /// Return a mutable reference to the element at position determined by `multi_index`. - /// - /// # Safety - /// `multi_index` must not be out of bounds. - unsafe fn get_unchecked(&self, multi_index: [usize; NDIM]) -> &Self::Item; -} - -/// This trait provides unsafe mutable access to the underlying data. -pub trait UnsafeRandomAccessMut { - type Item: Scalar; - - /// Return a mutable reference to the element at position determined by `multi_index`. - /// - /// # Safety - /// `multi_index` must not be out of bounds. - unsafe fn get_unchecked_mut(&mut self, multi_index: [usize; NDIM]) -> &mut Self::Item; -} - -/// This trait provides bounds checked access to the underlying data by value. -pub trait RandomAccessByValue: UnsafeRandomAccessByValue { - /// Return the element at position determined by `multi_index`. - fn get_value(&self, multi_index: [usize; NDIM]) -> Option; -} - -/// This trait provides bounds checked access to the underlying data by reference. -pub trait RandomAccessByRef: UnsafeRandomAccessByRef { - /// Return a reference to the element at position determined by `multi_index`. - fn get(&self, multi_index: [usize; NDIM]) -> Option<&Self::Item>; -} - -/// This trait provides bounds checked mutable access to the underlying data. -pub trait RandomAccessMut: UnsafeRandomAccessMut { - /// Return a mutable reference to the element at position determined by `multi_index`. - fn get_mut(&mut self, multi_index: [usize; NDIM]) -> Option<&mut Self::Item>; -} - -/// Return chunks of data of size N; -pub trait ChunkedAccess { - type Item: Scalar; - fn get_chunk(&self, chunk_index: usize) -> Option>; -} - -/// Get raw access to the underlying data. -pub trait RawAccess { - type Item: Scalar; - - /// Get a slice of the whole data. - fn data(&self) -> &[Self::Item]; -} - -/// Get mutable raw access to the underlying data. -pub trait RawAccessMut: RawAccess { - /// Get a mutable slice of the whole data. - fn data_mut(&mut self) -> &mut [Self::Item]; -} - -/// Check if `multi_index` not out of bounds with respect to `shape`. -#[inline] -fn check_dimension(multi_index: [usize; NDIM], shape: [usize; NDIM]) -> bool { - multi_index - .iter() - .zip(shape.iter()) - .fold(true, |acc, (ind, s)| acc && (ind < s)) -} - -impl< - Item: Scalar, - Mat: UnsafeRandomAccessByValue + Shape, - const NDIM: usize, - > RandomAccessByValue for Mat -{ - fn get_value(&self, multi_index: [usize; NDIM]) -> Option { - if check_dimension(multi_index, self.shape()) { - Some(unsafe { self.get_value_unchecked(multi_index) }) - } else { - None - } - } -} - -impl< - Item: Scalar, - Mat: UnsafeRandomAccessMut + Shape, - const NDIM: usize, - > RandomAccessMut for Mat -{ - fn get_mut(&mut self, multi_index: [usize; NDIM]) -> Option<&mut Self::Item> { - if check_dimension(multi_index, self.shape()) { - unsafe { Some(self.get_unchecked_mut(multi_index)) } - } else { - None - } - } -} - -impl< - Item: Scalar, - Mat: UnsafeRandomAccessByRef + Shape, - const NDIM: usize, - > RandomAccessByRef for Mat -{ - fn get(&self, multi_index: [usize; NDIM]) -> Option<&Self::Item> { - if check_dimension(multi_index, self.shape()) { - unsafe { Some(self.get_unchecked(multi_index)) } - } else { - None - } - } -} diff --git a/common/src/traits/in_place_operations.rs b/common/src/traits/in_place_operations.rs index f5f82fd2..1ac67c09 100644 --- a/common/src/traits/in_place_operations.rs +++ b/common/src/traits/in_place_operations.rs @@ -35,8 +35,3 @@ pub trait SetDiag { /// the length of the diagonal. fn set_diag_from_slice(&mut self, diag: &[Self::Item]); } - -/// Resize an operator in place -pub trait ResizeInPlace { - fn resize_in_place(&mut self, shape: [usize; NDIM]); -} diff --git a/common/src/traits/iterators.rs b/common/src/traits/iterators.rs index 060eaaf2..f1db6ba3 100644 --- a/common/src/traits/iterators.rs +++ b/common/src/traits/iterators.rs @@ -24,16 +24,6 @@ pub trait ColumnMajorIterator { fn iter_col_major(&self) -> Self::Iter<'_>; } -/// Default iterator. -pub trait DefaultIterator { - type Item: Scalar; - type Iter<'a>: std::iter::Iterator - where - Self: 'a; - - fn iter(&self) -> Self::Iter<'_>; -} - /// Mutable iterator through the elements in column-major ordering. pub trait ColumnMajorIteratorMut { type Item: Scalar; @@ -44,16 +34,6 @@ pub trait ColumnMajorIteratorMut { fn iter_col_major_mut(&mut self) -> Self::IterMut<'_>; } -/// Mutable default iterator. -pub trait DefaultIteratorMut { - type Item: Scalar; - type IterMut<'a>: std::iter::Iterator - where - Self: 'a; - - fn iter_mut(&mut self) -> Self::IterMut<'_>; -} - /// Iterate through the diagonal. pub trait DiagonalIterator { type Item: Scalar; diff --git a/common/src/traits/operations.rs b/common/src/traits/operations.rs index 23222a1a..13912b0c 100644 --- a/common/src/traits/operations.rs +++ b/common/src/traits/operations.rs @@ -130,57 +130,3 @@ pub trait PermuteRows { fn permute_rows(&self, permutation: &[usize]) -> Self::Out; } - -/// Multiply First * Second and sum into Self -pub trait MultInto { - type Item: Scalar; - fn simple_mult_into(self, arr_a: First, arr_b: Second) -> Self - where - Self: Sized, - { - self.mult_into( - TransMode::NoTrans, - TransMode::NoTrans, - ::one(), - arr_a, - arr_b, - ::zero(), - ) - } - fn mult_into( - self, - transa: TransMode, - transb: TransMode, - alpha: Self::Item, - arr_a: First, - arr_b: Second, - beta: Self::Item, - ) -> Self; -} - -/// Multiply First * Second and sum into Self. Allow to resize Self if necessary -pub trait MultIntoResize { - type Item: Scalar; - fn simple_mult_into_resize(self, arr_a: First, arr_b: Second) -> Self - where - Self: Sized, - { - self.mult_into_resize( - TransMode::NoTrans, - TransMode::NoTrans, - ::one(), - arr_a, - arr_b, - ::zero(), - ) - } - fn mult_into_resize( - self, - transa: TransMode, - transb: TransMode, - alpha: Self::Item, - arr_a: First, - arr_b: Second, - beta: Self::Item, - ) -> Self; -} diff --git a/common/src/traits/properties.rs b/common/src/traits/properties.rs index f9ad869d..8b137891 100644 --- a/common/src/traits/properties.rs +++ b/common/src/traits/properties.rs @@ -1,27 +1 @@ -//! Traits describing properties of objects. -/// Return the shape of the object. -pub trait Shape { - fn shape(&self) -> [usize; NDIM]; - - /// Return true if a dimension is 0. - fn is_empty(&self) -> bool { - let shape = self.shape(); - for elem in shape { - if elem == 0 { - return true; - } - } - false - } -} - -/// Return the stride of the object. -pub trait Stride { - fn stride(&self) -> [usize; NDIM]; -} - -/// Return the number of elements. -pub trait NumberOfElements { - fn number_of_elements(&self) -> usize; -} diff --git a/dense/src/array.rs b/dense/src/array.rs index 7eae1961..141af0a7 100644 --- a/dense/src/array.rs +++ b/dense/src/array.rs @@ -2,18 +2,7 @@ use crate::base_array::BaseArray; use crate::data_container::VectorContainer; -use rlst_common::traits::ChunkedAccess; -use rlst_common::traits::NumberOfElements; -use rlst_common::traits::RandomAccessByRef; -use rlst_common::traits::RandomAccessMut; -use rlst_common::traits::RawAccess; -use rlst_common::traits::RawAccessMut; -use rlst_common::traits::ResizeInPlace; -use rlst_common::traits::Shape; -use rlst_common::traits::Stride; -use rlst_common::traits::UnsafeRandomAccessByRef; -use rlst_common::traits::UnsafeRandomAccessByValue; -use rlst_common::traits::UnsafeRandomAccessMut; +use crate::traits::*; use rlst_common::types::DataChunk; use rlst_common::types::Scalar; diff --git a/dense/src/array/empty_axis.rs b/dense/src/array/empty_axis.rs index 33759f59..ecf0ebf1 100644 --- a/dense/src/array/empty_axis.rs +++ b/dense/src/array/empty_axis.rs @@ -233,7 +233,7 @@ where #[cfg(test)] mod test { - use rlst_common::traits::{Shape, Stride}; + use crate::traits::{Shape, Stride}; use crate::{array::empty_axis::AxisPosition, rlst_dynamic_array3}; diff --git a/dense/src/array/iterators.rs b/dense/src/array/iterators.rs index 9318bd7e..9884ef02 100644 --- a/dense/src/array/iterators.rs +++ b/dense/src/array/iterators.rs @@ -149,7 +149,7 @@ impl< Item: Scalar, ArrayImpl: UnsafeRandomAccessByValue + Shape, const NDIM: usize, - > rlst_common::traits::iterators::DefaultIterator for Array + > crate::traits::DefaultIterator for Array { type Item = Item; type Iter<'a> = ArrayDefaultIterator<'a, Item, ArrayImpl, NDIM> where Self: 'a; @@ -165,7 +165,7 @@ impl< + Shape + UnsafeRandomAccessMut, const NDIM: usize, - > rlst_common::traits::iterators::DefaultIteratorMut for Array + > crate::traits::DefaultIteratorMut for Array { type Item = Item; type IterMut<'a> = ArrayDefaultIteratorMut<'a, Item, ArrayImpl, NDIM> where Self: 'a; @@ -178,7 +178,7 @@ impl< #[cfg(test)] mod test { - use rlst_common::traits::*; + use crate::traits::*; #[test] fn test_iter() { diff --git a/dense/src/array/mult_into.rs b/dense/src/array/mult_into.rs index b61ab0d1..a603719a 100644 --- a/dense/src/array/mult_into.rs +++ b/dense/src/array/mult_into.rs @@ -1,9 +1,9 @@ //! Multiplication of Arrays +use crate::traits::MultInto; +use crate::traits::MultIntoResize; use rlst_blis::interface::gemm::Gemm; pub use rlst_blis::interface::types::TransMode; -use rlst_common::traits::MultInto; -use rlst_common::traits::MultIntoResize; use super::{empty_axis::AxisPosition, *}; diff --git a/dense/src/array/operators/to_complex.rs b/dense/src/array/operators/to_complex.rs index 05474f70..4e594094 100644 --- a/dense/src/array/operators/to_complex.rs +++ b/dense/src/array/operators/to_complex.rs @@ -82,41 +82,33 @@ impl< } impl + Shape, const NDIM: usize> - rlst_common::traits::ToComplex for Array + Array { - type Out = Array, NDIM>; - - fn to_complex(self) -> Self::Out { + pub fn to_complex(self) -> Array, NDIM> { Array::new(ArrayToComplex::new(self)) } } impl + Shape, const NDIM: usize> - rlst_common::traits::ToComplex for Array + Array { - type Out = Array, NDIM>; - - fn to_complex(self) -> Self::Out { + pub fn to_complex(self) -> Array, NDIM> { Array::new(ArrayToComplex::new(self)) } } impl + Shape, const NDIM: usize> - rlst_common::traits::ToComplex for Array + Array { - type Out = Self; - - fn to_complex(self) -> Self::Out { + pub fn to_complex(self) -> Self { self } } impl + Shape, const NDIM: usize> - rlst_common::traits::ToComplex for Array + Array { - type Out = Self; - - fn to_complex(self) -> Self::Out { + pub fn to_complex(self) -> Self { self } } diff --git a/dense/src/array/random.rs b/dense/src/array/random.rs index 18d9b06c..00349733 100644 --- a/dense/src/array/random.rs +++ b/dense/src/array/random.rs @@ -1,12 +1,13 @@ //! Methods for the creation of random matrices. use crate::data_container::DataContainerMut; +use crate::tools::*; +use crate::traits::*; use rand::prelude::*; use rand_chacha::ChaCha8Rng; use rand_distr::Standard; use rand_distr::StandardNormal; -use rlst_common::tools::*; -use rlst_common::traits::*; +use rlst_common::types::Scalar; use super::Array; use crate::base_array::BaseArray; diff --git a/dense/src/array/slice.rs b/dense/src/array/slice.rs index 21795b60..3efea210 100644 --- a/dense/src/array/slice.rs +++ b/dense/src/array/slice.rs @@ -307,8 +307,8 @@ fn compute_raw_range( #[cfg(test)] mod test { + use crate::traits::*; use crate::{layout::convert_nd_raw, rlst_dynamic_array3}; - use rlst_common::traits::*; #[test] fn test_create_slice() { diff --git a/dense/src/array/views.rs b/dense/src/array/views.rs index 48cf9f6a..a02d52be 100644 --- a/dense/src/array/views.rs +++ b/dense/src/array/views.rs @@ -3,7 +3,8 @@ use crate::layout::{check_multi_index_in_bounds, convert_1d_nd_from_shape}; use super::Array; -use rlst_common::traits::*; +use crate::traits::*; +use rlst_common::types::*; pub struct ArrayView< 'a, diff --git a/dense/src/base_array.rs b/dense/src/base_array.rs index 9b121b9c..81a18bac 100644 --- a/dense/src/base_array.rs +++ b/dense/src/base_array.rs @@ -3,14 +3,8 @@ use crate::data_container::{DataContainer, DataContainerMut, ResizeableDataConta use crate::layout::{ check_multi_index_in_bounds, convert_1d_nd_from_shape, convert_nd_raw, stride_from_shape, }; -use rlst_common::traits::{ - ChunkedAccess, RawAccess, RawAccessMut, ResizeInPlace, Stride, UnsafeRandomAccessByValue, - UnsafeRandomAccessMut, -}; -use rlst_common::{ - traits::{Shape, UnsafeRandomAccessByRef}, - types::Scalar, -}; +use crate::traits::*; +use rlst_common::types::Scalar; pub struct BaseArray, const NDIM: usize> { data: Data, diff --git a/dense/src/lib.rs b/dense/src/lib.rs index 6251a118..ff4eaf20 100644 --- a/dense/src/lib.rs +++ b/dense/src/lib.rs @@ -40,6 +40,8 @@ pub mod base_array; pub mod data_container; pub mod linalg; pub mod number_types; +pub mod tools; +pub mod traits; // pub mod base_array; // pub mod base_matrix; diff --git a/dense/src/linalg/inverse.rs b/dense/src/linalg/inverse.rs index abff76e0..692c9b47 100644 --- a/dense/src/linalg/inverse.rs +++ b/dense/src/linalg/inverse.rs @@ -1,8 +1,8 @@ //! Implement the Inverse use crate::array::Array; +use crate::traits::*; use lapack::{cgetrf, cgetri, dgetrf, dgetri, sgetrf, sgetri, zgetrf, zgetri}; use num::traits::Zero; -use rlst_common::traits::*; use rlst_common::types::{c32, c64, RlstError, RlstResult, Scalar}; use super::assert_lapack_stride; @@ -72,8 +72,8 @@ mod test { use super::*; + use crate::assert_array_abs_diff_eq; use paste::paste; - use rlst_common::assert_array_abs_diff_eq; use crate::array::empty_array; use crate::rlst_dynamic_array2; diff --git a/dense/src/linalg/lu.rs b/dense/src/linalg/lu.rs index c3ab9a36..c29b573b 100644 --- a/dense/src/linalg/lu.rs +++ b/dense/src/linalg/lu.rs @@ -2,6 +2,7 @@ use super::assert_lapack_stride; use super::Trans; use crate::array::Array; +use crate::traits::*; use lapack::{dgetrf, dgetrs}; use num::One; use rlst_common::traits::*; @@ -285,7 +286,7 @@ impl_lu!(f64, dgetrf, dgetrs); #[cfg(test)] mod test { - use rlst_common::assert_array_relative_eq; + use crate::assert_array_relative_eq; use crate::rlst_dynamic_array2; diff --git a/dense/src/linalg/pseudo_inverse.rs b/dense/src/linalg/pseudo_inverse.rs index 674cd84a..5f2840d2 100644 --- a/dense/src/linalg/pseudo_inverse.rs +++ b/dense/src/linalg/pseudo_inverse.rs @@ -1,6 +1,7 @@ //! Implement the Pseudo-Inverse use crate::array::Array; use crate::rlst_dynamic_array2; +use crate::traits::*; use itertools::Itertools; use num::traits::{One, Zero}; use rlst_common::traits::*; @@ -105,8 +106,8 @@ mod test { use super::*; use crate::array::empty_array; + use crate::assert_array_abs_diff_eq; use paste::paste; - use rlst_common::assert_array_abs_diff_eq; use crate::rlst_dynamic_array2; diff --git a/dense/src/linalg/qr.rs b/dense/src/linalg/qr.rs index 15e36948..d3c6cb6b 100644 --- a/dense/src/linalg/qr.rs +++ b/dense/src/linalg/qr.rs @@ -2,6 +2,7 @@ use super::assert_lapack_stride; use crate::array::Array; +use crate::traits::*; use itertools::Itertools; use lapack::{cgeqp3, cunmqr, dgeqp3, dormqr, sgeqp3, sormqr, zgeqp3, zunmqr}; @@ -551,8 +552,8 @@ implement_qr_complex!(c32, cgeqp3, cunmqr); #[cfg(test)] mod test { + use crate::{assert_array_abs_diff_eq, assert_array_relative_eq, traits::*}; use rlst_common::types::*; - use rlst_common::{assert_array_abs_diff_eq, assert_array_relative_eq, traits::*}; use crate::array::empty_array; use crate::rlst_dynamic_array2; diff --git a/dense/src/linalg/svd.rs b/dense/src/linalg/svd.rs index 06abd54c..581bfb89 100644 --- a/dense/src/linalg/svd.rs +++ b/dense/src/linalg/svd.rs @@ -1,8 +1,8 @@ //! Implement the SVD use crate::array::Array; +use crate::traits::*; use lapack::{cgesvd, dgesvd, sgesvd, zgesvd}; use num::traits::Zero; -use rlst_common::traits::*; use rlst_common::types::{c32, c64, RlstError, RlstResult, Scalar}; use super::assert_lapack_stride; @@ -410,9 +410,9 @@ mod test { use super::*; + use crate::assert_array_relative_eq; use approx::assert_relative_eq; use paste::paste; - use rlst_common::assert_array_relative_eq; use crate::array::empty_array; use crate::{rlst_dynamic_array1, rlst_dynamic_array2}; diff --git a/dense/src/matrix_multiply.rs b/dense/src/matrix_multiply.rs index 6fb87b07..ad55e801 100644 --- a/dense/src/matrix_multiply.rs +++ b/dense/src/matrix_multiply.rs @@ -3,10 +3,11 @@ //! This module implements the matrix multiplication. The current implementation //! uses the [rlst-blis] crate. +use crate::traits::*; +use crate::traits::*; use rlst_blis::interface::gemm::Gemm; use rlst_blis::interface::types::TransMode; -use rlst_common::traits::*; -use rlst_common::{traits::RawAccess, types::Scalar}; +use rlst_common::types::Scalar; pub fn matrix_multiply< Item: Scalar + Gemm, @@ -74,7 +75,7 @@ pub fn matrix_multiply< #[cfg(test)] mod test { - use rlst_common::assert_array_relative_eq; + use crate::assert_array_relative_eq; use rlst_common::types::{c32, c64}; use super::*; diff --git a/common/src/tools.rs b/dense/src/tools.rs similarity index 99% rename from common/src/tools.rs rename to dense/src/tools.rs index 8b37d5c4..2264d16b 100644 --- a/common/src/tools.rs +++ b/dense/src/tools.rs @@ -1,11 +1,10 @@ //! Useful library tools. -use crate::{ - traits::{RandomAccessByValue, Shape}, - types::*, -}; +use crate::traits::*; +use crate::traits::*; use rand::prelude::*; use rand_distr::Distribution; +use rlst_common::types::*; /// This trait implements a simple convenient function to return random scalars /// from a given random number generator and distribution. For complex types the diff --git a/dense/src/traits.rs b/dense/src/traits.rs new file mode 100644 index 00000000..e9fe7246 --- /dev/null +++ b/dense/src/traits.rs @@ -0,0 +1,113 @@ +//! Dense matrix traits + +pub mod accessors; + +pub use accessors::*; + +use rlst_blis::interface::types::TransMode; +use rlst_common::types::*; + +/// Return the shape of the object. +pub trait Shape { + fn shape(&self) -> [usize; NDIM]; + + /// Return true if a dimension is 0. + fn is_empty(&self) -> bool { + let shape = self.shape(); + for elem in shape { + if elem == 0 { + return true; + } + } + false + } +} + +/// Return the stride of the object. +pub trait Stride { + fn stride(&self) -> [usize; NDIM]; +} + +/// Return the number of elements. +pub trait NumberOfElements { + fn number_of_elements(&self) -> usize; +} + +/// Resize an operator in place +pub trait ResizeInPlace { + fn resize_in_place(&mut self, shape: [usize; NDIM]); +} + +/// Multiply First * Second and sum into Self +pub trait MultInto { + type Item: Scalar; + fn simple_mult_into(self, arr_a: First, arr_b: Second) -> Self + where + Self: Sized, + { + self.mult_into( + TransMode::NoTrans, + TransMode::NoTrans, + ::one(), + arr_a, + arr_b, + ::zero(), + ) + } + fn mult_into( + self, + transa: TransMode, + transb: TransMode, + alpha: Self::Item, + arr_a: First, + arr_b: Second, + beta: Self::Item, + ) -> Self; +} + +/// Multiply First * Second and sum into Self. Allow to resize Self if necessary +pub trait MultIntoResize { + type Item: Scalar; + fn simple_mult_into_resize(self, arr_a: First, arr_b: Second) -> Self + where + Self: Sized, + { + self.mult_into_resize( + TransMode::NoTrans, + TransMode::NoTrans, + ::one(), + arr_a, + arr_b, + ::zero(), + ) + } + fn mult_into_resize( + self, + transa: TransMode, + transb: TransMode, + alpha: Self::Item, + arr_a: First, + arr_b: Second, + beta: Self::Item, + ) -> Self; +} + +/// Default iterator. +pub trait DefaultIterator { + type Item: Scalar; + type Iter<'a>: std::iter::Iterator + where + Self: 'a; + + fn iter(&self) -> Self::Iter<'_>; +} + +/// Mutable default iterator. +pub trait DefaultIteratorMut { + type Item: Scalar; + type IterMut<'a>: std::iter::Iterator + where + Self: 'a; + + fn iter_mut(&mut self) -> Self::IterMut<'_>; +} diff --git a/dense/src/traits/accessors.rs b/dense/src/traits/accessors.rs new file mode 100644 index 00000000..1622dae7 --- /dev/null +++ b/dense/src/traits/accessors.rs @@ -0,0 +1,145 @@ +//! Fundamental traits for dense arrays. + +//! Traits for access to matrix data. +//! +//! Each random access trait has a two-dimensional and a one-dimensional access method, +//! namely `get` and `get1d` (together with their mutable and unsafe variants). +//! The two-dimensional access takes a row and a column and returns the corresponding +//! matrix element. The one-dimensional access takes a single `index` parameter that +//! iterates through the matrix elements. +//! +//! If the [`crate::traits::properties::Shape`] and [`crate::traits::properties::NumberOfElements`] +//! traits are implemented on top of [`UnsafeRandomAccessByValue`], [`UnsafeRandomAccessByRef`] +//! and [`UnsafeRandomAccessMut`] then the +//! corresponding bounds-checked traits [`RandomAccessByValue`], [`RandomAccessByRef`] and +//! [`RandomAccessMut`] are auto-implemented. +//! +//! To get raw access to the underlying data use the [`RawAccess`] and [`RawAccessMut`] traits. + +use crate::traits::Shape; +use rlst_common::types::{DataChunk, Scalar}; + +/// This trait provides unsafe access by value to the underlying data. +pub trait UnsafeRandomAccessByValue { + type Item: Scalar; + + /// Return the element at position determined by `multi_index`. + /// + /// # Safety + /// `multi_index` must not be out of bounds. + unsafe fn get_value_unchecked(&self, multi_index: [usize; NDIM]) -> Self::Item; +} + +/// This trait provides unsafe access by reference to the underlying data. +pub trait UnsafeRandomAccessByRef { + type Item: Scalar; + + /// Return a mutable reference to the element at position determined by `multi_index`. + /// + /// # Safety + /// `multi_index` must not be out of bounds. + unsafe fn get_unchecked(&self, multi_index: [usize; NDIM]) -> &Self::Item; +} + +/// This trait provides unsafe mutable access to the underlying data. +pub trait UnsafeRandomAccessMut { + type Item: Scalar; + + /// Return a mutable reference to the element at position determined by `multi_index`. + /// + /// # Safety + /// `multi_index` must not be out of bounds. + unsafe fn get_unchecked_mut(&mut self, multi_index: [usize; NDIM]) -> &mut Self::Item; +} + +/// This trait provides bounds checked access to the underlying data by value. +pub trait RandomAccessByValue: UnsafeRandomAccessByValue { + /// Return the element at position determined by `multi_index`. + fn get_value(&self, multi_index: [usize; NDIM]) -> Option; +} + +/// This trait provides bounds checked access to the underlying data by reference. +pub trait RandomAccessByRef: UnsafeRandomAccessByRef { + /// Return a reference to the element at position determined by `multi_index`. + fn get(&self, multi_index: [usize; NDIM]) -> Option<&Self::Item>; +} + +/// This trait provides bounds checked mutable access to the underlying data. +pub trait RandomAccessMut: UnsafeRandomAccessMut { + /// Return a mutable reference to the element at position determined by `multi_index`. + fn get_mut(&mut self, multi_index: [usize; NDIM]) -> Option<&mut Self::Item>; +} + +/// Return chunks of data of size N; +pub trait ChunkedAccess { + type Item: Scalar; + fn get_chunk(&self, chunk_index: usize) -> Option>; +} + +/// Get raw access to the underlying data. +pub trait RawAccess { + type Item: Scalar; + + /// Get a slice of the whole data. + fn data(&self) -> &[Self::Item]; +} + +/// Get mutable raw access to the underlying data. +pub trait RawAccessMut: RawAccess { + /// Get a mutable slice of the whole data. + fn data_mut(&mut self) -> &mut [Self::Item]; +} + +/// Check if `multi_index` not out of bounds with respect to `shape`. +#[inline] +fn check_dimension(multi_index: [usize; NDIM], shape: [usize; NDIM]) -> bool { + multi_index + .iter() + .zip(shape.iter()) + .fold(true, |acc, (ind, s)| acc && (ind < s)) +} + +impl< + Item: Scalar, + Mat: UnsafeRandomAccessByValue + Shape, + const NDIM: usize, + > RandomAccessByValue for Mat +{ + fn get_value(&self, multi_index: [usize; NDIM]) -> Option { + if check_dimension(multi_index, self.shape()) { + Some(unsafe { self.get_value_unchecked(multi_index) }) + } else { + None + } + } +} + +impl< + Item: Scalar, + Mat: UnsafeRandomAccessMut + Shape, + const NDIM: usize, + > RandomAccessMut for Mat +{ + fn get_mut(&mut self, multi_index: [usize; NDIM]) -> Option<&mut Self::Item> { + if check_dimension(multi_index, self.shape()) { + unsafe { Some(self.get_unchecked_mut(multi_index)) } + } else { + None + } + } +} + +impl< + Item: Scalar, + Mat: UnsafeRandomAccessByRef + Shape, + const NDIM: usize, + > RandomAccessByRef for Mat +{ + fn get(&self, multi_index: [usize; NDIM]) -> Option<&Self::Item> { + if check_dimension(multi_index, self.shape()) { + unsafe { Some(self.get_unchecked(multi_index)) } + } else { + None + } + } +} diff --git a/rlst/examples/array_operations.rs b/rlst/examples/array_operations.rs index ee9ed46d..b0f5b5d4 100644 --- a/rlst/examples/array_operations.rs +++ b/rlst/examples/array_operations.rs @@ -1,7 +1,7 @@ //! Tests of array algebray operations use rlst::rlst_dynamic_array3; -use rlst_common::traits::*; +use rlst_dense::traits::*; pub fn main() { let shape = [3, 4, 8]; diff --git a/rlst/tests/array_operations.rs b/rlst/tests/array_operations.rs index cbe9ea24..23288330 100644 --- a/rlst/tests/array_operations.rs +++ b/rlst/tests/array_operations.rs @@ -3,8 +3,8 @@ use approx::assert_relative_eq; use rlst::rlst_dynamic_array3; use rlst_common::types::*; -use rlst_common::{assert_array_relative_eq, traits::*}; use rlst_dense::{array::iterators::AsMultiIndex, layout::convert_1d_nd_from_shape}; +use rlst_dense::{assert_array_relative_eq, traits::*}; #[test] fn test_addition() {