Skip to content

Commit

Permalink
WIP: Tests and clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke committed Sep 20, 2023
1 parent b35c487 commit e2cb30d
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 56 deletions.
20 changes: 9 additions & 11 deletions dense/src/array/iterators.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Various iterator implementations
use crate::array::*;
use crate::layout::{convert_1d_nd, stride_from_shape};
use crate::layout::convert_1d_nd_from_shape;
use rlst_common::traits::{RandomAccessByValue, RandomAccessMut};
use rlst_common::types::Scalar;

Expand All @@ -12,7 +12,7 @@ pub struct ArrayDefaultIterator<
const NDIM: usize,
> {
arr: &'a Array<Item, ArrayImpl, NDIM>,
stride: [usize; NDIM],
shape: [usize; NDIM],
pos: usize,
}

Expand All @@ -25,7 +25,7 @@ pub struct ArrayDefaultIteratorMut<
const NDIM: usize,
> {
arr: &'a mut Array<Item, ArrayImpl, NDIM>,
stride: [usize; NDIM],
shape: [usize; NDIM],
pos: usize,
}

Expand All @@ -39,7 +39,7 @@ impl<
fn new(arr: &'a Array<Item, ArrayImpl, NDIM>) -> Self {
Self {
arr,
stride: stride_from_shape(arr.shape()),
shape: arr.shape(),
pos: 0,
}
}
Expand All @@ -56,11 +56,7 @@ impl<
{
fn new(arr: &'a mut Array<Item, ArrayImpl, NDIM>) -> Self {
let shape = arr.shape();
Self {
arr,
stride: stride_from_shape(shape),
pos: 0,
}
Self { arr, shape, pos: 0 }
}
}

