Skip to content

Commit

Permalink
WIP: Clean up traits
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke committed Dec 3, 2023
1 parent ad1fa61 commit d80e827
Show file tree
Hide file tree
Showing 27 changed files with 302 additions and 310 deletions.
1 change: 0 additions & 1 deletion common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! Common RLST data structures
#![cfg_attr(feature = "strict", deny(warnings))]

pub mod tools;
pub mod traits;
pub mod types;
142 changes: 0 additions & 142 deletions common/src/traits/accessors.rs
Original file line number Diff line number Diff line change
@@ -1,143 +1 @@
//! Traits for access to matrix data.
//!
//! Each random access trait has a two-dimensional and a one-dimensional access method,
//! namely `get` and `get1d` (together with their mutable and unsafe variants).
//! The two-dimensional access takes a row and a column and returns the corresponding
//! matrix element. The one-dimensional access takes a single `index` parameter that
//! iterates through the matrix elements.
//!
//! If the [`crate::traits::properties::Shape`] and [`crate::traits::properties::NumberOfElements`]
//! traits are implemented on top of [`UnsafeRandomAccessByValue`], [`UnsafeRandomAccessByRef`]
//! and [`UnsafeRandomAccessMut`] then the
//! corresponding bounds-checked traits [`RandomAccessByValue`], [`RandomAccessByRef`] and
//! [`RandomAccessMut`] are auto-implemented.
//!
//! To get raw access to the underlying data use the [`RawAccess`] and [`RawAccessMut`] traits.

use crate::traits::properties::Shape;
use crate::types::{DataChunk, Scalar};

/// This trait provides unsafe access by value to the underlying data.
pub trait UnsafeRandomAccessByValue<const NDIM: usize> {
type Item: Scalar;

/// Return the element at position determined by `multi_index`.
///
/// # Safety
/// `multi_index` must not be out of bounds.
unsafe fn get_value_unchecked(&self, multi_index: [usize; NDIM]) -> Self::Item;
}

/// This trait provides unsafe access by reference to the underlying data.
pub trait UnsafeRandomAccessByRef<const NDIM: usize> {
type Item: Scalar;

/// Return a mutable reference to the element at position determined by `multi_index`.
///
/// # Safety
/// `multi_index` must not be out of bounds.
unsafe fn get_unchecked(&self, multi_index: [usize; NDIM]) -> &Self::Item;
}

/// This trait provides unsafe mutable access to the underlying data.
pub trait UnsafeRandomAccessMut<const NDIM: usize> {
type Item: Scalar;

/// Return a mutable reference to the element at position determined by `multi_index`.
///
/// # Safety
/// `multi_index` must not be out of bounds.
unsafe fn get_unchecked_mut(&mut self, multi_index: [usize; NDIM]) -> &mut Self::Item;
}

/// This trait provides bounds checked access to the underlying data by value.
pub trait RandomAccessByValue<const NDIM: usize>: UnsafeRandomAccessByValue<NDIM> {
/// Return the element at position determined by `multi_index`.
fn get_value(&self, multi_index: [usize; NDIM]) -> Option<Self::Item>;
}

/// This trait provides bounds checked access to the underlying data by reference.
pub trait RandomAccessByRef<const NDIM: usize>: UnsafeRandomAccessByRef<NDIM> {
/// Return a reference to the element at position determined by `multi_index`.
fn get(&self, multi_index: [usize; NDIM]) -> Option<&Self::Item>;
}

/// This trait provides bounds checked mutable access to the underlying data.
pub trait RandomAccessMut<const NDIM: usize>: UnsafeRandomAccessMut<NDIM> {
/// Return a mutable reference to the element at position determined by `multi_index`.
fn get_mut(&mut self, multi_index: [usize; NDIM]) -> Option<&mut Self::Item>;
}

/// Return chunks of data of size N;
pub trait ChunkedAccess<const N: usize> {
type Item: Scalar;
fn get_chunk(&self, chunk_index: usize) -> Option<DataChunk<Self::Item, N>>;
}

/// Get raw access to the underlying data.
pub trait RawAccess {
type Item: Scalar;

/// Get a slice of the whole data.
fn data(&self) -> &[Self::Item];
}

/// Get mutable raw access to the underlying data.
pub trait RawAccessMut: RawAccess {
/// Get a mutable slice of the whole data.
fn data_mut(&mut self) -> &mut [Self::Item];
}

/// Check if `multi_index` not out of bounds with respect to `shape`.
#[inline]
fn check_dimension<const NDIM: usize>(multi_index: [usize; NDIM], shape: [usize; NDIM]) -> bool {
multi_index
.iter()
.zip(shape.iter())
.fold(true, |acc, (ind, s)| acc && (ind < s))
}

