Skip to content

Commit

Permalink
Merge pull request #272 from vbarrielle/dense_vector
Browse files Browse the repository at this point in the history
Draft: Support generic dense vectors in more APIs
  • Loading branch information
vbarrielle authored Feb 15, 2021
2 parents b5a8f66 + 3e15203 commit e857e1d
Show file tree
Hide file tree
Showing 8 changed files with 577 additions and 219 deletions.
53 changes: 33 additions & 20 deletions sprs-ldl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
// Copyright, this License, and the Availability note are retained,
// and a notice that the code was modified is included.
use std::ops::Deref;
use std::ops::IndexMut;

use num_traits::Num;

Expand All @@ -61,6 +60,7 @@ use sprs::indexing::SpIndex;
use sprs::linalg;
use sprs::stack::DStack;
use sprs::{is_symmetric, CsMatViewI, PermOwnedI, Permutation};
use sprs::{DenseVector, DenseVectorMut};
use sprs::{FillInReduction, PermutationCheck, SymmetryCheck};

#[cfg(feature = "sprs_suitesparse_ldl")]
Expand Down Expand Up @@ -380,18 +380,31 @@ impl<N, I: SpIndex> LdlNumeric<N, I> {
}

/// Solve the system A x = rhs
pub fn solve<'a, V>(&self, rhs: &V) -> Vec<N>
///
/// The type constraints look complicated, but they simply mean that
/// `rhs` should be interpretable as a dense vector, and we will return
/// a dense vector of a compatible type (but owned).
pub fn solve<'a, V>(
&self,
rhs: V,
) -> <<V as DenseVector>::Owned as DenseVector>::Owned
where
N: 'a + Copy + Num,
V: Deref<Target = [N]>,
N: 'a + Copy + Num + std::ops::SubAssign + std::ops::DivAssign,
V: DenseVector<Scalar = N>,
<V as DenseVector>::Owned: DenseVectorMut + DenseVector<Scalar = N>,
for<'b> &'b <V as DenseVector>::Owned: DenseVector<Scalar = N>,
for<'b> &'b mut <V as DenseVector>::Owned:
DenseVectorMut + DenseVector<Scalar = N>,
<<V as DenseVector>::Owned as DenseVector>::Owned:
DenseVectorMut + DenseVector<Scalar = N>,
{
let mut x = &self.symbolic.perm * &rhs[..];
let mut x = &self.symbolic.perm * rhs;
let l = self.l();
ldl_lsolve(&l, &mut x);
linalg::diag_solve(&self.diag, &mut x);
ldl_ltsolve(&l, &mut x);
let pinv = self.symbolic.perm.inv();
&pinv * &x
&pinv * x
}

/// The diagonal factor D of the LDL^T decomposition
Expand Down Expand Up @@ -579,34 +592,34 @@ where

/// Triangular solve specialized on lower triangular matrices
/// produced by ldlt (diagonal terms are omitted and assumed to be 1).
pub fn ldl_lsolve<N, I, V: ?Sized>(l: &CsMatViewI<N, I>, x: &mut V)
pub fn ldl_lsolve<N, I, V>(l: &CsMatViewI<N, I>, mut x: V)
where
N: Clone + Copy + Num,
N: Clone + Copy + Num + std::ops::SubAssign,
I: SpIndex,
V: IndexMut<usize, Output = N>,
V: DenseVectorMut + DenseVector<Scalar = N>,
{
for (col_ind, vec) in l.outer_iterator().enumerate() {
let x_col = x[col_ind];
let x_col = *x.index(col_ind);
for (row_ind, &value) in vec.iter() {
x[row_ind] = x[row_ind] - value * x_col;
*x.index_mut(row_ind) -= value * x_col;
}
}
}

/// Triangular transposed solve specialized on lower triangular matrices
/// produced by ldlt (diagonal terms are omitted and assumed to be 1).
pub fn ldl_ltsolve<N, I, V: ?Sized>(l: &CsMatViewI<N, I>, x: &mut V)
pub fn ldl_ltsolve<N, I, V>(l: &CsMatViewI<N, I>, mut x: V)
where
N: Clone + Copy + Num,
N: Clone + Copy + Num + std::ops::SubAssign,
I: SpIndex,
V: IndexMut<usize, Output = N>,
V: DenseVectorMut + DenseVector<Scalar = N>,
{
for (outer_ind, vec) in l.outer_iterator().enumerate().rev() {
let mut x_outer = x[outer_ind];
let mut x_outer = *x.index(outer_ind);
for (inner_ind, &value) in vec.iter() {
x_outer = x_outer - value * x[inner_ind];
x_outer -= value * *x.index(inner_ind);
}
x[outer_ind] = x_outer;
*x.index_mut(outer_ind) = x_outer;
}
}

Expand Down Expand Up @@ -838,15 +851,15 @@ mod test {
vec![1., 2., 21., 6., 6., 2., 2., 8.],
);

let b = vec![9., 60., 18., 34.];
let x0 = vec![1., 2., 3., 4.];
let b = ndarray::arr1(&[9., 60., 18., 34.]);
let x0 = ndarray::arr1(&[1., 2., 3., 4.]);

let ldlt = super::Ldl::new()
.check_symmetry(super::SymmetryCheck::DontCheckSymmetry)
.fill_in_reduction(super::FillInReduction::ReverseCuthillMcKee)
.numeric(mat.view())
.unwrap();
let x = ldlt.solve(&b);
let x = ldlt.solve(b.view());
assert_eq!(x, x0);
}

Expand Down
Loading

0 comments on commit e857e1d

Please sign in to comment.