Skip to content

Commit

Permalink
Remove need for Copy bound
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarrielle committed Mar 5, 2021
1 parent d56bd72 commit d2f0da7
Show file tree
Hide file tree
Showing 14 changed files with 322 additions and 308 deletions.
1 change: 1 addition & 0 deletions sprs-ldl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ impl<N, I: SpIndex> LdlNumeric<N, I> {
) -> <<V as DenseVector>::Owned as DenseVector>::Owned
where
N: 'a + Copy + Num + std::ops::SubAssign + std::ops::DivAssign,
N: for<'r> std::ops::DivAssign<&'r N>,
V: DenseVector<Scalar = N>,
<V as DenseVector>::Owned: DenseVectorMut + DenseVector<Scalar = N>,
for<'b> &'b <V as DenseVector>::Owned: DenseVector<Scalar = N>,
Expand Down
4 changes: 2 additions & 2 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ pub fn write_matrix_market<'a, N, I, M, P>(
) -> Result<(), io::Error>
where
I: 'a + SpIndex + fmt::Display,
N: 'a + PrimitiveKind + Copy + fmt::Display,
N: 'a + PrimitiveKind + fmt::Display,
M: IntoIterator<Item = (&'a N, (I, I))> + SparseMat,
P: AsRef<Path>,
{
Expand Down Expand Up @@ -319,7 +319,7 @@ pub fn write_matrix_market_sym<'a, N, I, M, P>(
) -> Result<(), io::Error>
where
I: 'a + SpIndex + fmt::Display,
N: 'a + PrimitiveKind + Copy + fmt::Display,
N: 'a + PrimitiveKind + fmt::Display,
M: IntoIterator<Item = (&'a N, (I, I))> + SparseMat,
P: AsRef<Path>,
{
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ assert_eq!(a, b.to_csc());

pub mod array_backend;
mod dense_vector;
mod mul_acc;
pub mod errors;
pub mod indexing;
#[cfg(not(miri))]
pub mod io;
mod mul_acc;
pub mod num_kinds;
mod range;
mod sparse;
Expand Down
10 changes: 6 additions & 4 deletions src/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ pub(crate) mod utils {
true
}

pub fn sort_indices_data_slices<N: Copy, I: SpIndex>(
pub fn sort_indices_data_slices<N: Clone, I: SpIndex>(
indices: &mut [I],
data: &mut [N],
buf: &mut Vec<(I, N)>,
Expand All @@ -379,13 +379,15 @@ pub(crate) mod utils {
buf.clear();
buf.reserve_exact(len);
for (i, v) in indices.iter().zip(data.iter()) {
buf.push((*i, *v));
buf.push((*i, v.clone()));
}

buf.sort_unstable_by_key(|x| x.0);

for (&(i, x), (ind, v)) in
buf.iter().zip(indices.iter_mut().zip(data.iter_mut()))
for ((i, x), (ind, v)) in buf
.iter()
.cloned()
.zip(indices.iter_mut().zip(data.iter_mut()))
{
*ind = i;
*v = x;
Expand Down
148 changes: 98 additions & 50 deletions src/sparse/binop.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Sparse matrix addition, subtraction

use std::ops::{Add, Deref, Mul, Sub};

use crate::errors::StructureError;
use crate::indexing::SpIndex;
use crate::sparse::compressed::SpMatView;
Expand All @@ -15,34 +17,68 @@ use num_traits::Num;

use crate::Ix2;

/// Sparse matrix addition, with matrices sharing the same storage type
pub fn add_mat_same_storage<N, I, Iptr, Mat1, Mat2>(
lhs: &Mat1,
rhs: &Mat2,
) -> CsMatI<N, I, Iptr>
impl<'a, 'b, N, I, Iptr, IpStorage, IStorage, DStorage, IpS2, IS2, DS2>
Add<&'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>>
for &'a CsMatBase<N, I, IpStorage, IStorage, DStorage, Iptr>
where
N: Num + Copy,
I: SpIndex,
Iptr: SpIndex,
Mat1: SpMatView<N, I, Iptr>,
Mat2: SpMatView<N, I, Iptr>,
N: num_traits::Zero + PartialEq + Clone + Default,
for<'r> &'r N: Add<&'r N, Output = N>,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpStorage: 'a + Deref<Target = [Iptr]>,
IStorage: 'a + Deref<Target = [I]>,
DStorage: 'a + Deref<Target = [N]>,
IpS2: 'a + Deref<Target = [Iptr]>,
IS2: 'a + Deref<Target = [I]>,
DS2: 'a + Deref<Target = [N]>,
{
csmat_binop(lhs.view(), rhs.view(), |&x, &y| x + y)
type Output = CsMatI<N, I, Iptr>;

fn add(
self,
rhs: &'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>,
) -> CsMatI<N, I, Iptr> {
if self.storage() != rhs.view().storage() {
return csmat_binop(
self.view(),
rhs.to_other_storage().view(),
|x, y| x + y,
);
}
csmat_binop(self.view(), rhs.view(), |x, y| x + y)
}
}

/// Sparse matrix subtraction, with same storage type
pub fn sub_mat_same_storage<N, I, Iptr, Mat1, Mat2>(
lhs: &Mat1,
rhs: &Mat2,
) -> CsMatI<N, I, Iptr>
impl<'a, 'b, N, I, Iptr, IpStorage, IStorage, DStorage, IpS2, IS2, DS2>
Sub<&'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>>
for &'a CsMatBase<N, I, IpStorage, IStorage, DStorage, Iptr>
where
N: Num + Copy,
I: SpIndex,
Iptr: SpIndex,
Mat1: SpMatView<N, I, Iptr>,
Mat2: SpMatView<N, I, Iptr>,
N: num_traits::Zero + PartialEq + Clone + Default,
for<'r> &'r N: Sub<&'r N, Output = N>,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpStorage: 'a + Deref<Target = [Iptr]>,
IStorage: 'a + Deref<Target = [I]>,
DStorage: 'a + Deref<Target = [N]>,
IpS2: 'a + Deref<Target = [Iptr]>,
IS2: 'a + Deref<Target = [I]>,
DS2: 'a + Deref<Target = [N]>,
{
csmat_binop(lhs.view(), rhs.view(), |&x, &y| x - y)
type Output = CsMatI<N, I, Iptr>;

fn sub(
self,
rhs: &'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>,
) -> CsMatI<N, I, Iptr> {
if self.storage() != rhs.view().storage() {
return csmat_binop(
self.view(),
rhs.to_other_storage().view(),
|x, y| x - y,
);
}
csmat_binop(self.view(), rhs.view(), |x, y| x - y)
}
}

/// Sparse matrix scalar multiplication, with same storage type
Expand All @@ -51,35 +87,57 @@ pub fn mul_mat_same_storage<N, I, Iptr, Mat1, Mat2>(
rhs: &Mat2,
) -> CsMatI<N, I, Iptr>
where
N: Num + Copy,
N: num_traits::Zero + PartialEq + Clone,
for<'r> &'r N: std::ops::Mul<&'r N, Output = N>,
I: SpIndex,
Iptr: SpIndex,
Mat1: SpMatView<N, I, Iptr>,
Mat2: SpMatView<N, I, Iptr>,
{
csmat_binop(lhs.view(), rhs.view(), |&x, &y| x * y)
csmat_binop(lhs.view(), rhs.view(), |x, y| x * y)
}

/// Sparse matrix multiplication by a scalar
pub fn scalar_mul_mat<N, I, Iptr, Mat>(mat: &Mat, val: N) -> CsMatI<N, I, Iptr>
where
N: Num + Copy,
I: SpIndex,
Iptr: SpIndex,
Mat: SpMatView<N, I, Iptr>,
{
let mat = mat.view();
mat.map(|&x| x * val)
macro_rules! sparse_scalar_mul {
($scalar: ident) => {
impl<'a, I, Iptr, IpStorage, IStorage, DStorage> Mul<$scalar>
for &'a CsMatBase<$scalar, I, IpStorage, IStorage, DStorage, Iptr>
where
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpStorage: 'a + Deref<Target = [Iptr]>,
IStorage: 'a + Deref<Target = [I]>,
DStorage: 'a + Deref<Target = [$scalar]>,
{
type Output = CsMatI<$scalar, I, Iptr>;

fn mul(self, rhs: $scalar) -> Self::Output {
self.map(|x| x * rhs)
}
}
};
}

sparse_scalar_mul!(u8);
sparse_scalar_mul!(i8);
sparse_scalar_mul!(u16);
sparse_scalar_mul!(i16);
sparse_scalar_mul!(u32);
sparse_scalar_mul!(i32);
sparse_scalar_mul!(u64);
sparse_scalar_mul!(i64);
sparse_scalar_mul!(isize);
sparse_scalar_mul!(usize);
sparse_scalar_mul!(f32);
sparse_scalar_mul!(f64);

/// Applies a binary operation to matching non-zero elements
/// of two sparse matrices. When e.g. only the `lhs` has a non-zero at a
/// given location, `0` is inferred for the non-zero value of the other matrix.
/// Both matrices should have the same storage.
///
/// Thus the behaviour is correct iff `binop(N::zero(), N::zero()) == N::zero()`
///
/// # Errors
/// # Panics
///
/// - on incompatible dimensions
/// - on incomatible storage
Expand All @@ -89,7 +147,7 @@ pub fn csmat_binop<N, I, Iptr, F>(
binop: F,
) -> CsMatI<N, I, Iptr>
where
N: Num + Clone,
N: num_traits::Zero + PartialEq + Clone,
I: SpIndex,
Iptr: SpIndex,
F: Fn(&N, &N) -> N,
Expand All @@ -107,10 +165,6 @@ where
let max_nnz = lhs.nnz() + rhs.nnz();
let mut out_indptr = vec![Iptr::zero(); lhs.outer_dims() + 1];
let mut out_indices = vec![I::zero(); max_nnz];

// Sadly the vec! macro requires Clone, but we don't want to force
// Clone on our consumers, so we have to use this workaround.
// This should compile to decent code however.
let mut out_data = vec![N::zero(); max_nnz];

let nnz = csmat_binop_same_storage_raw(
Expand Down Expand Up @@ -146,7 +200,7 @@ pub fn csmat_binop_same_storage_raw<N, I, Iptr, F>(
out_data: &mut [N],
) -> usize
where
N: Num,
N: num_traits::Zero + PartialEq,
I: SpIndex,
Iptr: SpIndex,
F: Fn(&N, &N) -> N,
Expand Down Expand Up @@ -391,11 +445,8 @@ mod test {
let a = mat1();
let b = mat2();

let c = super::add_mat_same_storage(&a, &b);
let c_true = mat1_plus_mat2();
assert_eq!(c, c_true);

let c = &a + &b;
let c_true = mat1_plus_mat2();
assert_eq!(c, c_true);

// test with CSR matrices having differ row patterns
Expand All @@ -416,11 +467,8 @@ mod test {
let a = mat1();
let b = mat2();

let c = super::sub_mat_same_storage(&a, &b);
let c_true = mat1_minus_mat2();
assert_eq!(c, c_true);

let c = &a - &b;
let c_true = mat1_minus_mat2();
assert_eq!(c, c_true);
}

Expand All @@ -439,7 +487,7 @@ mod test {
#[test]
fn test_smul() {
let a = mat1();
let c = super::scalar_mul_mat(&a, 2.);
let c = &a * 2.;
let c_true = mat1_times_2();
assert_eq!(c.indptr(), c_true.indptr());
assert_eq!(c.indices(), c_true.indices());
Expand Down
Loading

0 comments on commit d2f0da7

Please sign in to comment.