impl<
Item: Scalar,
Mat: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
const NDIM: usize,
> RandomAccessByValue<NDIM> for Mat
{
fn get_value(&self, multi_index: [usize; NDIM]) -> Option<Self::Item> {
if check_dimension(multi_index, self.shape()) {
Some(unsafe { self.get_value_unchecked(multi_index) })
} else {
None
}
}
}

impl<
Item: Scalar,
Mat: UnsafeRandomAccessMut<NDIM, Item = Item> + Shape<NDIM>,
const NDIM: usize,
> RandomAccessMut<NDIM> for Mat
{
fn get_mut(&mut self, multi_index: [usize; NDIM]) -> Option<&mut Self::Item> {
if check_dimension(multi_index, self.shape()) {
unsafe { Some(self.get_unchecked_mut(multi_index)) }
} else {
None
}
}
}

impl<
Item: Scalar,
Mat: UnsafeRandomAccessByRef<NDIM, Item = Item> + Shape<NDIM>,
const NDIM: usize,
> RandomAccessByRef<NDIM> for Mat
{
fn get(&self, multi_index: [usize; NDIM]) -> Option<&Self::Item> {
if check_dimension(multi_index, self.shape()) {
unsafe { Some(self.get_unchecked(multi_index)) }
} else {
None
}
}
}
5 changes: 0 additions & 5 deletions common/src/traits/in_place_operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,3 @@ pub trait SetDiag {
/// the length of the diagonal.
fn set_diag_from_slice(&mut self, diag: &[Self::Item]);
}

/// Resize an operator in place
pub trait ResizeInPlace<const NDIM: usize> {
fn resize_in_place(&mut self, shape: [usize; NDIM]);
}
20 changes: 0 additions & 20 deletions common/src/traits/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,6 @@ pub trait ColumnMajorIterator {
fn iter_col_major(&self) -> Self::Iter<'_>;
}

/// Default iterator.
pub trait DefaultIterator {
type Item: Scalar;
type Iter<'a>: std::iter::Iterator<Item = Self::Item>
where
Self: 'a;

fn iter(&self) -> Self::Iter<'_>;
}

