From fc0a0c0d00285c3c1fa85e496b221b49d5fe6635 Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Wed, 20 Jan 2021 23:30:14 +0100 Subject: [PATCH 1/7] Support generic dense vectors in more APIs Lots of APIs were using slices as their input, particularly when using output buffers. However, this was not a good fit for linear algebra, since consumers of the APIs are more likely to be using eg an array from ndarray. Here we extend the `DenseVector` trait with a mutable version to be able to use it on output parameters. Since these are sealed traits we should be able to add unsafe indexing if necessary without a breaking change. It should also be possible to support eg nalgebra in the future without too much trouble. This should improve the situation discussed in #93, though it's probably not done yet. As suggested by @mulimoen, the `index` and `index_mut` implementations hint as `#[inline(always)]` as we want them to be zero-cost abstractions, and it can be critical to have them inlined to allow the compiler to remove bounds checks when possible. --- src/dense_vector.rs | 125 ++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/sparse/linalg/trisolve.rs | 100 +++++++++++++-------------- src/sparse/prod.rs | 81 +++++++++++++++++----- src/sparse/vec.rs | 77 ++++----------------- 5 files changed, 255 insertions(+), 129 deletions(-) create mode 100644 src/dense_vector.rs diff --git a/src/dense_vector.rs b/src/dense_vector.rs new file mode 100644 index 00000000..51b49697 --- /dev/null +++ b/src/dense_vector.rs @@ -0,0 +1,125 @@ +use crate::Ix1; +use ndarray::{self, ArrayBase}; + +/// A trait for types representing dense vectors, useful for expressing +/// algorithms such as sparse-dense dot product, or linear solves. +pub trait DenseVector { + /// The dimension of the vector + fn dim(&self) -> usize; + + /// Random access to an element in the vector. + /// + /// # Panics + /// + /// If the index is out of bounds + fn index(&self, idx: usize) -> &N; +} + +impl<'a, N: 'a> DenseVector for &'a [N] { + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } +} + +impl<'a, N: 'a> DenseVector for &'a mut [N] { + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } +} + +impl DenseVector for Vec { + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } +} + +impl<'a, N: 'a> DenseVector for &'a Vec { + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } +} + +impl<'a, N: 'a> DenseVector for &'a mut Vec { + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } +} + +impl DenseVector for ArrayBase +where + S: ndarray::Data, +{ + fn dim(&self) -> usize { + self.shape()[0] + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[[idx]] + } +} + +pub trait DenseVectorMut: DenseVector { + /// Random mutable access to an element in the vector. + /// + /// # Panics + /// + /// If the index is out of bounds + fn index_mut(&mut self, idx: usize) -> &mut N; +} + +impl<'a, N: 'a> DenseVectorMut for &'a mut [N] { + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[idx] + } +} + +impl DenseVectorMut for Vec { + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[idx] + } +} + +impl<'a, N: 'a> DenseVectorMut for &'a mut Vec { + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[idx] + } +} + +impl DenseVectorMut for ArrayBase +where + S: ndarray::DataMut, +{ + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[[idx]] + } +} diff --git a/src/lib.rs b/src/lib.rs index d5e8de22..3e42355c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,6 +75,7 @@ assert_eq!(a, b.to_csc()); */ pub mod array_backend; +mod dense_vector; pub mod errors; pub mod indexing; #[cfg(not(miri))] diff --git a/src/sparse/linalg/trisolve.rs b/src/sparse/linalg/trisolve.rs index ea7be741..b188042f 100644 --- a/src/sparse/linalg/trisolve.rs +++ b/src/sparse/linalg/trisolve.rs @@ -1,19 +1,18 @@ +use crate::dense_vector::DenseVectorMut; use crate::errors::{LinalgError, SingularMatrixInfo}; use crate::indexing::SpIndex; -use crate::sparse::vec; use crate::sparse::CsMatViewI; use crate::sparse::CsVecViewI; use crate::stack::{self, DStack, StackVal}; use num_traits::Num; /// Sparse triangular solves -use std::ops::IndexMut; -fn check_solver_dimensions( +fn check_solver_dimensions( lower_tri_mat: &CsMatViewI, rhs: &V, ) where N: Copy + Num, - V: vec::VecDim, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { @@ -33,17 +32,17 @@ fn check_solver_dimensions( /// /// This solve does not assume the input matrix to actually be /// triangular, instead it ignores the upper triangular part. -pub fn lsolve_csr_dense_rhs( +pub fn lsolve_csr_dense_rhs( lower_tri_mat: CsMatViewI, - rhs: &mut V, + mut rhs: V, ) -> Result<(), LinalgError> where - N: Copy + Num, - V: IndexMut + vec::VecDim, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { - check_solver_dimensions(&lower_tri_mat, rhs); + check_solver_dimensions(&lower_tri_mat, &rhs); if !lower_tri_mat.is_csr() { panic!("Storage mismatch"); } @@ -57,7 +56,7 @@ where for (row_ind, row) in lower_tri_mat.outer_iterator().enumerate() { let mut diag_val = N::zero(); - let mut x = rhs[row_ind]; + let mut x = *rhs.index(row_ind); for (col_ind, &val) in row.iter() { if col_ind == row_ind { diag_val = val; @@ -66,7 +65,7 @@ where if col_ind > row_ind { continue; } - x = x - val * rhs[col_ind]; + x -= val * *rhs.index(col_ind); } if diag_val == N::zero() { return Err(LinalgError::SingularMatrix(SingularMatrixInfo { @@ -74,7 +73,7 @@ where reason: "diagonal element is 0", })); } - rhs[row_ind] = x / diag_val; + *rhs.index_mut(row_ind) = x / diag_val; } Ok(()) } @@ -89,17 +88,17 @@ where /// is the diagonal element (thus actual sorted lower triangular matrices work /// best). Otherwise, logarithmic search for the diagonal element /// has to be performed for each column. -pub fn lsolve_csc_dense_rhs( +pub fn lsolve_csc_dense_rhs( lower_tri_mat: CsMatViewI, - rhs: &mut V, + mut rhs: V, ) -> Result<(), LinalgError> where - N: Copy + Num, - V: IndexMut + vec::VecDim, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { - check_solver_dimensions(&lower_tri_mat, rhs); + check_solver_dimensions(&lower_tri_mat, &rhs); if !lower_tri_mat.is_csc() { panic!("Storage mismatch"); } @@ -113,18 +112,19 @@ where // L_1_1 x1 = b_1 - x0*l_1_0 for (col_ind, col) in lower_tri_mat.outer_iterator().enumerate() { - lspsolve_csc_process_col(col, col_ind, rhs)?; + lspsolve_csc_process_col(col, col_ind, &mut rhs)?; } Ok(()) } -fn lspsolve_csc_process_col( +fn lspsolve_csc_process_col( col: CsVecViewI, col_ind: usize, rhs: &mut V, ) -> Result<(), LinalgError> where - V: vec::VecDim + IndexMut, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut, I: SpIndex, { if let Some(&diag_val) = col.get(col_ind) { @@ -134,15 +134,14 @@ where reason: "diagonal element is a numeric 0", })); } - let b = rhs[col_ind]; + let b = *rhs.index(col_ind); let x = b / diag_val; - rhs[col_ind] = x; + *rhs.index_mut(col_ind) = x; for (row_ind, &val) in col.iter() { if row_ind <= col_ind { continue; } - let b = rhs[row_ind]; - rhs[row_ind] = b - val * x; + *rhs.index_mut(row_ind) -= val * x; } } else { return Err(LinalgError::SingularMatrix(SingularMatrixInfo { @@ -163,17 +162,17 @@ where /// is the diagonal element (thus actual sorted lower triangular matrices work /// best). Otherwise, logarithmic search for the diagonal element /// has to be performed for each column. -pub fn usolve_csc_dense_rhs( +pub fn usolve_csc_dense_rhs( upper_tri_mat: CsMatViewI, - rhs: &mut V, + mut rhs: V, ) -> Result<(), LinalgError> where - N: Copy + Num, - V: IndexMut + vec::VecDim, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { - check_solver_dimensions(&upper_tri_mat, rhs); + check_solver_dimensions(&upper_tri_mat, &rhs); if !upper_tri_mat.is_csc() { panic!("Storage mismatch"); } @@ -194,15 +193,14 @@ where reason: "diagonal element is a numeric 0", })); } - let b = rhs[col_ind]; + let b = *rhs.index(col_ind); let x = b / diag_val; - rhs[col_ind] = x; + *rhs.index_mut(col_ind) = x; for (row_ind, &val) in col.iter() { if row_ind >= col_ind { continue; } - let b = rhs[row_ind]; - rhs[row_ind] = b - val * x; + *rhs.index_mut(row_ind) -= val * x; } } else { return Err(LinalgError::SingularMatrix(SingularMatrixInfo { @@ -222,17 +220,17 @@ where /// /// This solve does not assume the input matrix to actually be /// triangular, instead it ignores the upper triangular part. -pub fn usolve_csr_dense_rhs( +pub fn usolve_csr_dense_rhs( upper_tri_mat: CsMatViewI, - rhs: &mut V, + mut rhs: V, ) -> Result<(), LinalgError> where - N: Copy + Num, - V: IndexMut + vec::VecDim, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { - check_solver_dimensions(&upper_tri_mat, rhs); + check_solver_dimensions(&upper_tri_mat, &rhs); if !upper_tri_mat.is_csr() { panic!("Storage mismatch"); } @@ -245,7 +243,7 @@ where // x0 = (b_0 - u_0_1^T.x_1) / u_0_0 for (row_ind, row) in upper_tri_mat.outer_iterator().enumerate().rev() { let mut diag_val = N::zero(); - let mut x = rhs[row_ind]; + let mut x = *rhs.index(row_ind); for (col_ind, &val) in row.iter() { if col_ind == row_ind { diag_val = val; @@ -254,7 +252,7 @@ where if col_ind < row_ind { continue; } - x = x - val * rhs[col_ind]; + x -= val * *rhs.index(col_ind); } if diag_val == N::zero() { return Err(LinalgError::SingularMatrix(SingularMatrixInfo { @@ -262,7 +260,7 @@ where reason: "diagonal element is a numeric 0", })); } - rhs[row_ind] = x / diag_val; + *rhs.index_mut(row_ind) = x / diag_val; } Ok(()) } @@ -289,15 +287,16 @@ where /// * if dstack is not empty /// * if `w_workspace` is not of length n /// -pub fn lsolve_csc_sparse_rhs( +pub fn lsolve_csc_sparse_rhs( lower_tri_mat: CsMatViewI, rhs: CsVecViewI, dstack: &mut DStack>, - x_workspace: &mut [N], + mut x_workspace: V, visited: &mut [bool], ) -> Result<(), LinalgError> where - N: Copy + Num, + N: Copy + Num + std::ops::SubAssign, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { @@ -310,7 +309,7 @@ where dstack.is_left_empty() && dstack.is_right_empty(), "dstack should be empty" ); - assert!(x_workspace.len() == n, "x should be of len n"); + assert!(x_workspace.dim() == n, "x should be of len n"); // the solve works out the sparsity of the solution using depth first // search on the matrix's graph @@ -353,11 +352,11 @@ where } // solve for the non-zero values into dense workspace - rhs.scatter(x_workspace); + rhs.scatter(&mut x_workspace); for &ind in dstack.iter_right().map(stack::extract_stack_val) { println!("ind: {}", ind); let col = lower_tri_mat.outer_view(ind).expect("ind not in bounds"); - lspsolve_csc_process_col(col, ind, x_workspace)?; + lspsolve_csc_process_col(col, ind, &mut x_workspace)?; } Ok(()) } @@ -367,6 +366,7 @@ mod test { use crate::sparse::{CsMat, CsVec}; use crate::stack::{self, DStack}; + use ndarray::arr1; use std::collections::HashSet; #[test] @@ -380,11 +380,11 @@ mod test { vec![0, 1, 0, 2], vec![1, 2, 1, 1], ); - let b = vec![3, 2, 4]; + let b = arr1(&[3, 2, 4]); let mut x = b.clone(); - super::lsolve_csr_dense_rhs(l.view(), &mut x).unwrap(); - assert_eq!(x, vec![3, 1, 1]); + super::lsolve_csr_dense_rhs(l.view(), x.view_mut()).unwrap(); + assert_eq!(x, arr1(&[3, 1, 1])); } #[test] diff --git a/src/sparse/prod.rs b/src/sparse/prod.rs index 12cbb82d..409c557a 100644 --- a/src/sparse/prod.rs +++ b/src/sparse/prod.rs @@ -1,8 +1,8 @@ +use crate::dense_vector::{DenseVector, DenseVectorMut}; use crate::indexing::SpIndex; use crate::sparse::compressed::SpMatView; ///! Sparse matrix product use crate::sparse::prelude::*; -use crate::sparse::vec::DenseVector; use crate::Ix2; use ndarray::{ArrayView, ArrayViewMut, Axis}; use num_traits::Num; @@ -47,18 +47,19 @@ where /// Multiply a sparse CSC matrix with a dense vector and accumulate the result /// into another dense vector -pub fn mul_acc_mat_vec_csc( +pub fn mul_acc_mat_vec_csc( mat: CsMatViewI, in_vec: V, - res_vec: &mut [N], + mut res_vec: VRes, ) where - N: Num + Copy, + N: Num + Copy + std::ops::AddAssign, I: SpIndex, Iptr: SpIndex, V: DenseVector, + VRes: DenseVectorMut, { let mat = mat.view(); - if mat.cols() != in_vec.dim() || mat.rows() != res_vec.len() { + if mat.cols() != in_vec.dim() || mat.rows() != res_vec.dim() { panic!("Dimension mismatch"); } if !mat.is_csc() { @@ -69,24 +70,25 @@ pub fn mul_acc_mat_vec_csc( let multiplier = in_vec.index(col_ind); for (row_ind, &value) in vec.iter() { // TODO: unsafe access to value? needs bench - res_vec[row_ind] = res_vec[row_ind] + *multiplier * value; + *res_vec.index_mut(row_ind) += *multiplier * value; } } } /// Multiply a sparse CSR matrix with a dense vector and accumulate the result /// into another dense vector -pub fn mul_acc_mat_vec_csr( +pub fn mul_acc_mat_vec_csr( mat: CsMatViewI, in_vec: V, - res_vec: &mut [N], + mut res_vec: VRes, ) where - N: Num + Copy, + N: Num + Copy + std::ops::AddAssign, I: SpIndex, Iptr: SpIndex, V: DenseVector, + VRes: DenseVectorMut, { - if mat.cols() != in_vec.dim() || mat.rows() != res_vec.len() { + if mat.cols() != in_vec.dim() || mat.rows() != res_vec.dim() { panic!("Dimension mismatch"); } if !mat.is_csr() { @@ -94,13 +96,10 @@ pub fn mul_acc_mat_vec_csr( } for (row_ind, vec) in mat.outer_iterator().enumerate() { - // this unwrap is ok because we did the check before to ensure - // mat.row() == res_vec.len() and now the row_ind is within the - // range of [0, mat.row). So it should be safe. - let tv = res_vec.get_mut(row_ind).unwrap(); + let tv = res_vec.index_mut(row_ind); for (col_ind, &value) in vec.iter() { // TODO: unsafe access to value? needs bench - *tv = *tv + *in_vec.index(col_ind) * value; + *tv += *in_vec.index(col_ind) * value; } } } @@ -328,7 +327,7 @@ mod test { mat_dense2, }; use ndarray::linalg::Dot; - use ndarray::{arr2, s, Array, Array2, Dimension, ShapeBuilder}; + use ndarray::{arr1, arr2, s, Array, Array2, Dimension, ShapeBuilder}; #[test] fn test_csvec_dot_by_binary_search() { @@ -368,6 +367,31 @@ mod test { .all(|(x, y)| (*x - *y).abs() < epsilon)); } + #[test] + fn mul_csc_vec_ndarray() { + let indptr: &[usize] = &[0, 2, 4, 5, 6, 7]; + let indices: &[usize] = &[2, 3, 3, 4, 2, 1, 3]; + let data: &[f64] = &[ + 0.35310881, 0.42380633, 0.28035896, 0.58082095, 0.53350123, + 0.88132896, 0.72527863, + ]; + + let mat = CsMatView::new_csc((5, 5), indptr, indices, data); + let vector = arr1(&[0.1f64, 0.2, -0.1, 0.3, 0.9]); + let mut res_vec = Array::zeros(5); + mul_acc_mat_vec_csc(mat, vector, res_vec.view_mut()); + + let expected_output = + vec![0., 0.26439869, -0.01803924, 0.75120319, 0.11616419]; + + let epsilon = 1e-7; // TODO: get better values and increase precision + + assert!(res_vec + .iter() + .zip(expected_output.iter()) + .all(|(x, y)| (*x - *y).abs() < epsilon)); + } + #[test] fn mul_csr_vec() { let indptr: &[usize] = &[0, 3, 3, 5, 6, 7]; @@ -393,6 +417,31 @@ mod test { .all(|(x, y)| (*x - *y).abs() < epsilon)); } + #[test] + fn mul_csr_vec_ndarray() { + let indptr: &[usize] = &[0, 3, 3, 5, 6, 7]; + let indices: &[usize] = &[1, 2, 3, 2, 3, 4, 4]; + let data: &[f64] = &[ + 0.75672424, 0.1649078, 0.30140296, 0.10358244, 0.6283315, + 0.39244208, 0.57202407, + ]; + + let mat = CsMatView::new((5, 5), indptr, indices, data); + let vec = arr1(&[0.1f64, 0.2, -0.1, 0.3, 0.9]); + let mut res_vec = Array::zeros(5); + mul_acc_mat_vec_csr(mat, vec.view(), res_vec.view_mut()); + + let expected_output = + [0.22527496, 0., 0.17814121, 0.35319787, 0.51482166]; + + let epsilon = 1e-7; // TODO: get better values and increase precision + + assert!(res_vec + .iter() + .zip(expected_output.iter()) + .all(|(x, y)| (*x - *y).abs() < epsilon)); + } + #[test] fn mul_csr_csr() { let a = mat1(); diff --git a/src/sparse/vec.rs b/src/sparse/vec.rs index 7d472eb2..7342fc64 100644 --- a/src/sparse/vec.rs +++ b/src/sparse/vec.rs @@ -1,3 +1,4 @@ +use crate::dense_vector::{DenseVector, DenseVectorMut}; use crate::sparse::to_dense::assign_vector_to_dense; use crate::Ix1; use ndarray::Array; @@ -345,67 +346,6 @@ where } } -/// A trait for types representing dense vectors, useful for -/// defining a fast sparse-dense dot product. -pub trait DenseVector { - /// The dimension of the vector - fn dim(&self) -> usize; - - /// Random access to an element in the vector. - /// - /// # Panics - /// - /// If the index is out of bounds - fn index(&self, idx: usize) -> &N; -} - -impl<'a, N: 'a> DenseVector for &'a [N] { - fn dim(&self) -> usize { - self.len() - } - - #[inline] - fn index(&self, idx: usize) -> &N { - &self[idx] - } -} - -impl DenseVector for Vec { - fn dim(&self) -> usize { - self.len() - } - - #[inline] - fn index(&self, idx: usize) -> &N { - &self[idx] - } -} - -impl<'a, N: 'a> DenseVector for &'a Vec { - fn dim(&self) -> usize { - self.len() - } - - #[inline] - fn index(&self, idx: usize) -> &N { - &self[idx] - } -} - -impl DenseVector for ArrayBase -where - S: ndarray::Data, -{ - fn dim(&self) -> usize { - self.shape()[0] - } - - #[inline] - fn index(&self, idx: usize) -> &N { - &self[[idx]] - } -} - /// An iterator over the non zeros of either of two vector iterators, ordered, /// such that the sum of the vectors may be computed pub struct NnzOrZip<'a, Ite1, Ite2, N1: 'a, N2: 'a> @@ -1019,12 +959,12 @@ where } /// Fill a dense vector with our values - pub fn scatter(&self, out: &mut [N]) + pub fn scatter>(&self, out: &mut V) where N: Clone, { for (ind, val) in self.iter() { - out[ind] = val.clone(); + *out.index_mut(ind) = val.clone(); } } @@ -1913,6 +1853,17 @@ mod test { assert_eq!(vector, CsVec::new(4, vec![1, 2, 3], vec![0_i32, 1, 2])); } + #[test] + fn scatter() { + let vector = CsVec::new(4, vec![1, 2, 3], vec![1_i32, 3, 4]); + let mut res = vec![0; 4]; + vector.scatter(&mut res); + assert_eq!(res, &[0, 1, 3, 4]); + let mut res = Array::zeros(4); + vector.scatter(&mut res); + assert_eq!(res, ndarray::arr1(&[0, 1, 3, 4])); + } + #[cfg(feature = "approx")] mod approx { use crate::*; From 2c542fae89124b1a3e6c2137e9d89c6ba9d37fac Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Wed, 27 Jan 2021 21:38:04 +0100 Subject: [PATCH 2/7] Implement DenseVectorMut on slices as well --- src/dense_vector.rs | 18 ++++++++++++++++++ src/sparse/linalg/trisolve.rs | 8 ++++++-- src/sparse/vec.rs | 6 +++++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/dense_vector.rs b/src/dense_vector.rs index 51b49697..a9ff0a7d 100644 --- a/src/dense_vector.rs +++ b/src/dense_vector.rs @@ -15,6 +15,17 @@ pub trait DenseVector { fn index(&self, idx: usize) -> &N; } +impl DenseVector for [N] { + fn dim(&self) -> usize { + self.len() + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[idx] + } +} + impl<'a, N: 'a> DenseVector for &'a [N] { fn dim(&self) -> usize { self.len() @@ -93,6 +104,13 @@ pub trait DenseVectorMut: DenseVector { fn index_mut(&mut self, idx: usize) -> &mut N; } +impl<'a, N: 'a> DenseVectorMut for [N] { + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[idx] + } +} + impl<'a, N: 'a> DenseVectorMut for &'a mut [N] { #[inline(always)] fn index_mut(&mut self, idx: usize) -> &mut N { diff --git a/src/sparse/linalg/trisolve.rs b/src/sparse/linalg/trisolve.rs index b188042f..caaa399f 100644 --- a/src/sparse/linalg/trisolve.rs +++ b/src/sparse/linalg/trisolve.rs @@ -1,4 +1,4 @@ -use crate::dense_vector::DenseVectorMut; +use crate::dense_vector::{DenseVector, DenseVectorMut}; use crate::errors::{LinalgError, SingularMatrixInfo}; use crate::indexing::SpIndex; use crate::sparse::CsMatViewI; @@ -12,7 +12,7 @@ fn check_solver_dimensions( rhs: &V, ) where N: Copy + Num, - V: DenseVectorMut, + V: DenseVector + ?Sized, I: SpIndex, Iptr: SpIndex, { @@ -403,6 +403,10 @@ mod test { super::lsolve_csc_dense_rhs(l.view(), &mut x).unwrap(); assert_eq!(x, vec![3, 1, 1]); + + let x: &mut [i32] = &mut [3, 5, 3]; + super::lsolve_csc_dense_rhs(l.view(), &mut x[..]).unwrap(); + assert_eq!(x, &[3, 1, 1]); } #[test] diff --git a/src/sparse/vec.rs b/src/sparse/vec.rs index 7342fc64..e93a37b3 100644 --- a/src/sparse/vec.rs +++ b/src/sparse/vec.rs @@ -959,9 +959,10 @@ where } /// Fill a dense vector with our values - pub fn scatter>(&self, out: &mut V) + pub fn scatter(&self, out: &mut V) where N: Clone, + V: DenseVectorMut + ?Sized, { for (ind, val) in self.iter() { *out.index_mut(ind) = val.clone(); @@ -1862,6 +1863,9 @@ mod test { let mut res = Array::zeros(4); vector.scatter(&mut res); assert_eq!(res, ndarray::arr1(&[0, 1, 3, 4])); + let res: &mut [i32] = &mut [0; 4]; + vector.scatter(res); + assert_eq!(res, &[0, 1, 3, 4]); } #[cfg(feature = "approx")] From 16959be867eb042cd72839e01c172e2af0d0551d Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Wed, 27 Jan 2021 22:04:13 +0100 Subject: [PATCH 3/7] Use associated types for DenseVector --- src/dense_vector.rs | 121 +++++++++++++++++++++++++++++----- src/sparse/linalg/trisolve.rs | 14 ++-- src/sparse/prod.rs | 8 +-- src/sparse/vec.rs | 4 +- 4 files changed, 118 insertions(+), 29 deletions(-) diff --git a/src/dense_vector.rs b/src/dense_vector.rs index a9ff0a7d..72ec0493 100644 --- a/src/dense_vector.rs +++ b/src/dense_vector.rs @@ -1,9 +1,13 @@ use crate::Ix1; use ndarray::{self, ArrayBase}; +use num_traits::identities::Zero; /// A trait for types representing dense vectors, useful for expressing /// algorithms such as sparse-dense dot product, or linear solves. -pub trait DenseVector { +pub trait DenseVector { + type Owned; + type Scalar; + /// The dimension of the vector fn dim(&self) -> usize; @@ -12,10 +16,19 @@ pub trait DenseVector { /// # Panics /// /// If the index is out of bounds - fn index(&self, idx: usize) -> &N; + fn index(&self, idx: usize) -> &Self::Scalar; + + /// Create an owned version of this dense vector type, filled with zeros + fn zeros(dim: usize) -> Self::Owned; + + /// Copies this vector into an owned version + fn to_owned(&self) -> Self::Owned; } -impl DenseVector for [N] { +impl DenseVector for [N] { + type Owned = Vec; + type Scalar = N; + fn dim(&self) -> usize { self.len() } @@ -24,9 +37,20 @@ impl DenseVector for [N] { fn index(&self, idx: usize) -> &N { &self[idx] } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } } -impl<'a, N: 'a> DenseVector for &'a [N] { +impl<'a, N: 'a + Zero + Clone> DenseVector for &'a [N] { + type Owned = Vec; + type Scalar = N; + fn dim(&self) -> usize { self.len() } @@ -35,9 +59,20 @@ impl<'a, N: 'a> DenseVector for &'a [N] { fn index(&self, idx: usize) -> &N { &self[idx] } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } } -impl<'a, N: 'a> DenseVector for &'a mut [N] { +impl<'a, N: 'a + Zero + Clone> DenseVector for &'a mut [N] { + type Owned = Vec; + type Scalar = N; + fn dim(&self) -> usize { self.len() } @@ -46,9 +81,20 @@ impl<'a, N: 'a> DenseVector for &'a mut [N] { fn index(&self, idx: usize) -> &N { &self[idx] } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } } -impl DenseVector for Vec { +impl DenseVector for Vec { + type Owned = Vec; + type Scalar = N; + fn dim(&self) -> usize { self.len() } @@ -57,9 +103,20 @@ impl DenseVector for Vec { fn index(&self, idx: usize) -> &N { &self[idx] } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } } -impl<'a, N: 'a> DenseVector for &'a Vec { +impl<'a, N: 'a + Zero + Clone> DenseVector for &'a Vec { + type Owned = Vec; + type Scalar = N; + fn dim(&self) -> usize { self.len() } @@ -68,9 +125,20 @@ impl<'a, N: 'a> DenseVector for &'a Vec { fn index(&self, idx: usize) -> &N { &self[idx] } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } } -impl<'a, N: 'a> DenseVector for &'a mut Vec { +impl<'a, N: 'a + Zero + Clone> DenseVector for &'a mut Vec { + type Owned = Vec; + type Scalar = N; + fn dim(&self) -> usize { self.len() } @@ -79,12 +147,24 @@ impl<'a, N: 'a> DenseVector for &'a mut Vec { fn index(&self, idx: usize) -> &N { &self[idx] } + + fn zeros(dim: usize) -> Self::Owned { + vec![N::zero(); dim] + } + + fn to_owned(&self) -> Self::Owned { + self.to_vec() + } } -impl DenseVector for ArrayBase +impl DenseVector for ArrayBase where S: ndarray::Data, + N: Zero + Clone, { + type Owned = ndarray::Array; + type Scalar = N; + fn dim(&self) -> usize { self.shape()[0] } @@ -93,48 +173,57 @@ where fn index(&self, idx: usize) -> &N { &self[[idx]] } + + fn zeros(dim: usize) -> Self::Owned { + ndarray::Array::zeros(dim) + } + + fn to_owned(&self) -> Self::Owned { + self.to_owned() + } } -pub trait DenseVectorMut: DenseVector { +pub trait DenseVectorMut: DenseVector { /// Random mutable access to an element in the vector. /// /// # Panics /// /// If the index is out of bounds - fn index_mut(&mut self, idx: usize) -> &mut N; + fn index_mut(&mut self, idx: usize) -> &mut Self::Scalar; } -impl<'a, N: 'a> DenseVectorMut for [N] { +impl<'a, N: 'a + Zero + Clone> DenseVectorMut for [N] { #[inline(always)] fn index_mut(&mut self, idx: usize) -> &mut N { &mut self[idx] } } -impl<'a, N: 'a> DenseVectorMut for &'a mut [N] { +impl<'a, N: 'a + Zero + Clone> DenseVectorMut for &'a mut [N] { #[inline(always)] fn index_mut(&mut self, idx: usize) -> &mut N { &mut self[idx] } } -impl DenseVectorMut for Vec { +impl DenseVectorMut for Vec { #[inline(always)] fn index_mut(&mut self, idx: usize) -> &mut N { &mut self[idx] } } -impl<'a, N: 'a> DenseVectorMut for &'a mut Vec { +impl<'a, N: 'a + Zero + Clone> DenseVectorMut for &'a mut Vec { #[inline(always)] fn index_mut(&mut self, idx: usize) -> &mut N { &mut self[idx] } } -impl DenseVectorMut for ArrayBase +impl DenseVectorMut for ArrayBase where S: ndarray::DataMut, + N: Zero + Clone, { #[inline(always)] fn index_mut(&mut self, idx: usize) -> &mut N { diff --git a/src/sparse/linalg/trisolve.rs b/src/sparse/linalg/trisolve.rs index caaa399f..093d5614 100644 --- a/src/sparse/linalg/trisolve.rs +++ b/src/sparse/linalg/trisolve.rs @@ -12,7 +12,7 @@ fn check_solver_dimensions( rhs: &V, ) where N: Copy + Num, - V: DenseVector + ?Sized, + V: DenseVector + ?Sized, I: SpIndex, Iptr: SpIndex, { @@ -38,7 +38,7 @@ pub fn lsolve_csr_dense_rhs( ) -> Result<(), LinalgError> where N: Copy + Num + std::ops::SubAssign, - V: DenseVectorMut, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { @@ -94,7 +94,7 @@ pub fn lsolve_csc_dense_rhs( ) -> Result<(), LinalgError> where N: Copy + Num + std::ops::SubAssign, - V: DenseVectorMut, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { @@ -124,7 +124,7 @@ fn lspsolve_csc_process_col( ) -> Result<(), LinalgError> where N: Copy + Num + std::ops::SubAssign, - V: DenseVectorMut, + V: DenseVectorMut, I: SpIndex, { if let Some(&diag_val) = col.get(col_ind) { @@ -168,7 +168,7 @@ pub fn usolve_csc_dense_rhs( ) -> Result<(), LinalgError> where N: Copy + Num + std::ops::SubAssign, - V: DenseVectorMut, + V: DenseVectorMut, I: SpIndex, Iptr: SpIndex, { @@ -226,7 +226,7 @@ pub fn usolve_csr_dense_rhs( ) -> Result<(), LinalgError> where N: Copy + Num + std::ops::SubAssign, - V: DenseVectorMut, + V: DenseVectorMut + DenseVector, I: SpIndex, Iptr: SpIndex, { @@ -296,7 +296,7 @@ pub fn lsolve_csc_sparse_rhs( ) -> Result<(), LinalgError> where N: Copy + Num + std::ops::SubAssign, - V: DenseVectorMut, + V: DenseVectorMut + DenseVector, I: SpIndex, Iptr: SpIndex, { diff --git a/src/sparse/prod.rs b/src/sparse/prod.rs index 409c557a..8dddcff7 100644 --- a/src/sparse/prod.rs +++ b/src/sparse/prod.rs @@ -55,8 +55,8 @@ pub fn mul_acc_mat_vec_csc( N: Num + Copy + std::ops::AddAssign, I: SpIndex, Iptr: SpIndex, - V: DenseVector, - VRes: DenseVectorMut, + V: DenseVector, + VRes: DenseVectorMut, { let mat = mat.view(); if mat.cols() != in_vec.dim() || mat.rows() != res_vec.dim() { @@ -85,8 +85,8 @@ pub fn mul_acc_mat_vec_csr( N: Num + Copy + std::ops::AddAssign, I: SpIndex, Iptr: SpIndex, - V: DenseVector, - VRes: DenseVectorMut, + V: DenseVector, + VRes: DenseVectorMut, { if mat.cols() != in_vec.dim() || mat.rows() != res_vec.dim() { panic!("Dimension mismatch"); diff --git a/src/sparse/vec.rs b/src/sparse/vec.rs index e93a37b3..fe765e3d 100644 --- a/src/sparse/vec.rs +++ b/src/sparse/vec.rs @@ -897,7 +897,7 @@ where /// If the dimension of the vectors do not match. pub fn dot_dense(&self, rhs: T) -> N where - T: DenseVector, + T: DenseVector, N: Num + Copy + Sum, { assert_eq!(self.dim(), rhs.dim()); @@ -962,7 +962,7 @@ where pub fn scatter(&self, out: &mut V) where N: Clone, - V: DenseVectorMut + ?Sized, + V: DenseVectorMut + ?Sized, { for (ind, val) in self.iter() { *out.index_mut(ind) = val.clone(); From d76db9961b26943729f0417073743056e8513259 Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Wed, 27 Jan 2021 22:16:21 +0100 Subject: [PATCH 4/7] Leverage DenseVector in permutation * dense product Now it's possible to call on ndarray data as well, with a single impl. --- src/sparse/permutation.rs | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/sparse/permutation.rs b/src/sparse/permutation.rs index 56ebffbd..4383363a 100644 --- a/src/sparse/permutation.rs +++ b/src/sparse/permutation.rs @@ -3,6 +3,7 @@ /// Both the permutation matrices and its inverse are stored use std::ops::{Deref, Mul}; +use crate::dense_vector::{DenseVector, DenseVectorMut}; use crate::indexing::SpIndex; use crate::sparse::{CompressedStorage, CsMatI, CsMatViewI}; @@ -251,21 +252,24 @@ where } } -impl<'a, 'b, N, I, IndStorage> Mul<&'a [N]> for &'b Permutation +impl<'b, V, I, IndStorage> Mul for &'b Permutation where IndStorage: 'b + Deref, - N: 'a + Copy, + V: DenseVector, + ::Owned: + DenseVectorMut + DenseVector::Scalar>, + ::Scalar: Clone, I: SpIndex, { - type Output = Vec; - fn mul(self, rhs: &'a [N]) -> Vec { - assert_eq!(self.dim, rhs.len()); - let mut res = rhs.to_vec(); + type Output = V::Owned; + fn mul(self, rhs: V) -> Self::Output { + assert_eq!(self.dim, rhs.dim()); + let mut res = rhs.to_owned(); match self.storage { Identity => res, FinitePerm { perm: ref p, .. } => { - for (pi, r) in p.iter().zip(res.iter_mut()) { - *r = rhs[pi.index_unchecked()]; + for (i, pi) in p.iter().enumerate() { + *res.index_mut(i) = rhs.index(pi.index_unchecked()).clone(); } res } @@ -273,6 +277,21 @@ where } } +impl Mul for Permutation +where + IndStorage: Deref, + V: DenseVector, + ::Owned: + DenseVectorMut + DenseVector::Scalar>, + ::Scalar: Clone, + I: SpIndex, +{ + type Output = V::Owned; + fn mul(self, rhs: V) -> Self::Output { + &self * rhs + } +} + /// Compute the square matrix resulting from the product P * A * P^T pub fn transform_mat_papt( mat: CsMatViewI, @@ -344,6 +363,10 @@ mod test { let y = &p * &x; assert_eq!(&y, &[2, 1, 3, 5, 4]); + + let x = ndarray::arr1(&[5, 1, 2, 3, 4]); + let y = p.view() * x.view(); + assert_eq!(y, ndarray::arr1(&[2, 1, 3, 5, 4])); } #[test] From 97e13433f3a89225891b9ad3f98f2216df7243e3 Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Wed, 3 Feb 2021 22:21:15 +0100 Subject: [PATCH 5/7] Expose the DenseVector trait publicly (but sealed) This way dependent crates will be able to express algorithms in terms of that trait. It will remain sealed initially, as we want to gain experience before committing on its API. --- src/dense_vector.rs | 25 ++++++++++++++++++++++++- src/lib.rs | 2 ++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/dense_vector.rs b/src/dense_vector.rs index 72ec0493..723e166f 100644 --- a/src/dense_vector.rs +++ b/src/dense_vector.rs @@ -4,7 +4,10 @@ use num_traits::identities::Zero; /// A trait for types representing dense vectors, useful for expressing /// algorithms such as sparse-dense dot product, or linear solves. -pub trait DenseVector { +/// +/// This trait is sealed, and cannot be implemented outside of the `sprs` +/// crate. +pub trait DenseVector: seal::Sealed { type Owned; type Scalar; @@ -183,6 +186,11 @@ where } } +/// Trait for dense vectors that can be modified, useful for expressing +/// algorithms which compute a resulting dense vector, such as solvers. +/// +/// This trait is sealed, and cannot be implemented outside of the `sprs` +/// crate. pub trait DenseVectorMut: DenseVector { /// Random mutable access to an element in the vector. /// @@ -230,3 +238,18 @@ where &mut self[[idx]] } } + +mod seal { + pub trait Sealed {} + + impl Sealed for [N] {} + impl<'a, N: 'a> Sealed for &'a [N] {} + impl<'a, N: 'a> Sealed for &'a mut [N] {} + impl Sealed for Vec {} + impl<'a, N: 'a> Sealed for &'a Vec {} + impl<'a, N: 'a> Sealed for &'a mut Vec {} + impl> Sealed + for ndarray::ArrayBase + { + } +} diff --git a/src/lib.rs b/src/lib.rs index 3e42355c..ddefaa0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,6 +101,8 @@ pub use crate::sparse::{ TriMatIter, TriMatView, TriMatViewI, TriMatViewMut, TriMatViewMutI, }; +pub use crate::dense_vector::{DenseVector, DenseVectorMut}; + pub use crate::sparse::symmetric::is_symmetric; pub use crate::sparse::permutation::{ From a06efbf82b87762d6d646d5240a37175a611dcd9 Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Wed, 3 Feb 2021 22:54:50 +0100 Subject: [PATCH 6/7] Use DenseVector in sprs-ldl --- sprs-ldl/src/lib.rs | 53 ++++++++++++++++++++------------- src/dense_vector.rs | 71 ++++++++++++++++++++++++++++++++++++++++++++ src/sparse/linalg.rs | 16 +++++----- 3 files changed, 113 insertions(+), 27 deletions(-) diff --git a/sprs-ldl/src/lib.rs b/sprs-ldl/src/lib.rs index e82e55db..175ac8d8 100644 --- a/sprs-ldl/src/lib.rs +++ b/sprs-ldl/src/lib.rs @@ -52,7 +52,6 @@ // Copyright, this License, and the Availability note are retained, // and a notice that the code was modified is included. use std::ops::Deref; -use std::ops::IndexMut; use num_traits::Num; @@ -61,6 +60,7 @@ use sprs::indexing::SpIndex; use sprs::linalg; use sprs::stack::DStack; use sprs::{is_symmetric, CsMatViewI, PermOwnedI, Permutation}; +use sprs::{DenseVector, DenseVectorMut}; use sprs::{FillInReduction, PermutationCheck, SymmetryCheck}; #[cfg(feature = "sprs_suitesparse_ldl")] @@ -380,18 +380,31 @@ impl LdlNumeric { } /// Solve the system A x = rhs - pub fn solve<'a, V>(&self, rhs: &V) -> Vec + /// + /// The type constraints look complicated, but they simply mean that + /// `rhs` should be interpretable as a dense vector, and we will return + /// a dense vector of a compatible type (but owned). + pub fn solve<'a, V>( + &self, + rhs: V, + ) -> <::Owned as DenseVector>::Owned where - N: 'a + Copy + Num, - V: Deref, + N: 'a + Copy + Num + std::ops::SubAssign + std::ops::DivAssign, + V: DenseVector, + ::Owned: DenseVectorMut + DenseVector, + for<'b> &'b ::Owned: DenseVector, + for<'b> &'b mut ::Owned: + DenseVectorMut + DenseVector, + <::Owned as DenseVector>::Owned: + DenseVectorMut + DenseVector, { - let mut x = &self.symbolic.perm * &rhs[..]; + let mut x = &self.symbolic.perm * rhs; let l = self.l(); ldl_lsolve(&l, &mut x); linalg::diag_solve(&self.diag, &mut x); ldl_ltsolve(&l, &mut x); let pinv = self.symbolic.perm.inv(); - &pinv * &x + &pinv * x } /// The diagonal factor D of the LDL^T decomposition @@ -579,34 +592,34 @@ where /// Triangular solve specialized on lower triangular matrices /// produced by ldlt (diagonal terms are omitted and assumed to be 1). -pub fn ldl_lsolve(l: &CsMatViewI, x: &mut V) +pub fn ldl_lsolve(l: &CsMatViewI, mut x: V) where - N: Clone + Copy + Num, + N: Clone + Copy + Num + std::ops::SubAssign, I: SpIndex, - V: IndexMut, + V: DenseVectorMut + DenseVector, { for (col_ind, vec) in l.outer_iterator().enumerate() { - let x_col = x[col_ind]; + let x_col = *x.index(col_ind); for (row_ind, &value) in vec.iter() { - x[row_ind] = x[row_ind] - value * x_col; + *x.index_mut(row_ind) -= value * x_col; } } } /// Triangular transposed solve specialized on lower triangular matrices /// produced by ldlt (diagonal terms are omitted and assumed to be 1). -pub fn ldl_ltsolve(l: &CsMatViewI, x: &mut V) +pub fn ldl_ltsolve(l: &CsMatViewI, mut x: V) where - N: Clone + Copy + Num, + N: Clone + Copy + Num + std::ops::SubAssign, I: SpIndex, - V: IndexMut, + V: DenseVectorMut + DenseVector, { for (outer_ind, vec) in l.outer_iterator().enumerate().rev() { - let mut x_outer = x[outer_ind]; + let mut x_outer = *x.index(outer_ind); for (inner_ind, &value) in vec.iter() { - x_outer = x_outer - value * x[inner_ind]; + x_outer -= value * *x.index(inner_ind); } - x[outer_ind] = x_outer; + *x.index_mut(outer_ind) = x_outer; } } @@ -838,15 +851,15 @@ mod test { vec![1., 2., 21., 6., 6., 2., 2., 8.], ); - let b = vec![9., 60., 18., 34.]; - let x0 = vec![1., 2., 3., 4.]; + let b = ndarray::arr1(&[9., 60., 18., 34.]); + let x0 = ndarray::arr1(&[1., 2., 3., 4.]); let ldlt = super::Ldl::new() .check_symmetry(super::SymmetryCheck::DontCheckSymmetry) .fill_in_reduction(super::FillInReduction::ReverseCuthillMcKee) .numeric(mat.view()) .unwrap(); - let x = ldlt.solve(&b); + let x = ldlt.solve(b.view()); assert_eq!(x, x0); } diff --git a/src/dense_vector.rs b/src/dense_vector.rs index 723e166f..afa1fc54 100644 --- a/src/dense_vector.rs +++ b/src/dense_vector.rs @@ -186,6 +186,58 @@ where } } +impl<'a, N, S> DenseVector for &'a ArrayBase +where + S: ndarray::Data, + N: 'a + Zero + Clone, +{ + type Owned = ndarray::Array; + type Scalar = N; + + fn dim(&self) -> usize { + self.shape()[0] + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[[idx]] + } + + fn zeros(dim: usize) -> Self::Owned { + ndarray::Array::zeros(dim) + } + + fn to_owned(&self) -> Self::Owned { + ArrayBase::to_owned(self) + } +} + +impl<'a, N, S> DenseVector for &'a mut ArrayBase +where + S: ndarray::Data, + N: 'a + Zero + Clone, +{ + type Owned = ndarray::Array; + type Scalar = N; + + fn dim(&self) -> usize { + self.shape()[0] + } + + #[inline(always)] + fn index(&self, idx: usize) -> &N { + &self[[idx]] + } + + fn zeros(dim: usize) -> Self::Owned { + ndarray::Array::zeros(dim) + } + + fn to_owned(&self) -> Self::Owned { + ArrayBase::to_owned(self) + } +} + /// Trait for dense vectors that can be modified, useful for expressing /// algorithms which compute a resulting dense vector, such as solvers. /// @@ -239,6 +291,17 @@ where } } +impl<'a, N, S> DenseVectorMut for &'a mut ArrayBase +where + S: ndarray::DataMut, + N: 'a + Zero + Clone, +{ + #[inline(always)] + fn index_mut(&mut self, idx: usize) -> &mut N { + &mut self[[idx]] + } +} + mod seal { pub trait Sealed {} @@ -252,4 +315,12 @@ mod seal { for ndarray::ArrayBase { } + impl<'a, N: 'a, S: ndarray::Data> Sealed + for &'a ndarray::ArrayBase + { + } + impl<'a, N: 'a, S: ndarray::Data> Sealed + for &'a mut ndarray::ArrayBase + { + } } diff --git a/src/sparse/linalg.rs b/src/sparse/linalg.rs index 15908b8d..b26cc12a 100644 --- a/src/sparse/linalg.rs +++ b/src/sparse/linalg.rs @@ -1,9 +1,9 @@ +use crate::{DenseVector, DenseVectorMut}; ///! Sparse linear algebra ///! ///! This module contains solvers for sparse linear systems. Currently ///! there are solver for sparse triangular systems and symmetric systems. use num_traits::Num; -use std::iter::IntoIterator; pub mod etree; pub mod ordering; @@ -12,13 +12,15 @@ pub mod trisolve; pub use self::ordering::reverse_cuthill_mckee; /// Diagonal solve -pub fn diag_solve<'a, N, I1, I2>(diag: I1, x: I2) +pub fn diag_solve<'a, N, V1, V2>(diag: V1, mut x: V2) where - N: 'a + Copy + Num, - I1: IntoIterator, - I2: IntoIterator, + N: 'a + Copy + Num + std::ops::DivAssign, + V1: DenseVector, + V2: DenseVectorMut + DenseVector, { - for (xv, dv) in x.into_iter().zip(diag.into_iter()) { - *xv = *xv / *dv; + let n = x.dim(); + assert_eq!(diag.dim(), n); + for i in 0..n { + *x.index_mut(i) /= *diag.index(i); } } From 3e1520336358c6350a8222d83aa08e4079b26119 Mon Sep 17 00:00:00 2001 From: Vincent Barrielle Date: Thu, 11 Feb 2021 23:06:32 +0100 Subject: [PATCH 7/7] Implement IntoSparseVecIter for V: DenseVector This removes some code duplication, and benchmarks show this does not affect performance. Also pointed to a place where I'm not satisfied with the current API, but where I don't see how to improve with the current rust's trait system. --- src/sparse/vec.rs | 87 +++++++++++++++++++---------------------------- 1 file changed, 35 insertions(+), 52 deletions(-) diff --git a/src/sparse/vec.rs b/src/sparse/vec.rs index fe765e3d..9ce80bcb 100644 --- a/src/sparse/vec.rs +++ b/src/sparse/vec.rs @@ -2,7 +2,6 @@ use crate::dense_vector::{DenseVector, DenseVectorMut}; use crate::sparse::to_dense::assign_vector_to_dense; use crate::Ix1; use ndarray::Array; -use ndarray::{self, ArrayBase}; use std::cmp; use std::collections::HashSet; use std::convert::AsRef; @@ -281,68 +280,48 @@ where } } -impl<'a, N: 'a> IntoSparseVecIter<'a, N> for &'a [N] { - type IterType = Enumerate>; - - fn dim(&self) -> usize { - self.len() - } - - fn into_sparse_vec_iter(self) -> Enumerate> { - self.iter().enumerate() - } - - fn is_dense(&self) -> bool { - true - } - - fn index(self, idx: usize) -> &'a N { - &self[idx] - } -} - -impl<'a, N: 'a> IntoSparseVecIter<'a, N> for &'a Vec { - type IterType = Enumerate>; - - fn dim(&self) -> usize { - self.len() - } - - fn into_sparse_vec_iter(self) -> Enumerate> { - self.iter().enumerate() - } - - fn is_dense(&self) -> bool { - true - } - - fn index(self, idx: usize) -> &'a N { - &self[idx] - } -} - -impl<'a, N: 'a, S> IntoSparseVecIter<'a, N> for &'a ArrayBase +impl<'a, N: 'a, V: ?Sized> IntoSparseVecIter<'a, N> for &'a V where - S: ndarray::Data, + V: DenseVector, { - type IterType = Enumerate>; - - fn dim(&self) -> usize { - self.shape()[0] + // FIXME we want + // type IterType = impl Iterator + #[allow(clippy::type_complexity)] + type IterType = std::iter::Map< + std::iter::Zip, std::ops::Range>, + fn((&'a V, usize)) -> (usize, &'a N), + >; + + #[inline(always)] + fn into_sparse_vec_iter(self) -> Self::IterType { + let n = DenseVector::dim(self); + // FIXME since it's not possible to have an existential type as an + // associated type yet, I'm using a trick to send the necessary + // context to a plain function, which enables specifying the type + // Needless to say, this needs to go when it's no longer necessary + #[inline(always)] + fn hack_instead_of_closure(vi: (&V, usize)) -> (usize, &N) + where + V: DenseVector, + { + (vi.1, vi.0.index(vi.1)) + } + std::iter::repeat(self) + .zip(0..n) + .map(hack_instead_of_closure) } - fn into_sparse_vec_iter( - self, - ) -> Enumerate> { - self.iter().enumerate() + fn dim(&self) -> usize { + DenseVector::dim(*self) } fn is_dense(&self) -> bool { true } + #[inline(always)] fn index(self, idx: usize) -> &'a N { - &self[[idx]] + DenseVector::index(self, idx) } } @@ -959,6 +938,10 @@ where } /// Fill a dense vector with our values + // FIXME I'm uneasy with this &mut V, can't I get rid of it with more + // trait magic? I would probably need to define what a mutable view is... + // But it's valuable. But I cannot find a way with the current trait system. + // Would probably require something link existential lifetimes. pub fn scatter(&self, out: &mut V) where N: Clone,