From 1afe700495b9f37c055600239872aed36eca8661 Mon Sep 17 00:00:00 2001 From: Timo Betcke Date: Fri, 29 Sep 2023 21:40:45 +0100 Subject: [PATCH] WIP: Tensors --- rlst/tests/array_operations.rs | 69 +++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/rlst/tests/array_operations.rs b/rlst/tests/array_operations.rs index c42de41e..cbe9ea24 100644 --- a/rlst/tests/array_operations.rs +++ b/rlst/tests/array_operations.rs @@ -144,10 +144,75 @@ fn test_conj() { #[test] fn test_transpose() { let shape = [3, 4, 8]; + let permutation = [1, 2, 0]; + let new_shape = [4, 8, 3]; let mut arr1 = rlst_dynamic_array3!(f64, shape); arr1.fill_from_seed_equally_distributed(0); + let mut res = rlst_dynamic_array3!(f64, new_shape); + let mut res_chunked = rlst_dynamic_array3!(f64, new_shape); + let mut expected = rlst_dynamic_array3!(f64, new_shape); - let res = arr1.permute_axes([1, 2, 0]); + let arr3 = arr1.view().permute_axes(permutation); + res.fill_from(arr3.view()); + res_chunked.fill_from_chunked::<_, 31>(arr3.view()); + + assert_eq!(arr3.shape(), [shape[1], shape[2], shape[0]]); + + for (multi_index, elem) in arr3.iter().enumerate().multi_index(res.shape()) { + let original_index = [multi_index[2], multi_index[0], multi_index[1]]; + + assert_eq!(elem, arr1[original_index]); + expected[multi_index] = elem; + } + + assert_array_relative_eq!(res, expected, 1E-14); + assert_array_relative_eq!(res_chunked, expected, 1E-14); +} + +#[test] +fn test_cmp_wise_division() { + let shape = [3, 4, 8]; + let mut arr1 = rlst_dynamic_array3!(f64, shape); + let mut arr2 = rlst_dynamic_array3!(f64, shape); + let mut res_chunked = rlst_dynamic_array3!(f64, shape); + let mut res = rlst_dynamic_array3!(f64, shape); + let mut expected = rlst_dynamic_array3!(f64, shape); + + arr1.fill_from_seed_equally_distributed(0); + arr2.fill_from_seed_equally_distributed(1); + + let arr3 = arr1.view() / arr2.view(); + + res_chunked.fill_from_chunked::<_, 31>(arr3.view()); + res.fill_from(arr3.view()); - assert_eq!(res.shape(), [4, 8, 3]); + for (multi_index, elem) in expected.iter_mut().enumerate().multi_index(shape) { + *elem = arr1[multi_index] / arr2[multi_index]; + } + + assert_array_relative_eq!(res_chunked, expected, 1E-14); + assert_array_relative_eq!(res, expected, 1E-14); +} + +#[test] +fn test_to_complex() { + let shape = [3, 4, 8]; + let mut arr1 = rlst_dynamic_array3!(f64, shape); + let mut res_chunked = rlst_dynamic_array3!(c64, shape); + let mut res = rlst_dynamic_array3!(c64, shape); + let mut expected = rlst_dynamic_array3!(c64, shape); + + arr1.fill_from_seed_equally_distributed(0); + + let arr3 = arr1.view().to_complex(); + + res_chunked.fill_from_chunked::<_, 31>(arr3.view()); + res.fill_from(arr3.view()); + + for (multi_index, elem) in expected.iter_mut().enumerate().multi_index(shape) { + *elem = c64::from_real(arr1[multi_index]); + } + + assert_array_relative_eq!(res_chunked, expected, 1E-14); + assert_array_relative_eq!(res, expected, 1E-14); }