Skip to content

Commit

Permalink
Implemented chunked evaluation.
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke committed Sep 20, 2023
1 parent e2cb30d commit 364d9e8
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 44 deletions.
32 changes: 22 additions & 10 deletions dense/src/array/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
use crate::array::*;
use crate::layout::convert_1d_nd_from_shape;
use rlst_common::traits::{RandomAccessByValue, RandomAccessMut};
use rlst_common::types::Scalar;

pub struct ArrayDefaultIterator<
Expand All @@ -14,6 +13,7 @@ pub struct ArrayDefaultIterator<
arr: &'a Array<Item, ArrayImpl, NDIM>,
shape: [usize; NDIM],
pos: usize,
nelements: usize,
}

pub struct ArrayDefaultIteratorMut<
Expand All @@ -27,6 +27,7 @@ pub struct ArrayDefaultIteratorMut<
arr: &'a mut Array<Item, ArrayImpl, NDIM>,
shape: [usize; NDIM],
pos: usize,
nelements: usize,
}

impl<
Expand All @@ -41,6 +42,7 @@ impl<
arr,
shape: arr.shape(),
pos: 0,
nelements: arr.shape().iter().product(),
}
}
}
Expand All @@ -56,7 +58,12 @@ impl<
{
fn new(arr: &'a mut Array<Item, ArrayImpl, NDIM>) -> Self {
let shape = arr.shape();
Self { arr, shape, pos: 0 }
Self {
arr,
shape,
pos: 0,
nelements: shape.iter().product(),
}
}
}

Expand All @@ -69,10 +76,12 @@ impl<
{
type Item = Item;
fn next(&mut self) -> Option<Self::Item> {
let indices = convert_1d_nd_from_shape(self.pos, self.shape)?;
let elem = self.arr.get_value(indices);
if self.pos >= self.nelements {
return None;
}
let indices = convert_1d_nd_from_shape(self.pos, self.shape);
self.pos += 1;
elem
unsafe { Some(self.arr.get_value_unchecked(indices)) }
}
}

