diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs index f1e1bc42b..a9dca7e83 100644 --- a/crates/blas-tests/tests/oper.rs +++ b/crates/blas-tests/tests/oper.rs @@ -12,6 +12,7 @@ use ndarray::linalg::general_mat_vec_mul; use ndarray::Order; use ndarray::{Data, Ix, LinalgScalar}; use ndarray_gen::array_builder::ArrayBuilder; +use ndarray_gen::array_builder::ElementGenerator; use approx::assert_relative_eq; use defmac::defmac; @@ -230,7 +231,6 @@ fn gen_mat_mul() let sizes = vec![ (4, 4, 4), (8, 8, 8), - (10, 10, 10), (8, 8, 1), (1, 10, 10), (10, 1, 10), @@ -241,19 +241,23 @@ fn gen_mat_mul() (4, 17, 3), (17, 3, 22), (19, 18, 2), - (16, 17, 15), (15, 16, 17), - (67, 63, 62), + (67, 50, 62), ]; let strides = &[1, 2, -1, -2]; let cf_order = [Order::C, Order::F]; + let generator = [ElementGenerator::Sequential, ElementGenerator::Checkerboard]; // test different strides and memory orders - for (&s1, &s2) in iproduct!(strides, strides) { + for (&s1, &s2, &gen) in iproduct!(strides, strides, &generator) { for &(m, k, n) in &sizes { for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) { - println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3); - let a = ArrayBuilder::new((m, k)).memory_order(ord1).build() * 0.5; + println!("Case s1={}, s2={}, gen={:?}, orders={:?}, {:?}, {:?}", s1, s2, gen, ord1, ord2, ord3); + let a = ArrayBuilder::new((m, k)) + .memory_order(ord1) + .generator(gen) + .build() + * 0.5; let b = ArrayBuilder::new((k, n)).memory_order(ord2).build(); let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build(); diff --git a/crates/ndarray-gen/src/array_builder.rs b/crates/ndarray-gen/src/array_builder.rs index a021e5252..9351aadc5 100644 --- a/crates/ndarray-gen/src/array_builder.rs +++ b/crates/ndarray-gen/src/array_builder.rs @@ -26,6 +26,7 @@ pub struct ArrayBuilder pub enum ElementGenerator { Sequential, + Checkerboard, Zero, } @@ -64,16 +65,14 @@ where D: Dimension pub fn build(self) -> Array where T: Num + Clone { - let mut current = T::zero(); + let zero = T::zero(); let size = self.dim.size(); - let use_zeros = self.generator == ElementGenerator::Zero; - Array::from_iter((0..size).map(|_| { - let ret = current.clone(); - if !use_zeros { - current = ret.clone() + T::one(); - } - ret - })) + (match self.generator { + ElementGenerator::Sequential => + Array::from_iter(core::iter::successors(Some(zero), |elt| Some(elt.clone() + T::one())).take(size)), + ElementGenerator::Checkerboard => Array::from_iter([T::one(), zero].iter().cycle().take(size).cloned()), + ElementGenerator::Zero => Array::zeros(size), + }) .into_shape_with_order((self.dim, self.memory_order)) .unwrap() }