Skip to content

Commit

Permalink
Introduce the MulAcc trait to parameterize matrix products
Browse files Browse the repository at this point in the history
This trait has less requirements than the current `Num` bound, and is
also amenable to optimized usage for non-`Copy` types. It is therefore a
good candidate to remove the `Copy` requirement.
  • Loading branch information
vbarrielle committed Feb 23, 2021
1 parent e857e1d commit d56bd72
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ assert_eq!(a, b.to_csc());

pub mod array_backend;
mod dense_vector;
mod mul_acc;
pub mod errors;
pub mod indexing;
#[cfg(not(miri))]
Expand All @@ -102,6 +103,7 @@ pub use crate::sparse::{
};

pub use crate::dense_vector::{DenseVector, DenseVectorMut};
pub use crate::mul_acc::MulAcc;

pub use crate::sparse::symmetric::is_symmetric;

Expand Down
41 changes: 41 additions & 0 deletions src/mul_acc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//! Multiply-accumulate (MAC) trait and implementations
//! It's useful to define our own MAC trait as it's the main primitive we use
//! in matrix products, and defining it ourselves means we can define an
//! implementation that does not require cloning, which should prove useful
//! when defining sparse matrices per blocks (eg BSR, BSC)

/// Trait for types that have a multiply-accumulate operation, as required
/// in dot products and matrix products.
///
/// This trait is automatically implemented for numeric types that are `Copy`,
/// however the implementation is open for more complex types, to allow them
/// to provide the most performant implementation. For instance, we could have
/// a default implementation for numeric types that are `Clone`, but it would
/// make possibly unnecessary copies.
pub trait MulAcc {
/// Multiply and accumulate in this variable, formally `*self += a * b`.
fn mul_acc(&mut self, a: &Self, b: &Self);
}

impl<N> MulAcc for N
where
N: Copy + num_traits::MulAdd<Output = N>,
{
fn mul_acc(&mut self, a: &Self, b: &Self) {
*self = a.mul_add(*b, *self);
}
}

#[cfg(test)]
mod tests {
use super::MulAcc;

#[test]
fn mul_acc_f64() {
let mut a = 1f64;
let b = 2.;
let c = 3.;
a.mul_acc(&b, &c);
assert_eq!(a, 7.);
}
}

0 comments on commit d56bd72

Please sign in to comment.