Expand All @@ -94,12 +103,15 @@ impl<
{
type Item = &'a mut Item;
fn next(&mut self) -> Option<Self::Item> {
let indices = convert_1d_nd_from_shape(self.pos, self.shape)?;
let elem = self.arr.get_mut(indices);
if self.pos >= self.nelements {
return None;
}
let indices = convert_1d_nd_from_shape(self.pos, self.shape);
self.pos += 1;
match elem {
None => None,
Some(inner) => Some(unsafe { std::mem::transmute::<&mut Item, &'a mut Item>(inner) }),
unsafe {
Some(std::mem::transmute::<&mut Item, &'a mut Item>(
self.arr.get_unchecked_mut(indices),
))
}
}
}
Expand Down
14 changes: 8 additions & 6 deletions dense/src/array/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ impl<

for data_index in 0..chunk.valid_entries {
unsafe {
*self.get_unchecked_mut(
convert_1d_nd_from_shape(data_start + data_index, self.shape()).unwrap(),
) = chunk.data[data_index];
*self.get_unchecked_mut(convert_1d_nd_from_shape(
data_start + data_index,
self.shape(),
)) = chunk.data[data_index];
}
}
chunk_index += 1;
Expand Down Expand Up @@ -107,9 +108,10 @@ impl<

for data_index in 0..chunk.valid_entries {
unsafe {
*self.get_unchecked_mut(
convert_1d_nd_from_shape(data_index + data_start, self.shape()).unwrap(),
) = my_chunk.data[data_index];
*self.get_unchecked_mut(convert_1d_nd_from_shape(
data_index + data_start,
self.shape(),
)) = my_chunk.data[data_index];
}
}

Expand Down
23 changes: 14 additions & 9 deletions dense/src/array/views.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Views onto an array
use crate::layout::convert_1d_nd_from_shape;
use crate::layout::{check_indices_in_bounds, convert_1d_nd_from_shape};

use super::Array;
use rlst_common::traits::*;
Expand Down Expand Up @@ -107,6 +107,7 @@ impl<
type Item = Item;
#[inline]
unsafe fn get_value_unchecked(&self, indices: [usize; NDIM]) -> Self::Item {
debug_assert!(check_indices_in_bounds(indices, self.shape()));
self.arr
.get_value_unchecked(offset_indices(indices, self.offset))
}
Expand All @@ -124,6 +125,7 @@ impl<
type Item = Item;
#[inline]
unsafe fn get_unchecked(&self, indices: [usize; NDIM]) -> &Self::Item {
debug_assert!(check_indices_in_bounds(indices, self.shape()));
self.arr.get_unchecked(offset_indices(indices, self.offset))
}
}
Expand Down Expand Up @@ -151,10 +153,10 @@ impl<
if let Some(mut chunk) = super::empty_chunk(chunk_index, nelements) {
for count in 0..chunk.valid_entries {
unsafe {
chunk.data[count] = self.get_value_unchecked(
convert_1d_nd_from_shape(chunk.start_index + count, self.shape())
.unwrap(),
)
chunk.data[count] = self.get_value_unchecked(convert_1d_nd_from_shape(
chunk.start_index + count,
self.shape(),
))
}
}
Some(chunk)
Expand Down Expand Up @@ -193,6 +195,7 @@ impl<
type Item = Item;
#[inline]
unsafe fn get_value_unchecked(&self, indices: [usize; NDIM]) -> Self::Item {
debug_assert!(check_indices_in_bounds(indices, self.shape()));
self.arr
.get_value_unchecked(offset_indices(indices, self.offset))
}
Expand All @@ -211,6 +214,7 @@ impl<
type Item = Item;
#[inline]
unsafe fn get_unchecked(&self, indices: [usize; NDIM]) -> &Self::Item {
debug_assert!(check_indices_in_bounds(indices, self.shape()));
self.arr.get_unchecked(offset_indices(indices, self.offset))
}
}
Expand Down Expand Up @@ -242,10 +246,10 @@ impl<
if let Some(mut chunk) = super::empty_chunk(chunk_index, nelements) {
for count in 0..chunk.valid_entries {
unsafe {
chunk.data[count] = self.get_value_unchecked(
convert_1d_nd_from_shape(chunk.start_index + count, self.shape())
.unwrap(),
)
chunk.data[count] = self.get_value_unchecked(convert_1d_nd_from_shape(
chunk.start_index + count,
self.shape(),
))
}
}
Some(chunk)
Expand All @@ -268,6 +272,7 @@ impl<
type Item = Item;
#[inline]
unsafe fn get_unchecked_mut(&mut self, indices: [usize; NDIM]) -> &mut Self::Item {
debug_assert!(check_indices_in_bounds(indices, self.shape()));
self.arr
.get_unchecked_mut(offset_indices(indices, self.offset))
}
Expand Down
16 changes: 10 additions & 6 deletions dense/src/base_array.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::array::empty_chunk;
use crate::data_container::{DataContainer, DataContainerMut};
use crate::layout::{convert_1d_nd_from_shape, convert_nd_raw, stride_from_shape};
use rlst_common::traits::{
ChunkedAccess, RandomAccessByValue, UnsafeRandomAccessByValue, UnsafeRandomAccessMut,
use crate::layout::{
check_indices_in_bounds, convert_1d_nd_from_shape, convert_nd_raw, stride_from_shape,
};
use rlst_common::traits::{ChunkedAccess, UnsafeRandomAccessByValue, UnsafeRandomAccessMut};
use rlst_common::{
traits::{Shape, UnsafeRandomAccessByRef},
types::Scalar,
Expand Down Expand Up @@ -51,6 +51,7 @@ impl<Item: Scalar, Data: DataContainer<Item = Item>, const NDIM: usize>

#[inline]
unsafe fn get_unchecked(&self, indices: [usize; NDIM]) -> &Self::Item {
debug_assert!(check_indices_in_bounds(indices, self.shape()));
let index = convert_nd_raw(indices, self.stride);
self.data.get_unchecked(index)
}
Expand All @@ -63,6 +64,7 @@ impl<Item: Scalar, Data: DataContainer<Item = Item>, const NDIM: usize>

#[inline]
unsafe fn get_value_unchecked(&self, indices: [usize; NDIM]) -> Self::Item {
debug_assert!(check_indices_in_bounds(indices, self.shape()));
let index = convert_nd_raw(indices, self.stride);
self.data.get_unchecked_value(index)
}
Expand All @@ -75,6 +77,7 @@ impl<Item: Scalar, Data: DataContainerMut<Item = Item>, const NDIM: usize>

#[inline]
unsafe fn get_unchecked_mut(&mut self, indices: [usize; NDIM]) -> &mut Self::Item {
debug_assert!(check_indices_in_bounds(indices, self.shape()));
let index = convert_nd_raw(indices, self.stride);
self.data.get_unchecked_mut(index)
}
Expand All @@ -94,9 +97,10 @@ impl<Item: Scalar, Data: DataContainerMut<Item = Item>, const N: usize, const ND
if let Some(mut chunk) = empty_chunk(chunk_index, nelements) {
for count in 0..chunk.valid_entries {
unsafe {
chunk.data[count] = self.get_value_unchecked(
convert_1d_nd_from_shape(chunk.start_index + count, self.shape()).unwrap(),
)
chunk.data[count] = self.get_value_unchecked(convert_1d_nd_from_shape(
chunk.start_index + count,
self.shape(),
))
}
}
Some(chunk)
Expand Down
11 changes: 11 additions & 0 deletions dense/src/data_container.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,12 @@ impl<Item: Scalar> DataContainer for VectorContainer<Item> {
type Item = Item;

unsafe fn get_unchecked_value(&self, index: usize) -> Self::Item {
debug_assert!(index < self.number_of_elements());
*self.data.get_unchecked(index)
}

unsafe fn get_unchecked(&self, index: usize) -> &Self::Item {
debug_assert!(index < self.number_of_elements());
self.data.get_unchecked(index)
}

Expand All @@ -183,6 +185,7 @@ impl<Item: Scalar> DataContainer for VectorContainer<Item> {

impl<Item: Scalar> DataContainerMut for VectorContainer<Item> {
unsafe fn get_unchecked_mut(&mut self, index: usize) -> &mut Self::Item {
debug_assert!(index < self.number_of_elements());
self.data.get_unchecked_mut(index)
}

Expand All @@ -195,10 +198,12 @@ impl<Item: Scalar, const N: usize> DataContainer for ArrayContainer<Item, N> {
type Item = Item;

unsafe fn get_unchecked_value(&self, index: usize) -> Self::Item {
debug_assert!(index < self.number_of_elements());
*self.data.get_unchecked(index)
}

unsafe fn get_unchecked(&self, index: usize) -> &Self::Item {
debug_assert!(index < self.number_of_elements());
self.data.get_unchecked(index)
}

Expand All @@ -213,6 +218,7 @@ impl<Item: Scalar, const N: usize> DataContainer for ArrayContainer<Item, N> {

impl<Item: Scalar, const N: usize> DataContainerMut for ArrayContainer<Item, N> {
unsafe fn get_unchecked_mut(&mut self, index: usize) -> &mut Self::Item {
debug_assert!(index < self.number_of_elements());
self.data.get_unchecked_mut(index)
}

Expand All @@ -225,10 +231,12 @@ impl<'a, Item: Scalar> DataContainer for SliceContainer<'a, Item> {
type Item = Item;

unsafe fn get_unchecked_value(&self, index: usize) -> Self::Item {
debug_assert!(index < self.number_of_elements());
*self.data.get_unchecked(index)
}

unsafe fn get_unchecked(&self, index: usize) -> &Self::Item {
debug_assert!(index < self.number_of_elements());
self.data.get_unchecked(index)
}

Expand All @@ -245,10 +253,12 @@ impl<'a, Item: Scalar> DataContainer for SliceContainerMut<'a, Item> {
type Item = Item;

unsafe fn get_unchecked_value(&self, index: usize) -> Self::Item {
debug_assert!(index < self.number_of_elements());
*self.data.get_unchecked(index)
}

unsafe fn get_unchecked(&self, index: usize) -> &Self::Item {
debug_assert!(index < self.number_of_elements());
self.data.get_unchecked(index)
}

Expand All @@ -263,6 +273,7 @@ impl<'a, Item: Scalar> DataContainer for SliceContainerMut<'a, Item> {

impl<'a, Item: Scalar> DataContainerMut for SliceContainerMut<'a, Item> {
unsafe fn get_unchecked_mut(&mut self, index: usize) -> &mut Self::Item {
debug_assert!(index < self.number_of_elements());
self.data.get_unchecked_mut(index)
}

Expand Down
29 changes: 17 additions & 12 deletions dense/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ pub fn stride_from_shape<const NDIM: usize>(shape: [usize; NDIM]) -> [usize; NDI
output
}

/// Return true if `indices` in bounds with respect to `shape`.
pub fn check_indices_in_bounds<const N: usize>(indices: [usize; N], shape: [usize; N]) -> bool {
for (ind, s) in indices.iter().zip(shape.iter()) {
if ind >= s {
return false;
}
}
true
}

/// Convert an n-d index into a 1d index.
#[inline]
pub fn convert_nd_raw<const NDIM: usize>(indices: [usize; NDIM], stride: [usize; NDIM]) -> usize {
Expand All @@ -33,19 +43,14 @@ pub fn convert_nd_raw<const NDIM: usize>(indices: [usize; NDIM], stride: [usize;
pub fn convert_1d_nd_from_shape<const NDIM: usize>(
mut index: usize,
shape: [usize; NDIM],
) -> Option<[usize; NDIM]> {
) -> [usize; NDIM] {
let mut res = [0; NDIM];
let nelements = shape.iter().product();
if index >= nelements {
None
} else {
for ind in 0..NDIM {
res[ind] = index % shape[ind];
index /= shape[ind];
}

Some(res)
debug_assert!(index < shape.iter().product());
for ind in 0..NDIM {
res[ind] = index % shape[ind];
index /= shape[ind];
}
res
}

#[cfg(test)]
Expand All @@ -59,7 +64,7 @@ mod test {
let stride = stride_from_shape(shape);

let index_1d = convert_nd_raw(indices, stride);
let actual_nd = convert_1d_nd_from_shape(index_1d, shape).unwrap();
let actual_nd = convert_1d_nd_from_shape(index_1d, shape);

println!("{}, {:#?}", index_1d, actual_nd);

Expand Down
2 changes: 1 addition & 1 deletion rlst/tests/array_operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn test_multiple_operations() {
res.sum_into_chunked::<_, 64>(arr3.view());

for index in 0..nelements {
let indices = convert_1d_nd_from_shape(index, res.shape()).unwrap();
let indices = convert_1d_nd_from_shape(index, res.shape());

approx::assert_relative_eq!(
res[indices],
Expand Down

0 comments on commit 364d9e8

Please sign in to comment.