-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce the MulAcc trait to parameterize matrix products
This trait has less requirements than the current `Num` bound, and is also amenable to optimized usage for non-`Copy` types. It is therefore a good candidate to remove the `Copy` requirement.
- Loading branch information
1 parent
e857e1d
commit d56bd72
Showing
2 changed files
with
43 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
//! Multiply-accumulate (MAC) trait and implementations | ||
//! It's useful to define our own MAC trait as it's the main primitive we use | ||
//! in matrix products, and defining it ourselves means we can define an | ||
//! implementation that does not require cloning, which should prove useful | ||
//! when defining sparse matrices per blocks (eg BSR, BSC) | ||
|
||
/// Trait for types that have a multiply-accumulate operation, as required | ||
/// in dot products and matrix products. | ||
/// | ||
/// This trait is automatically implemented for numeric types that are `Copy`, | ||
/// however the implementation is open for more complex types, to allow them | ||
/// to provide the most performant implementation. For instance, we could have | ||
/// a default implementation for numeric types that are `Clone`, but it would | ||
/// make possibly unnecessary copies. | ||
pub trait MulAcc { | ||
/// Multiply and accumulate in this variable, formally `*self += a * b`. | ||
fn mul_acc(&mut self, a: &Self, b: &Self); | ||
} | ||
|
||
impl<N> MulAcc for N | ||
where | ||
N: Copy + num_traits::MulAdd<Output = N>, | ||
{ | ||
fn mul_acc(&mut self, a: &Self, b: &Self) { | ||
*self = a.mul_add(*b, *self); | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::MulAcc; | ||
|
||
#[test] | ||
fn mul_acc_f64() { | ||
let mut a = 1f64; | ||
let b = 2.; | ||
let c = 3.; | ||
a.mul_acc(&b, &c); | ||
assert_eq!(a, 7.); | ||
} | ||
} |