Skip to content

Commit

Permalink
blas: test with more than one pattern in data
Browse files Browse the repository at this point in the history
Implement a checkerboard pattern in input data just to test with some
another kind of input.
  • Loading branch information
bluss committed Aug 9, 2024
1 parent 0153a37 commit 876ad01
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
16 changes: 10 additions & 6 deletions crates/blas-tests/tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use ndarray::linalg::general_mat_vec_mul;
use ndarray::Order;
use ndarray::{Data, Ix, LinalgScalar};
use ndarray_gen::array_builder::ArrayBuilder;
use ndarray_gen::array_builder::ElementGenerator;

use approx::assert_relative_eq;
use defmac::defmac;
Expand Down Expand Up @@ -230,7 +231,6 @@ fn gen_mat_mul()
let sizes = vec![
(4, 4, 4),
(8, 8, 8),
(10, 10, 10),
(8, 8, 1),
(1, 10, 10),
(10, 1, 10),
Expand All @@ -241,19 +241,23 @@ fn gen_mat_mul()
(4, 17, 3),
(17, 3, 22),
(19, 18, 2),
(16, 17, 15),
(15, 16, 17),
(67, 63, 62),
(67, 50, 62),
];
let strides = &[1, 2, -1, -2];
let cf_order = [Order::C, Order::F];
let generator = [ElementGenerator::Sequential, ElementGenerator::Checkerboard];

// test different strides and memory orders
for (&s1, &s2) in iproduct!(strides, strides) {
for (&s1, &s2, &gen) in iproduct!(strides, strides, &generator) {
for &(m, k, n) in &sizes {
for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);
let a = ArrayBuilder::new((m, k)).memory_order(ord1).build() * 0.5;
println!("Case s1={}, s2={}, gen={:?}, orders={:?}, {:?}, {:?}", s1, s2, gen, ord1, ord2, ord3);
let a = ArrayBuilder::new((m, k))
.memory_order(ord1)
.generator(gen)
.build()
* 0.5;
let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build();

Expand Down
17 changes: 8 additions & 9 deletions crates/ndarray-gen/src/array_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct ArrayBuilder<D: Dimension>
pub enum ElementGenerator
{
Sequential,
Checkerboard,
Zero,
}

Expand Down Expand Up @@ -64,16 +65,14 @@ where D: Dimension
pub fn build<T>(self) -> Array<T, D>
where T: Num + Clone
{
let mut current = T::zero();
let zero = T::zero();
let size = self.dim.size();
let use_zeros = self.generator == ElementGenerator::Zero;
Array::from_iter((0..size).map(|_| {
let ret = current.clone();
if !use_zeros {
current = ret.clone() + T::one();
}
ret
}))
(match self.generator {
ElementGenerator::Sequential =>
Array::from_iter(core::iter::successors(Some(zero), |elt| Some(elt.clone() + T::one())).take(size)),
ElementGenerator::Checkerboard => Array::from_iter([T::one(), zero].iter().cycle().take(size).cloned()),
ElementGenerator::Zero => Array::zeros(size),
})
.into_shape_with_order((self.dim, self.memory_order))
.unwrap()
}
Expand Down

0 comments on commit 876ad01

Please sign in to comment.