/// Mutable iterator through the elements in column-major ordering.
pub trait ColumnMajorIteratorMut {
type Item: Scalar;
Expand All @@ -44,16 +34,6 @@ pub trait ColumnMajorIteratorMut {
fn iter_col_major_mut(&mut self) -> Self::IterMut<'_>;
}

/// Mutable default iterator.
pub trait DefaultIteratorMut {
type Item: Scalar;
type IterMut<'a>: std::iter::Iterator<Item = &'a mut Self::Item>
where
Self: 'a;

fn iter_mut(&mut self) -> Self::IterMut<'_>;
}

/// Iterate through the diagonal.
pub trait DiagonalIterator {
type Item: Scalar;
Expand Down
54 changes: 0 additions & 54 deletions common/src/traits/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,57 +130,3 @@ pub trait PermuteRows {

fn permute_rows(&self, permutation: &[usize]) -> Self::Out;
}

/// Multiply First * Second and sum into Self
pub trait MultInto<First, Second> {
type Item: Scalar;
fn simple_mult_into(self, arr_a: First, arr_b: Second) -> Self
where
Self: Sized,
{
self.mult_into(
TransMode::NoTrans,
TransMode::NoTrans,
<Self::Item as num::One>::one(),
arr_a,
arr_b,
<Self::Item as num::Zero>::zero(),
)
}
fn mult_into(
self,
transa: TransMode,
transb: TransMode,
alpha: Self::Item,
arr_a: First,
arr_b: Second,
beta: Self::Item,
) -> Self;
}

/// Multiply First * Second and sum into Self. Allow to resize Self if necessary
pub trait MultIntoResize<First, Second> {
type Item: Scalar;
fn simple_mult_into_resize(self, arr_a: First, arr_b: Second) -> Self
where
Self: Sized,
{
self.mult_into_resize(
TransMode::NoTrans,
TransMode::NoTrans,
<Self::Item as num::One>::one(),
arr_a,
arr_b,
<Self::Item as num::Zero>::zero(),
)
}
fn mult_into_resize(
self,
transa: TransMode,
transb: TransMode,
alpha: Self::Item,
arr_a: First,
arr_b: Second,
beta: Self::Item,
) -> Self;
}
26 changes: 0 additions & 26 deletions common/src/traits/properties.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1 @@
//! Traits describing properties of objects.

/// Return the shape of the object.
pub trait Shape<const NDIM: usize> {
fn shape(&self) -> [usize; NDIM];

/// Return true if a dimension is 0.
fn is_empty(&self) -> bool {
let shape = self.shape();
for elem in shape {
if elem == 0 {
return true;
}
}
false
}
}

/// Return the stride of the object.
pub trait Stride<const NDIM: usize> {
fn stride(&self) -> [usize; NDIM];
}

/// Return the number of elements.
pub trait NumberOfElements {
fn number_of_elements(&self) -> usize;
}
13 changes: 1 addition & 12 deletions dense/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,7 @@
use crate::base_array::BaseArray;
use crate::data_container::VectorContainer;
use rlst_common::traits::ChunkedAccess;
use rlst_common::traits::NumberOfElements;
use rlst_common::traits::RandomAccessByRef;
use rlst_common::traits::RandomAccessMut;
use rlst_common::traits::RawAccess;
use rlst_common::traits::RawAccessMut;
use rlst_common::traits::ResizeInPlace;
use rlst_common::traits::Shape;
use rlst_common::traits::Stride;
use rlst_common::traits::UnsafeRandomAccessByRef;
use rlst_common::traits::UnsafeRandomAccessByValue;
use rlst_common::traits::UnsafeRandomAccessMut;
use crate::traits::*;
use rlst_common::types::DataChunk;
use rlst_common::types::Scalar;

Expand Down
2 changes: 1 addition & 1 deletion dense/src/array/empty_axis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ where

#[cfg(test)]
mod test {
use rlst_common::traits::{Shape, Stride};
use crate::traits::{Shape, Stride};

use crate::{array::empty_axis::AxisPosition, rlst_dynamic_array3};

Expand Down
6 changes: 3 additions & 3 deletions dense/src/array/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl<
Item: Scalar,
ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
const NDIM: usize,
> rlst_common::traits::iterators::DefaultIterator for Array<Item, ArrayImpl, NDIM>
> crate::traits::DefaultIterator for Array<Item, ArrayImpl, NDIM>
{
type Item = Item;
type Iter<'a> = ArrayDefaultIterator<'a, Item, ArrayImpl, NDIM> where Self: 'a;
Expand All @@ -165,7 +165,7 @@ impl<
+ Shape<NDIM>
+ UnsafeRandomAccessMut<NDIM, Item = Item>,
const NDIM: usize,
> rlst_common::traits::iterators::DefaultIteratorMut for Array<Item, ArrayImpl, NDIM>
> crate::traits::DefaultIteratorMut for Array<Item, ArrayImpl, NDIM>
{
type Item = Item;
type IterMut<'a> = ArrayDefaultIteratorMut<'a, Item, ArrayImpl, NDIM> where Self: 'a;
Expand All @@ -178,7 +178,7 @@ impl<
#[cfg(test)]
mod test {

use rlst_common::traits::*;
use crate::traits::*;

#[test]
fn test_iter() {
Expand Down
4 changes: 2 additions & 2 deletions dense/src/array/mult_into.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! Multiplication of Arrays
use crate::traits::MultInto;
use crate::traits::MultIntoResize;
use rlst_blis::interface::gemm::Gemm;
pub use rlst_blis::interface::types::TransMode;
use rlst_common::traits::MultInto;
use rlst_common::traits::MultIntoResize;

use super::{empty_axis::AxisPosition, *};

Expand Down
24 changes: 8 additions & 16 deletions dense/src/array/operators/to_complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,41 +82,33 @@ impl<
}

impl<ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = f32> + Shape<NDIM>, const NDIM: usize>
rlst_common::traits::ToComplex for Array<f32, ArrayImpl, NDIM>
Array<f32, ArrayImpl, NDIM>
{
type Out = Array<c32, ArrayToComplex<c32, ArrayImpl, NDIM>, NDIM>;

fn to_complex(self) -> Self::Out {
pub fn to_complex(self) -> Array<c32, ArrayToComplex<c32, ArrayImpl, NDIM>, NDIM> {
Array::new(ArrayToComplex::new(self))
}
}

impl<ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = f64> + Shape<NDIM>, const NDIM: usize>
rlst_common::traits::ToComplex for Array<f64, ArrayImpl, NDIM>
Array<f64, ArrayImpl, NDIM>
{
type Out = Array<c64, ArrayToComplex<c64, ArrayImpl, NDIM>, NDIM>;

fn to_complex(self) -> Self::Out {
pub fn to_complex(self) -> Array<c64, ArrayToComplex<c64, ArrayImpl, NDIM>, NDIM> {
Array::new(ArrayToComplex::new(self))
}
}

impl<ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = c64> + Shape<NDIM>, const NDIM: usize>
rlst_common::traits::ToComplex for Array<c64, ArrayImpl, NDIM>
Array<c64, ArrayImpl, NDIM>
{
type Out = Self;

fn to_complex(self) -> Self::Out {
pub fn to_complex(self) -> Self {
self
}
}

impl<ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = c32> + Shape<NDIM>, const NDIM: usize>
rlst_common::traits::ToComplex for Array<c32, ArrayImpl, NDIM>
Array<c32, ArrayImpl, NDIM>
{
type Out = Self;

fn to_complex(self) -> Self::Out {
pub fn to_complex(self) -> Self {
self
}
}
Loading

0 comments on commit d80e827

Please sign in to comment.