Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Support generic dense vectors in more APIs #272

Merged
merged 7 commits into from
Feb 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions sprs-ldl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
// Copyright, this License, and the Availability note are retained,
// and a notice that the code was modified is included.
use std::ops::Deref;
use std::ops::IndexMut;

use num_traits::Num;

Expand All @@ -61,6 +60,7 @@ use sprs::indexing::SpIndex;
use sprs::linalg;
use sprs::stack::DStack;
use sprs::{is_symmetric, CsMatViewI, PermOwnedI, Permutation};
use sprs::{DenseVector, DenseVectorMut};
use sprs::{FillInReduction, PermutationCheck, SymmetryCheck};

#[cfg(feature = "sprs_suitesparse_ldl")]
Expand Down Expand Up @@ -380,18 +380,31 @@ impl<N, I: SpIndex> LdlNumeric<N, I> {
}

/// Solve the system A x = rhs
pub fn solve<'a, V>(&self, rhs: &V) -> Vec<N>
///
/// The type constraints look complicated, but they simply mean that
/// `rhs` should be interpretable as a dense vector, and we will return
/// a dense vector of a compatible type (but owned).
pub fn solve<'a, V>(
&self,
rhs: V,
) -> <<V as DenseVector>::Owned as DenseVector>::Owned
where
N: 'a + Copy + Num,
V: Deref<Target = [N]>,
N: 'a + Copy + Num + std::ops::SubAssign + std::ops::DivAssign,
V: DenseVector<Scalar = N>,
<V as DenseVector>::Owned: DenseVectorMut + DenseVector<Scalar = N>,
for<'b> &'b <V as DenseVector>::Owned: DenseVector<Scalar = N>,
for<'b> &'b mut <V as DenseVector>::Owned:
DenseVectorMut + DenseVector<Scalar = N>,
<<V as DenseVector>::Owned as DenseVector>::Owned:
DenseVectorMut + DenseVector<Scalar = N>,
{
let mut x = &self.symbolic.perm * &rhs[..];
let mut x = &self.symbolic.perm * rhs;
let l = self.l();
ldl_lsolve(&l, &mut x);
linalg::diag_solve(&self.diag, &mut x);
ldl_ltsolve(&l, &mut x);
let pinv = self.symbolic.perm.inv();
&pinv * &x
&pinv * x
}

/// The diagonal factor D of the LDL^T decomposition
Expand Down Expand Up @@ -579,34 +592,34 @@ where

/// Triangular solve specialized on lower triangular matrices
/// produced by ldlt (diagonal terms are omitted and assumed to be 1).
pub fn ldl_lsolve<N, I, V: ?Sized>(l: &CsMatViewI<N, I>, x: &mut V)
pub fn ldl_lsolve<N, I, V>(l: &CsMatViewI<N, I>, mut x: V)
where
N: Clone + Copy + Num,
N: Clone + Copy + Num + std::ops::SubAssign,
I: SpIndex,
V: IndexMut<usize, Output = N>,
V: DenseVectorMut + DenseVector<Scalar = N>,
{
for (col_ind, vec) in l.outer_iterator().enumerate() {
let x_col = x[col_ind];
let x_col = *x.index(col_ind);
for (row_ind, &value) in vec.iter() {
x[row_ind] = x[row_ind] - value * x_col;
*x.index_mut(row_ind) -= value * x_col;
}
}
}

/// Triangular transposed solve specialized on lower triangular matrices
/// produced by ldlt (diagonal terms are omitted and assumed to be 1).
pub fn ldl_ltsolve<N, I, V: ?Sized>(l: &CsMatViewI<N, I>, x: &mut V)
pub fn ldl_ltsolve<N, I, V>(l: &CsMatViewI<N, I>, mut x: V)
where
N: Clone + Copy + Num,
N: Clone + Copy + Num + std::ops::SubAssign,
I: SpIndex,
V: IndexMut<usize, Output = N>,
V: DenseVectorMut + DenseVector<Scalar = N>,
{
for (outer_ind, vec) in l.outer_iterator().enumerate().rev() {
let mut x_outer = x[outer_ind];
let mut x_outer = *x.index(outer_ind);
for (inner_ind, &value) in vec.iter() {
x_outer = x_outer - value * x[inner_ind];
x_outer -= value * *x.index(inner_ind);
}
x[outer_ind] = x_outer;
*x.index_mut(outer_ind) = x_outer;
}
}

Expand Down Expand Up @@ -838,15 +851,15 @@ mod test {
vec![1., 2., 21., 6., 6., 2., 2., 8.],
);

let b = vec![9., 60., 18., 34.];
let x0 = vec![1., 2., 3., 4.];
let b = ndarray::arr1(&[9., 60., 18., 34.]);
let x0 = ndarray::arr1(&[1., 2., 3., 4.]);

let ldlt = super::Ldl::new()
.check_symmetry(super::SymmetryCheck::DontCheckSymmetry)
.fill_in_reduction(super::FillInReduction::ReverseCuthillMcKee)
.numeric(mat.view())
.unwrap();
let x = ldlt.solve(&b);
let x = ldlt.solve(b.view());
assert_eq!(x, x0);
}

Expand Down
Loading