diff --git a/sprs-ldl/src/lib.rs b/sprs-ldl/src/lib.rs index e82e55db..175ac8d8 100644 --- a/sprs-ldl/src/lib.rs +++ b/sprs-ldl/src/lib.rs @@ -52,7 +52,6 @@ // Copyright, this License, and the Availability note are retained, // and a notice that the code was modified is included. use std::ops::Deref; -use std::ops::IndexMut; use num_traits::Num; @@ -61,6 +60,7 @@ use sprs::indexing::SpIndex; use sprs::linalg; use sprs::stack::DStack; use sprs::{is_symmetric, CsMatViewI, PermOwnedI, Permutation}; +use sprs::{DenseVector, DenseVectorMut}; use sprs::{FillInReduction, PermutationCheck, SymmetryCheck}; #[cfg(feature = "sprs_suitesparse_ldl")] @@ -380,18 +380,31 @@ impl LdlNumeric { } /// Solve the system A x = rhs - pub fn solve<'a, V>(&self, rhs: &V) -> Vec + /// + /// The type constraints look complicated, but they simply mean that + /// `rhs` should be interpretable as a dense vector, and we will return + /// a dense vector of a compatible type (but owned). + pub fn solve<'a, V>( + &self, + rhs: V, + ) -> <::Owned as DenseVector>::Owned where - N: 'a + Copy + Num, - V: Deref, + N: 'a + Copy + Num + std::ops::SubAssign + std::ops::DivAssign, + V: DenseVector, + ::Owned: DenseVectorMut + DenseVector, + for<'b> &'b ::Owned: DenseVector, + for<'b> &'b mut ::Owned: + DenseVectorMut + DenseVector, + <::Owned as DenseVector>::Owned: + DenseVectorMut + DenseVector, { - let mut x = &self.symbolic.perm * &rhs[..]; + let mut x = &self.symbolic.perm * rhs; let l = self.l(); ldl_lsolve(&l, &mut x); linalg::diag_solve(&self.diag, &mut x); ldl_ltsolve(&l, &mut x); let pinv = self.symbolic.perm.inv(); - &pinv * &x + &pinv * x } /// The diagonal factor D of the LDL^T decomposition @@ -579,34 +592,34 @@ where /// Triangular solve specialized on lower triangular matrices /// produced by ldlt (diagonal terms are omitted and assumed to be 1). -pub fn ldl_lsolve(l: &CsMatViewI, x: &mut V) +pub fn ldl_lsolve(l: &CsMatViewI, mut x: V) where - N: Clone + Copy + Num, + N: Clone + Copy + Num + std::ops::SubAssign, I: SpIndex, - V: IndexMut, + V: DenseVectorMut + DenseVector, { for (col_ind, vec) in l.outer_iterator().enumerate() { - let x_col = x[col_ind]; + let x_col = *x.index(col_ind); for (row_ind, &value) in vec.iter() { - x[row_ind] = x[row_ind] - value * x_col; + *x.index_mut(row_ind) -= value * x_col; } } } /// Triangular transposed solve specialized on lower triangular matrices /// produced by ldlt (diagonal terms are omitted and assumed to be 1). -pub fn ldl_ltsolve(l: &CsMatViewI, x: &mut V) +pub fn ldl_ltsolve(l: &CsMatViewI, mut x: V) where - N: Clone + Copy + Num, + N: Clone + Copy + Num + std::ops::SubAssign, I: SpIndex, - V: IndexMut, + V: DenseVectorMut + DenseVector, { for (outer_ind, vec) in l.outer_iterator().enumerate().rev() { - let mut x_outer = x[outer_ind]; + let mut x_outer = *x.index(outer_ind); for (inner_ind, &value) in vec.iter() { - x_outer = x_outer - value * x[inner_ind]; + x_outer -= value * *x.index(inner_ind); } - x[outer_ind] = x_outer; + *x.index_mut(outer_ind) = x_outer; } } @@ -838,15 +851,15 @@ mod test { vec![1., 2., 21., 6., 6., 2., 2., 8.], ); - let b = vec![9., 60., 18., 34.]; - let x0 = vec![1., 2., 3., 4.]; + let b = ndarray::arr1(&[9., 60., 18., 34.]); + let x0 = ndarray::arr1(&[1., 2., 3., 4.]); let ldlt = super::Ldl::new() .check_symmetry(super::SymmetryCheck::DontCheckSymmetry) .fill_in_reduction(super::FillInReduction::ReverseCuthillMcKee) .numeric(mat.view()) .unwrap(); - let x = ldlt.solve(&b); + let x = ldlt.solve(b.view()); assert_eq!(x, x0); } diff --git a/src/dense_vector.rs b/src/dense_vector.rs new file mode 100644 index 00000000..afa1fc54 --- /dev/null +++ b/src/dense_vector.rs @@ -0,0 +1,326 @@ +use crate::Ix1; +use ndarray::{self, ArrayBase}; +use num_traits::identities::Zero; + +/// A trait for types representing dense vectors, useful for expressing +/// algorithms such as sparse-dense dot product, or linear solves. +/// +/// This trait is sealed, and cannot be implemented outside of the `sprs` +/// crate. +pub trait DenseVector: seal::Sealed { + type Owned; + type Scalar; + + /// The dimension of the vector + fn dim(&self) -> usize; + + /// Random access to an element in the vector. + /// + /// # Panics + /// + /// If the index is out of bounds + fn index(&self, idx: usize) -> &Self::Scalar; + + /// Create an owned version of this dense vector type, filled with zeros + fn zeros(dim: usize) -> Self::Owned; + + /// Copies this vector into an owned version + fn to_owned(&self) -> Self::Owned; +} + +impl DenseVector for [N] { + type Owned = Vec; + type Scalar = N; + + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } +} + +impl<'a, N: 'a + Zero + Clone> DenseVector for &'a [N] { + type Owned = Vec; + type Scalar = N; + + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } +} + +impl<'a, N: 'a + Zero + Clone> DenseVector for &'a mut [N] { + type Owned = Vec; + type Scalar = N; + + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } +} + +impl DenseVector for Vec { + type Owned = Vec; + type Scalar = N; + + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } +} + +impl<'a, N: 'a + Zero + Clone> DenseVector for &'a Vec { + type Owned = Vec; + type Scalar = N; + + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } +} + +impl<'a, N: 'a + Zero + Clone> DenseVector for &'a mut Vec { + type Owned = Vec; + type Scalar = N; + + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } +} + +impl DenseVector for ArrayBase +where + S: ndarray::Data, + N: Zero + Clone, +{ + type Owned = ndarray::Array; + type Scalar = N; + + fn dim(&self) -> usize { + self.shape()[0] + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[[idx]] + } + + fn zeros(dim: usize) -> Self::Owned { + ndarray::Array::zeros(dim) + } + + fn to_owned(&self) -> Self::Owned { + self.to_owned() + } +} + +impl<'a, N, S> DenseVector for &'a ArrayBase +where + S: ndarray::Data, + N: 'a + Zero + Clone, +{ + type Owned = ndarray::Array; + type Scalar = N; + + fn dim(&self) -> usize { + self.shape()[0] + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[[idx]] + } + + fn zeros(dim: usize) -> Self::Owned { + ndarray::Array::zeros(dim) + } + + fn to_owned(&self) -> Self::Owned { + ArrayBase::to_owned(self) + } +} + +impl<'a, N, S> DenseVector for &'a mut ArrayBase +where + S: ndarray::Data, + N: 'a + Zero + Clone, +{ + type Owned = ndarray::Array; + type Scalar = N; + + fn dim(&self) -> usize { + self.shape()[0] + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[[idx]] + } + + fn zeros(dim: usize) -> Self::Owned { + ndarray::Array::zeros(dim) + } + + fn to_owned(&self) -> Self::Owned { + ArrayBase::to_owned(self) + } +} + +/// Trait for dense vectors that can be modified, useful for expressing +/// algorithms which compute a resulting dense vector, such as solvers. +/// +/// This trait is sealed, and cannot be implemented outside of the `sprs` +/// crate. +pub trait DenseVectorMut: DenseVector { + /// Random mutable access to an element in the vector. + /// + /// # Panics + /// + /// If the index is out of bounds + fn index_mut(&mut self, idx: usize) -> &mut Self::Scalar; +} + +impl<'a, N: 'a + Zero + Clone> DenseVectorMut for [N] { + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[idx] + } +} + +impl<'a, N: 'a + Zero + Clone> DenseVectorMut for &'a mut [N] { + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[idx] + } +} + +impl DenseVectorMut for Vec { + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[idx] + } +} + +impl<'a, N: 'a + Zero + Clone> DenseVectorMut for &'a mut Vec { + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[idx] + } +} + +impl DenseVectorMut for ArrayBase +where + S: ndarray::DataMut, + N: Zero + Clone, +{ + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[[idx]] + } +} + +impl<'a, N, S> DenseVectorMut for &'a mut ArrayBase +where + S: ndarray::DataMut, + N: 'a + Zero + Clone, +{ + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[[idx]] + } +} + +mod seal { + pub trait Sealed {} + + impl Sealed for [N] {} + impl<'a, N: 'a> Sealed for &'a [N] {} + impl<'a, N: 'a> Sealed for &'a mut [N] {} + impl Sealed for Vec {} + impl<'a, N: 'a> Sealed for &'a Vec {} + impl<'a, N: 'a> Sealed for &'a mut Vec {} + impl> Sealed + for ndarray::ArrayBase + { + } + impl<'a, N: 'a, S: ndarray::Data> Sealed + for &'a ndarray::ArrayBase + { + } + impl<'a, N: 'a, S: ndarray::Data> Sealed + for &'a mut ndarray::ArrayBase + { + } +} diff --git a/src/lib.rs b/src/lib.rs index d5e8de22..ddefaa0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,6 +75,7 @@ assert_eq!(a, b.to_csc()); */ pub mod array_backend; +mod dense_vector; pub mod errors; pub mod indexing; #[cfg(not(miri))] @@ -100,6 +101,8 @@ pub use crate::sparse::{ TriMatIter, TriMatView, TriMatViewI, TriMatViewMut, TriMatViewMutI, }; +pub use crate::dense_vector::{DenseVector, DenseVectorMut}; + pub use crate::sparse::symmetric::is_symmetric; pub use crate::sparse::permutation::{ diff --git a/src/sparse/linalg.rs b/src/sparse/linalg.rs index 15908b8d..b26cc12a 100644 --- a/src/sparse/linalg.rs +++ b/src/sparse/linalg.rs @@ -1,9 +1,9 @@ +use crate::{DenseVector, DenseVectorMut}; ///! Sparse linear algebra ///! ///! This module contains solvers for sparse linear systems. Currently ///! there are solver for sparse triangular systems and symmetric systems. use num_traits::Num; -use std::iter::IntoIterator; pub mod etree; pub mod ordering; @@ -12,13 +12,15 @@ pub mod trisolve; pub use self::ordering::reverse_cuthill_mckee; /// Diagonal solve -pub fn diag_solve<'a, N, I1, I2>(diag: I1, x: I2) +pub fn diag_solve<'a, N, V1, V2>(diag: V1, mut x: V2) where - N: 'a + Copy + Num, - I1: IntoIterator, - I2: IntoIterator, + N: 'a + Copy + Num + std::ops::DivAssign, + V1: DenseVector, + V2: DenseVectorMut + DenseVector, { - for (xv, dv) in x.into_iter().zip(diag.into_iter()) { - *xv = *xv / *dv; + let n = x.dim(); + assert_eq!(diag.dim(), n); + for i in 0..n { + *x.index_mut(i) /= *diag.index(i); } } diff --git a/src/sparse/linalg/trisolve.rs b/src/sparse/linalg/trisolve.rs index ea7be741..093d5614 100644 --- a/src/sparse/linalg/trisolve.rs +++ b/src/sparse/linalg/trisolve.rs @@ -1,19 +1,18 @@ +use crate::dense_vector::{DenseVector, DenseVectorMut}; use crate::errors::{LinalgError, SingularMatrixInfo}; use crate::indexing::SpIndex; -use crate::sparse::vec; use crate::sparse::CsMatViewI; use crate::sparse::CsVecViewI; use crate::stack::{self, DStack, StackVal}; use num_traits::Num; /// Sparse triangular solves -use std::ops::IndexMut; -fn check_solver_dimensions( +fn check_solver_dimensions( lower_tri_mat: &CsMatViewI, rhs: &V, ) where N: Copy + Num, - V: vec::VecDim, + V: DenseVector + ?Sized, I: SpIndex, Iptr: SpIndex, { @@ -33,17 +32,17 @@ fn check_solver_dimensions( /// /// This solve does not assume the input matrix to actually be /// triangular, instead it ignores the upper triangular part. -pub fn lsolve_csr_dense_rhs( +pub fn lsolve_csr_dense_rhs( lower_tri_mat: CsMatViewI, - rhs: &mut V, + mut rhs: V, ) -> Result<(), LinalgError> where - N: Copy + Num, - V: IndexMut + vec::VecDim, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { - check_solver_dimensions(&lower_tri_mat, rhs); + check_solver_dimensions(&lower_tri_mat, &rhs); if !lower_tri_mat.is_csr() { panic!("Storage mismatch"); } @@ -57,7 +56,7 @@ where for (row_ind, row) in lower_tri_mat.outer_iterator().enumerate() { let mut diag_val = N::zero(); - let mut x = rhs[row_ind]; + let mut x = *rhs.index(row_ind); for (col_ind, &val) in row.iter() { if col_ind == row_ind { diag_val = val; @@ -66,7 +65,7 @@ where if col_ind > row_ind { continue; } - x = x - val * rhs[col_ind]; + x -= val * *rhs.index(col_ind); } if diag_val == N::zero() { return Err(LinalgError::SingularMatrix(SingularMatrixInfo { @@ -74,7 +73,7 @@ where reason: "diagonal element is 0", })); } - rhs[row_ind] = x / diag_val; + *rhs.index_mut(row_ind) = x / diag_val; } Ok(()) } @@ -89,17 +88,17 @@ where /// is the diagonal element (thus actual sorted lower triangular matrices work /// best). Otherwise, logarithmic search for the diagonal element /// has to be performed for each column. -pub fn lsolve_csc_dense_rhs( +pub fn lsolve_csc_dense_rhs( lower_tri_mat: CsMatViewI, - rhs: &mut V, + mut rhs: V, ) -> Result<(), LinalgError> where - N: Copy + Num, - V: IndexMut + vec::VecDim, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { - check_solver_dimensions(&lower_tri_mat, rhs); + check_solver_dimensions(&lower_tri_mat, &rhs); if !lower_tri_mat.is_csc() { panic!("Storage mismatch"); } @@ -113,18 +112,19 @@ where // L_1_1 x1 = b_1 - x0*l_1_0 for (col_ind, col) in lower_tri_mat.outer_iterator().enumerate() { - lspsolve_csc_process_col(col, col_ind, rhs)?; + lspsolve_csc_process_col(col, col_ind, &mut rhs)?; } Ok(()) } -fn lspsolve_csc_process_col( +fn lspsolve_csc_process_col( col: CsVecViewI, col_ind: usize, rhs: &mut V, ) -> Result<(), LinalgError> where - V: vec::VecDim + IndexMut, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut, I: SpIndex, { if let Some(&diag_val) = col.get(col_ind) { @@ -134,15 +134,14 @@ where reason: "diagonal element is a numeric 0", })); } - let b = rhs[col_ind]; + let b = *rhs.index(col_ind); let x = b / diag_val; - rhs[col_ind] = x; + *rhs.index_mut(col_ind) = x; for (row_ind, &val) in col.iter() { if row_ind <= col_ind { continue; } - let b = rhs[row_ind]; - rhs[row_ind] = b - val * x; + *rhs.index_mut(row_ind) -= val * x; } } else { return Err(LinalgError::SingularMatrix(SingularMatrixInfo { @@ -163,17 +162,17 @@ where /// is the diagonal element (thus actual sorted lower triangular matrices work /// best). Otherwise, logarithmic search for the diagonal element /// has to be performed for each column. -pub fn usolve_csc_dense_rhs( +pub fn usolve_csc_dense_rhs( upper_tri_mat: CsMatViewI, - rhs: &mut V, + mut rhs: V, ) -> Result<(), LinalgError> where - N: Copy + Num, - V: IndexMut + vec::VecDim, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { - check_solver_dimensions(&upper_tri_mat, rhs); + check_solver_dimensions(&upper_tri_mat, &rhs); if !upper_tri_mat.is_csc() { panic!("Storage mismatch"); } @@ -194,15 +193,14 @@ where reason: "diagonal element is a numeric 0", })); } - let b = rhs[col_ind]; + let b = *rhs.index(col_ind); let x = b / diag_val; - rhs[col_ind] = x; + *rhs.index_mut(col_ind) = x; for (row_ind, &val) in col.iter() { if row_ind >= col_ind { continue; } - let b = rhs[row_ind]; - rhs[row_ind] = b - val * x; + *rhs.index_mut(row_ind) -= val * x; } } else { return Err(LinalgError::SingularMatrix(SingularMatrixInfo { @@ -222,17 +220,17 @@ where /// /// This solve does not assume the input matrix to actually be /// triangular, instead it ignores the upper triangular part. -pub fn usolve_csr_dense_rhs( +pub fn usolve_csr_dense_rhs( upper_tri_mat: CsMatViewI, - rhs: &mut V, + mut rhs: V, ) -> Result<(), LinalgError> where - N: Copy + Num, - V: IndexMut + vec::VecDim, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut + DenseVector, I: SpIndex, Iptr: SpIndex, { - check_solver_dimensions(&upper_tri_mat, rhs); + check_solver_dimensions(&upper_tri_mat, &rhs); if !upper_tri_mat.is_csr() { panic!("Storage mismatch"); } @@ -245,7 +243,7 @@ where // x0 = (b_0 - u_0_1^T.x_1) / u_0_0 for (row_ind, row) in upper_tri_mat.outer_iterator().enumerate().rev() { let mut diag_val = N::zero(); - let mut x = rhs[row_ind]; + let mut x = *rhs.index(row_ind); for (col_ind, &val) in row.iter() { if col_ind == row_ind { diag_val = val; @@ -254,7 +252,7 @@ where if col_ind < row_ind { continue; } - x = x - val * rhs[col_ind]; + x -= val * *rhs.index(col_ind); } if diag_val == N::zero() { return Err(LinalgError::SingularMatrix(SingularMatrixInfo { @@ -262,7 +260,7 @@ where reason: "diagonal element is a numeric 0", })); } - rhs[row_ind] = x / diag_val; + *rhs.index_mut(row_ind) = x / diag_val; } Ok(()) } @@ -289,15 +287,16 @@ where /// * if dstack is not empty /// * if `w_workspace` is not of length n /// -pub fn lsolve_csc_sparse_rhs( +pub fn lsolve_csc_sparse_rhs( lower_tri_mat: CsMatViewI, rhs: CsVecViewI, dstack: &mut DStack>, - x_workspace: &mut [N], + mut x_workspace: V, visited: &mut [bool], ) -> Result<(), LinalgError> where - N: Copy + Num, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut + DenseVector, I: SpIndex, Iptr: SpIndex, { @@ -310,7 +309,7 @@ where dstack.is_left_empty() && dstack.is_right_empty(), "dstack should be empty" ); - assert!(x_workspace.len() == n, "x should be of len n"); + assert!(x_workspace.dim() == n, "x should be of len n"); // the solve works out the sparsity of the solution using depth first // search on the matrix's graph @@ -353,11 +352,11 @@ where } // solve for the non-zero values into dense workspace - rhs.scatter(x_workspace); + rhs.scatter(&mut x_workspace); for &ind in dstack.iter_right().map(stack::extract_stack_val) { println!("ind: {}", ind); let col = lower_tri_mat.outer_view(ind).expect("ind not in bounds"); - lspsolve_csc_process_col(col, ind, x_workspace)?; + lspsolve_csc_process_col(col, ind, &mut x_workspace)?; } Ok(()) } @@ -367,6 +366,7 @@ mod test { use crate::sparse::{CsMat, CsVec}; use crate::stack::{self, DStack}; + use ndarray::arr1; use std::collections::HashSet; #[test] @@ -380,11 +380,11 @@ mod test { vec![0, 1, 0, 2], vec![1, 2, 1, 1], ); - let b = vec![3, 2, 4]; + let b = arr1(&[3, 2, 4]); let mut x = b.clone(); - super::lsolve_csr_dense_rhs(l.view(), &mut x).unwrap(); - assert_eq!(x, vec![3, 1, 1]); + super::lsolve_csr_dense_rhs(l.view(), x.view_mut()).unwrap(); + assert_eq!(x, arr1(&[3, 1, 1])); } #[test] @@ -403,6 +403,10 @@ mod test { super::lsolve_csc_dense_rhs(l.view(), &mut x).unwrap(); assert_eq!(x, vec![3, 1, 1]); + + let x: &mut [i32] = &mut [3, 5, 3]; + super::lsolve_csc_dense_rhs(l.view(), &mut x[..]).unwrap(); + assert_eq!(x, &[3, 1, 1]); } #[test] diff --git a/src/sparse/permutation.rs b/src/sparse/permutation.rs index 56ebffbd..4383363a 100644 --- a/src/sparse/permutation.rs +++ b/src/sparse/permutation.rs @@ -3,6 +3,7 @@ /// Both the permutation matrices and its inverse are stored use std::ops::{Deref, Mul}; +use crate::dense_vector::{DenseVector, DenseVectorMut}; use crate::indexing::SpIndex; use crate::sparse::{CompressedStorage, CsMatI, CsMatViewI}; @@ -251,21 +252,24 @@ where } } -impl<'a, 'b, N, I, IndStorage> Mul<&'a [N]> for &'b Permutation +impl<'b, V, I, IndStorage> Mul for &'b Permutation where IndStorage: 'b + Deref, - N: 'a + Copy, + V: DenseVector, + ::Owned: + DenseVectorMut + DenseVector::Scalar>, + ::Scalar: Clone, I: SpIndex, { - type Output = Vec; - fn mul(self, rhs: &'a [N]) -> Vec { - assert_eq!(self.dim, rhs.len()); - let mut res = rhs.to_vec(); + type Output = V::Owned; + fn mul(self, rhs: V) -> Self::Output { + assert_eq!(self.dim, rhs.dim()); + let mut res = rhs.to_owned(); match self.storage { Identity => res, FinitePerm { perm: ref p, .. } => { - for (pi, r) in p.iter().zip(res.iter_mut()) { - *r = rhs[pi.index_unchecked()]; + for (i, pi) in p.iter().enumerate() { + *res.index_mut(i) = rhs.index(pi.index_unchecked()).clone(); } res } @@ -273,6 +277,21 @@ where } } +impl Mul for Permutation +where + IndStorage: Deref, + V: DenseVector, + ::Owned: + DenseVectorMut + DenseVector::Scalar>, + ::Scalar: Clone, + I: SpIndex, +{ + type Output = V::Owned; + fn mul(self, rhs: V) -> Self::Output { + &self * rhs + } +} + /// Compute the square matrix resulting from the product P * A * P^T pub fn transform_mat_papt( mat: CsMatViewI, @@ -344,6 +363,10 @@ mod test { let y = &p * &x; assert_eq!(&y, &[2, 1, 3, 5, 4]); + + let x = ndarray::arr1(&[5, 1, 2, 3, 4]); + let y = p.view() * x.view(); + assert_eq!(y, ndarray::arr1(&[2, 1, 3, 5, 4])); } #[test] diff --git a/src/sparse/prod.rs b/src/sparse/prod.rs index 12cbb82d..8dddcff7 100644 --- a/src/sparse/prod.rs +++ b/src/sparse/prod.rs @@ -1,8 +1,8 @@ +use crate::dense_vector::{DenseVector, DenseVectorMut}; use crate::indexing::SpIndex; use crate::sparse::compressed::SpMatView; ///! Sparse matrix product use crate::sparse::prelude::*; -use crate::sparse::vec::DenseVector; use crate::Ix2; use ndarray::{ArrayView, ArrayViewMut, Axis}; use num_traits::Num; @@ -47,18 +47,19 @@ where /// Multiply a sparse CSC matrix with a dense vector and accumulate the result /// into another dense vector -pub fn mul_acc_mat_vec_csc( +pub fn mul_acc_mat_vec_csc( mat: CsMatViewI, in_vec: V, - res_vec: &mut [N], + mut res_vec: VRes, ) where - N: Num + Copy, + N: Num + Copy + std::ops::AddAssign, I: SpIndex, Iptr: SpIndex, - V: DenseVector, + V: DenseVector, + VRes: DenseVectorMut, { let mat = mat.view(); - if mat.cols() != in_vec.dim() || mat.rows() != res_vec.len() { + if mat.cols() != in_vec.dim() || mat.rows() != res_vec.dim() { panic!("Dimension mismatch"); } if !mat.is_csc() { @@ -69,24 +70,25 @@ pub fn mul_acc_mat_vec_csc( let multiplier = in_vec.index(col_ind); for (row_ind, &value) in vec.iter() { // TODO: unsafe access to value? needs bench - res_vec[row_ind] = res_vec[row_ind] + *multiplier * value; + *res_vec.index_mut(row_ind) += *multiplier * value; } } } /// Multiply a sparse CSR matrix with a dense vector and accumulate the result /// into another dense vector -pub fn mul_acc_mat_vec_csr( +pub fn mul_acc_mat_vec_csr( mat: CsMatViewI, in_vec: V, - res_vec: &mut [N], + mut res_vec: VRes, ) where - N: Num + Copy, + N: Num + Copy + std::ops::AddAssign, I: SpIndex, Iptr: SpIndex, - V: DenseVector, + V: DenseVector, + VRes: DenseVectorMut, { - if mat.cols() != in_vec.dim() || mat.rows() != res_vec.len() { + if mat.cols() != in_vec.dim() || mat.rows() != res_vec.dim() { panic!("Dimension mismatch"); } if !mat.is_csr() { @@ -94,13 +96,10 @@ pub fn mul_acc_mat_vec_csr( } for (row_ind, vec) in mat.outer_iterator().enumerate() { - // this unwrap is ok because we did the check before to ensure - // mat.row() == res_vec.len() and now the row_ind is within the - // range of [0, mat.row). So it should be safe. - let tv = res_vec.get_mut(row_ind).unwrap(); + let tv = res_vec.index_mut(row_ind); for (col_ind, &value) in vec.iter() { // TODO: unsafe access to value? needs bench - *tv = *tv + *in_vec.index(col_ind) * value; + *tv += *in_vec.index(col_ind) * value; } } } @@ -328,7 +327,7 @@ mod test { mat_dense2, }; use ndarray::linalg::Dot; - use ndarray::{arr2, s, Array, Array2, Dimension, ShapeBuilder}; + use ndarray::{arr1, arr2, s, Array, Array2, Dimension, ShapeBuilder}; #[test] fn test_csvec_dot_by_binary_search() { @@ -368,6 +367,31 @@ mod test { .all(|(x, y)| (*x - *y).abs() < epsilon)); } + #[test] + fn mul_csc_vec_ndarray() { + let indptr: &[usize] = &[0, 2, 4, 5, 6, 7]; + let indices: &[usize] = &[2, 3, 3, 4, 2, 1, 3]; + let data: &[f64] = &[ + 0.35310881, 0.42380633, 0.28035896, 0.58082095, 0.53350123, + 0.88132896, 0.72527863, + ]; + + let mat = CsMatView::new_csc((5, 5), indptr, indices, data); + let vector = arr1(&[0.1f64, 0.2, -0.1, 0.3, 0.9]); + let mut res_vec = Array::zeros(5); + mul_acc_mat_vec_csc(mat, vector, res_vec.view_mut()); + + let expected_output = + vec![0., 0.26439869, -0.01803924, 0.75120319, 0.11616419]; + + let epsilon = 1e-7; // TODO: get better values and increase precision + + assert!(res_vec + .iter() + .zip(expected_output.iter()) + .all(|(x, y)| (*x - *y).abs() < epsilon)); + } + #[test] fn mul_csr_vec() { let indptr: &[usize] = &[0, 3, 3, 5, 6, 7]; @@ -393,6 +417,31 @@ mod test { .all(|(x, y)| (*x - *y).abs() < epsilon)); } + #[test] + fn mul_csr_vec_ndarray() { + let indptr: &[usize] = &[0, 3, 3, 5, 6, 7]; + let indices: &[usize] = &[1, 2, 3, 2, 3, 4, 4]; + let data: &[f64] = &[ + 0.75672424, 0.1649078, 0.30140296, 0.10358244, 0.6283315, + 0.39244208, 0.57202407, + ]; + + let mat = CsMatView::new((5, 5), indptr, indices, data); + let vec = arr1(&[0.1f64, 0.2, -0.1, 0.3, 0.9]); + let mut res_vec = Array::zeros(5); + mul_acc_mat_vec_csr(mat, vec.view(), res_vec.view_mut()); + + let expected_output = + [0.22527496, 0., 0.17814121, 0.35319787, 0.51482166]; + + let epsilon = 1e-7; // TODO: get better values and increase precision + + assert!(res_vec + .iter() + .zip(expected_output.iter()) + .all(|(x, y)| (*x - *y).abs() < epsilon)); + } + #[test] fn mul_csr_csr() { let a = mat1(); diff --git a/src/sparse/vec.rs b/src/sparse/vec.rs index 7d472eb2..9ce80bcb 100644 --- a/src/sparse/vec.rs +++ b/src/sparse/vec.rs @@ -1,7 +1,7 @@ +use crate::dense_vector::{DenseVector, DenseVectorMut}; use crate::sparse::to_dense::assign_vector_to_dense; use crate::Ix1; use ndarray::Array; -use ndarray::{self, ArrayBase}; use std::cmp; use std::collections::HashSet; use std::convert::AsRef; @@ -280,129 +280,48 @@ where } } -impl<'a, N: 'a> IntoSparseVecIter<'a, N> for &'a [N] { - type IterType = Enumerate>; - - fn dim(&self) -> usize { - self.len() - } - - fn into_sparse_vec_iter(self) -> Enumerate> { - self.iter().enumerate() - } - - fn is_dense(&self) -> bool { - true - } - - fn index(self, idx: usize) -> &'a N { - &self[idx] - } -} - -impl<'a, N: 'a> IntoSparseVecIter<'a, N> for &'a Vec { - type IterType = Enumerate>; - - fn dim(&self) -> usize { - self.len() - } - - fn into_sparse_vec_iter(self) -> Enumerate> { - self.iter().enumerate() - } - - fn is_dense(&self) -> bool { - true - } - - fn index(self, idx: usize) -> &'a N { - &self[idx] - } -} - -impl<'a, N: 'a, S> IntoSparseVecIter<'a, N> for &'a ArrayBase +impl<'a, N: 'a, V: ?Sized> IntoSparseVecIter<'a, N> for &'a V where - S: ndarray::Data, + V: DenseVector, { - type IterType = Enumerate>; - - fn dim(&self) -> usize { - self.shape()[0] + // FIXME we want + // type IterType = impl Iterator + #[allow(clippy::type_complexity)] + type IterType = std::iter::Map< + std::iter::Zip, std::ops::Range>, + fn((&'a V, usize)) -> (usize, &'a N), + >; + + #[inline(always)] + fn into_sparse_vec_iter(self) -> Self::IterType { + let n = DenseVector::dim(self); + // FIXME since it's not possible to have an existential type as an + // associated type yet, I'm using a trick to send the necessary + // context to a plain function, which enables specifying the type + // Needless to say, this needs to go when it's no longer necessary + #[inline(always)] + fn hack_instead_of_closure(vi: (&V, usize)) -> (usize, &N) + where + V: DenseVector, + { + (vi.1, vi.0.index(vi.1)) + } + std::iter::repeat(self) + .zip(0..n) + .map(hack_instead_of_closure) } - fn into_sparse_vec_iter( - self, - ) -> Enumerate> { - self.iter().enumerate() + fn dim(&self) -> usize { + DenseVector::dim(*self) } fn is_dense(&self) -> bool { true } + #[inline(always)] fn index(self, idx: usize) -> &'a N { - &self[[idx]] - } -} - -/// A trait for types representing dense vectors, useful for -/// defining a fast sparse-dense dot product. -pub trait DenseVector { - /// The dimension of the vector - fn dim(&self) -> usize; - - /// Random access to an element in the vector. - /// - /// # Panics - /// - /// If the index is out of bounds - fn index(&self, idx: usize) -> &N; -} - -impl<'a, N: 'a> DenseVector for &'a [N] { - fn dim(&self) -> usize { - self.len() - } - - #[inline] - fn index(&self, idx: usize) -> &N { - &self[idx] - } -} - -impl DenseVector for Vec { - fn dim(&self) -> usize { - self.len() - } - - #[inline] - fn index(&self, idx: usize) -> &N { - &self[idx] - } -} - -impl<'a, N: 'a> DenseVector for &'a Vec { - fn dim(&self) -> usize { - self.len() - } - - #[inline] - fn index(&self, idx: usize) -> &N { - &self[idx] - } -} - -impl DenseVector for ArrayBase -where - S: ndarray::Data, -{ - fn dim(&self) -> usize { - self.shape()[0] - } - - #[inline] - fn index(&self, idx: usize) -> &N { - &self[[idx]] + DenseVector::index(self, idx) } } @@ -957,7 +876,7 @@ where /// If the dimension of the vectors do not match. pub fn dot_dense(&self, rhs: T) -> N where - T: DenseVector, + T: DenseVector, N: Num + Copy + Sum, { assert_eq!(self.dim(), rhs.dim()); @@ -1019,12 +938,17 @@ where } /// Fill a dense vector with our values - pub fn scatter(&self, out: &mut [N]) + // FIXME I'm uneasy with this &mut V, can't I get rid of it with more + // trait magic? I would probably need to define what a mutable view is... + // But it's valuable. But I cannot find a way with the current trait system. + // Would probably require something link existential lifetimes. + pub fn scatter(&self, out: &mut V) where N: Clone, + V: DenseVectorMut + ?Sized, { for (ind, val) in self.iter() { - out[ind] = val.clone(); + *out.index_mut(ind) = val.clone(); } } @@ -1913,6 +1837,20 @@ mod test { assert_eq!(vector, CsVec::new(4, vec![1, 2, 3], vec![0_i32, 1, 2])); } + #[test] + fn scatter() { + let vector = CsVec::new(4, vec![1, 2, 3], vec![1_i32, 3, 4]); + let mut res = vec![0; 4]; + vector.scatter(&mut res); + assert_eq!(res, &[0, 1, 3, 4]); + let mut res = Array::zeros(4); + vector.scatter(&mut res); + assert_eq!(res, ndarray::arr1(&[0, 1, 3, 4])); + let res: &mut [i32] = &mut [0; 4]; + vector.scatter(res); + assert_eq!(res, &[0, 1, 3, 4]); + } + #[cfg(feature = "approx")] mod approx { use crate::*;