Skip to content

Commit

Permalink
WIP: Tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke committed Sep 29, 2023
1 parent ac2d9a0 commit 1afe700
Showing 1 changed file with 67 additions and 2 deletions.
69 changes: 67 additions & 2 deletions rlst/tests/array_operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,75 @@ fn test_conj() {
#[test]
fn test_transpose() {
let shape = [3, 4, 8];
let permutation = [1, 2, 0];
let new_shape = [4, 8, 3];
let mut arr1 = rlst_dynamic_array3!(f64, shape);
arr1.fill_from_seed_equally_distributed(0);
let mut res = rlst_dynamic_array3!(f64, new_shape);
let mut res_chunked = rlst_dynamic_array3!(f64, new_shape);
let mut expected = rlst_dynamic_array3!(f64, new_shape);

let res = arr1.permute_axes([1, 2, 0]);
let arr3 = arr1.view().permute_axes(permutation);
res.fill_from(arr3.view());
res_chunked.fill_from_chunked::<_, 31>(arr3.view());

assert_eq!(arr3.shape(), [shape[1], shape[2], shape[0]]);

for (multi_index, elem) in arr3.iter().enumerate().multi_index(res.shape()) {
let original_index = [multi_index[2], multi_index[0], multi_index[1]];

assert_eq!(elem, arr1[original_index]);
expected[multi_index] = elem;
}

assert_array_relative_eq!(res, expected, 1E-14);
assert_array_relative_eq!(res_chunked, expected, 1E-14);
}

#[test]
fn test_cmp_wise_division() {
let shape = [3, 4, 8];
let mut arr1 = rlst_dynamic_array3!(f64, shape);
let mut arr2 = rlst_dynamic_array3!(f64, shape);
let mut res_chunked = rlst_dynamic_array3!(f64, shape);
let mut res = rlst_dynamic_array3!(f64, shape);
let mut expected = rlst_dynamic_array3!(f64, shape);

arr1.fill_from_seed_equally_distributed(0);
arr2.fill_from_seed_equally_distributed(1);

let arr3 = arr1.view() / arr2.view();

res_chunked.fill_from_chunked::<_, 31>(arr3.view());
res.fill_from(arr3.view());

assert_eq!(res.shape(), [4, 8, 3]);
for (multi_index, elem) in expected.iter_mut().enumerate().multi_index(shape) {
*elem = arr1[multi_index] / arr2[multi_index];
}

assert_array_relative_eq!(res_chunked, expected, 1E-14);
assert_array_relative_eq!(res, expected, 1E-14);
}

#[test]
fn test_to_complex() {
let shape = [3, 4, 8];
let mut arr1 = rlst_dynamic_array3!(f64, shape);
let mut res_chunked = rlst_dynamic_array3!(c64, shape);
let mut res = rlst_dynamic_array3!(c64, shape);
let mut expected = rlst_dynamic_array3!(c64, shape);

arr1.fill_from_seed_equally_distributed(0);

let arr3 = arr1.view().to_complex();

res_chunked.fill_from_chunked::<_, 31>(arr3.view());
res.fill_from(arr3.view());

for (multi_index, elem) in expected.iter_mut().enumerate().multi_index(shape) {
*elem = c64::from_real(arr1[multi_index]);
}

assert_array_relative_eq!(res_chunked, expected, 1E-14);
assert_array_relative_eq!(res, expected, 1E-14);
}

0 comments on commit 1afe700

Please sign in to comment.