Expand All @@ -73,7 +69,8 @@ impl<
{
type Item = Item;
fn next(&mut self) -> Option<Self::Item> {
let elem = self.arr.get_value(convert_1d_nd(self.pos, self.stride));
let indices = convert_1d_nd_from_shape(self.pos, self.shape)?;
let elem = self.arr.get_value(indices);
self.pos += 1;
elem
}
Expand All @@ -97,7 +94,8 @@ impl<
{
type Item = &'a mut Item;
fn next(&mut self) -> Option<Self::Item> {
let elem = self.arr.get_mut(convert_1d_nd(self.pos, self.stride));
let indices = convert_1d_nd_from_shape(self.pos, self.shape)?;
let elem = self.arr.get_mut(indices);
self.pos += 1;
match elem {
None => None,
Expand Down
15 changes: 7 additions & 8 deletions dense/src/array/operations.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Operations on arrays
use crate::layout::{convert_1d_nd, stride_from_shape};
use crate::layout::convert_1d_nd_from_shape;

use super::*;
use rlst_common::traits::*;
Expand Down Expand Up @@ -39,7 +39,6 @@ impl<
other: Other,
) {
assert_eq!(self.shape(), other.shape());
let stride = stride_from_shape(self.shape());

let mut chunk_index = 0;

Expand All @@ -48,11 +47,11 @@ impl<

for data_index in 0..chunk.valid_entries {
unsafe {
*self.get_unchecked_mut(convert_1d_nd(data_start + data_index, stride)) =
chunk.data[data_index];
*self.get_unchecked_mut(
convert_1d_nd_from_shape(data_start + data_index, self.shape()).unwrap(),
) = chunk.data[data_index];
}
}

chunk_index += 1;
}
}
Expand Down Expand Up @@ -94,7 +93,6 @@ impl<
Self: ChunkedAccess<N, Item = Item>,
{
assert_eq!(self.shape(), other.shape());
let stride = stride_from_shape(self.shape());

let mut chunk_index = 0;

Expand All @@ -109,8 +107,9 @@ impl<

for data_index in 0..chunk.valid_entries {
unsafe {
*self.get_unchecked_mut(convert_1d_nd(data_index + data_start, stride)) =
my_chunk.data[data_index];
*self.get_unchecked_mut(
convert_1d_nd_from_shape(data_index + data_start, self.shape()).unwrap(),
) = my_chunk.data[data_index];
}
}

Expand Down
28 changes: 14 additions & 14 deletions dense/src/array/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@ pub mod addition;
pub mod other;
pub mod scalar_mult;

// pub fn test_simd() {
// use crate::rlst_dynamic_array2;
pub fn test_simd() {
use crate::rlst_dynamic_array2;

// let shape = [200, 300];
// let mut arr1 = rlst_dynamic_array2!(f32, shape);
// let mut arr2 = rlst_dynamic_array2!(f32, shape);
// let mut res = rlst_dynamic_array2!(f32, shape);
let shape = [200, 300];
let mut arr1 = rlst_dynamic_array2!(f32, shape);
let mut arr2 = rlst_dynamic_array2!(f32, shape);
let mut res = rlst_dynamic_array2!(f32, shape);

// arr1.fill_from_seed_equally_distributed(0);
// arr2.fill_from_seed_equally_distributed(0);
arr1.fill_from_seed_equally_distributed(0);
arr2.fill_from_seed_equally_distributed(0);

// // let arr3 = arr1.view() + arr2.view();
// let arr3 = arr1.view() + arr2.view();

// let arr3 = 3.0 * arr1 + arr2;
let arr3 = 3.0 * arr1 + arr2;

// res.fill_from_chunked::<_, 512>(arr3.view());
// //res.fill_from(arr3.view());
res.fill_from_chunked::<_, 512>(arr3.view());
//res.fill_from(arr3.view());

// println!("{}", res[[0, 0]]);
// }
println!("{}", res[[0, 0]]);
}
16 changes: 9 additions & 7 deletions dense/src/array/views.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Views onto an array
use crate::layout::{convert_1d_nd, stride_from_shape};
use crate::layout::convert_1d_nd_from_shape;

use super::Array;
use rlst_common::traits::*;
Expand Down Expand Up @@ -148,12 +148,13 @@ impl<
} else {
// If the view is on a subsection of the array have to recalcuate the chunk
let nelements = self.shape().iter().product();
let stride = stride_from_shape(self.shape());
if let Some(mut chunk) = super::empty_chunk(chunk_index, nelements) {
for count in 0..chunk.valid_entries {
unsafe {
chunk.data[count] = self
.get_value_unchecked(convert_1d_nd(chunk.start_index + count, stride));
chunk.data[count] = self.get_value_unchecked(
convert_1d_nd_from_shape(chunk.start_index + count, self.shape())
.unwrap(),
)
}
}
Some(chunk)
Expand Down Expand Up @@ -238,12 +239,13 @@ impl<
} else {
// If the view is on a subsection of the array have to recalcuate the chunk
let nelements = self.shape().iter().product();
let stride = stride_from_shape(self.shape());
if let Some(mut chunk) = super::empty_chunk(chunk_index, nelements) {
for count in 0..chunk.valid_entries {
unsafe {
chunk.data[count] = self
.get_value_unchecked(convert_1d_nd(chunk.start_index + count, stride));
chunk.data[count] = self.get_value_unchecked(
convert_1d_nd_from_shape(chunk.start_index + count, self.shape())
.unwrap(),
)
}
}
Some(chunk)
Expand Down
11 changes: 7 additions & 4 deletions dense/src/base_array.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::array::empty_chunk;
use crate::data_container::{DataContainer, DataContainerMut};
use crate::layout::{convert_1d_nd, convert_nd_raw, stride_from_shape};
use rlst_common::traits::{ChunkedAccess, UnsafeRandomAccessByValue, UnsafeRandomAccessMut};
use crate::layout::{convert_1d_nd_from_shape, convert_nd_raw, stride_from_shape};
use rlst_common::traits::{
ChunkedAccess, RandomAccessByValue, UnsafeRandomAccessByValue, UnsafeRandomAccessMut,
};
use rlst_common::{
traits::{Shape, UnsafeRandomAccessByRef},
types::Scalar,
Expand Down Expand Up @@ -92,8 +94,9 @@ impl<Item: Scalar, Data: DataContainerMut<Item = Item>, const N: usize, const ND
if let Some(mut chunk) = empty_chunk(chunk_index, nelements) {
for count in 0..chunk.valid_entries {
unsafe {
chunk.data[count] = self
.get_value_unchecked(convert_1d_nd(chunk.start_index + count, self.stride));
chunk.data[count] = self.get_value_unchecked(
convert_1d_nd_from_shape(chunk.start_index + count, self.shape()).unwrap(),
)
}
}
Some(chunk)
Expand Down
25 changes: 15 additions & 10 deletions dense/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,19 @@ pub fn convert_nd_raw<const NDIM: usize>(indices: [usize; NDIM], stride: [usize;
pub fn convert_1d_nd_from_shape<const NDIM: usize>(
mut index: usize,
shape: [usize; NDIM],
) -> [usize; NDIM] {
) -> Option<[usize; NDIM]> {
let mut res = [0; NDIM];
let stride = stride_from_shape(shape);
let nelements = shape.iter().product();
if index >= nelements {
None
} else {
for ind in 0..NDIM {
res[ind] = index % shape[ind];
index /= shape[ind];
}

for ind in (0..NDIM).rev() {
res[ind] = index / stride[ind];
index %= stride[ind];
Some(res)
}

res
}

#[cfg(test)]
Expand All @@ -51,12 +54,14 @@ mod test {

#[test]
fn test_convert_1d_nd() {
let indices: [usize; 4] = [3, 7, 14, 5];
let shape: [usize; 4] = [4, 15, 17, 6];
let indices: [usize; 3] = [2, 3, 7];
let shape: [usize; 3] = [3, 4, 8];
let stride = stride_from_shape(shape);

let index_1d = convert_nd_raw(indices, stride);
let actual_nd = convert_1d_nd(index_1d, stride);
let actual_nd = convert_1d_nd_from_shape(index_1d, shape).unwrap();

println!("{}, {:#?}", index_1d, actual_nd);

for (&expected, actual) in indices.iter().zip(actual_nd) {
assert_eq!(expected, actual)
Expand Down
4 changes: 2 additions & 2 deletions rlst/tests/array_operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use rlst::rlst_dynamic_array3;
use rlst_common::traits::*;
use rlst_dense::layout::{convert_1d_nd, stride_from_shape};
use rlst_dense::layout::convert_1d_nd_from_shape;

#[test]
fn test_addition() {
Expand Down Expand Up @@ -61,7 +61,7 @@ fn test_multiple_operations() {
res.sum_into_chunked::<_, 64>(arr3.view());

for index in 0..nelements {
let indices = convert_1d_nd(index, stride_from_shape(shape));
let indices = convert_1d_nd_from_shape(index, res.shape()).unwrap();

approx::assert_relative_eq!(
res[indices],
Expand Down

0 comments on commit e2cb30d

Please sign in to comment.