Skip to content

Commit

Permalink
Linalg traits (#54)
Browse files Browse the repository at this point in the history
* WIP: Documentation

* WIP: New documentation

* Fixed cargo format error

* WIP: Lapack Traits

* Added traits for Lapack Operations

* New doc (#53)

* WIP: Documentation

* WIP: New documentation

* Fixed cargo format error
  • Loading branch information
tbetcke authored Dec 18, 2023
1 parent 19f7610 commit 9793283
Show file tree
Hide file tree
Showing 7 changed files with 480 additions and 171 deletions.
23 changes: 23 additions & 0 deletions dense/src/linalg.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
//! Linear algebra routines
use crate::array::Array;
use crate::traits::*;
use rlst_common::types::Scalar;

use self::{
inverse::MatrixInverse, lu::MatrixLuDecomposition, pseudo_inverse::MatrixPseudoInverse,
qr::MatrixQrDecomposition, svd::MatrixSvd,
};
pub mod inverse;
pub mod lu;
pub mod pseudo_inverse;
Expand All @@ -13,3 +22,17 @@ pub fn assert_lapack_stride(stride: [usize; 2]) {
stride[0]
);
}

/// Marker trait for Arrays that provide
pub trait Linalg {}

impl<Item: Scalar, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Shape<2>> Linalg
for Array<Item, ArrayImpl, 2>
where
Array<Item, ArrayImpl, 2>: MatrixInverse
+ MatrixLuDecomposition
+ MatrixPseudoInverse
+ MatrixQrDecomposition
+ MatrixSvd,
{
}
41 changes: 23 additions & 18 deletions dense/src/linalg/inverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,37 @@ use rlst_common::types::{c32, c64, RlstError, RlstResult, Scalar};

use super::assert_lapack_stride;

/// Compute the matrix inverse.
///
/// The matrix inverse is defined for a two dimensional square array `arr` of
/// shape `[m, m]`.
///
/// # Example
///
/// The following command computes the inverse of an array `a`. The content
/// of `a` is replaced by the inverse.
/// ```
/// # use rlst_dense::rlst_dynamic_array2;
/// # use rlst_dense::linalg::inverse::MatrixInverse;
/// # let mut a = rlst_dynamic_array2!(f64, [3, 3]);
/// # a.fill_from_seed_equally_distributed(0);
/// a.view_mut().into_inverse_alloc().unwrap();
/// ```
/// This method allocates memory for the inverse computation.
pub trait MatrixInverse {
fn into_inverse_alloc(self) -> RlstResult<()>;
}

macro_rules! impl_inverse {
($scalar:ty, $getrf: expr, $getri:expr) => {
impl<
ArrayImpl: UnsafeRandomAccessByValue<2, Item = $scalar>
+ Stride<2>
+ Shape<2>
+ RawAccessMut<Item = $scalar>,
> Array<$scalar, ArrayImpl, 2>
> MatrixInverse for Array<$scalar, ArrayImpl, 2>
{
//! Compute the matrix inverse.
//!
//! The matrix inverse is defined for a two dimensional square array `arr` of
//! shape `[m, m]`.
//!
//! # Example
//!
//! The following command computes the inverse of an array `a`. The content
//! of `a` is replaced by the inverse.
//! ```
//! # use rlst_dense::rlst_dynamic_array2;
//! # let mut a = rlst_dynamic_array2!(f64, [3, 3]);
//! # a.fill_from_seed_equally_distributed(0);
//! a.view_mut().into_inverse_alloc().unwrap();
//! ```
//! This method allocates memory for the inverse computation.
pub fn into_inverse_alloc(mut self) -> RlstResult<()> {
fn into_inverse_alloc(mut self) -> RlstResult<()> {
assert_lapack_stride(self.stride());

let m = self.shape()[0] as i32;
Expand Down
222 changes: 176 additions & 46 deletions dense/src/linalg/lu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,165 @@ use lapack::{cgetrf, cgetrs, dgetrf, dgetrs, sgetrf, sgetrs, zgetrf, zgetrs};
use num::One;
use rlst_common::types::*;

pub trait MatrixLuDecomposition {}

impl<
Item: Scalar,
ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item>
+ Shape<2>
+ Stride<2>
+ RawAccessMut<Item = Item>,
> MatrixLuDecomposition for Array<Item, ArrayImpl, 2>
where
Self: IntoLu<Item = Item, ArrayImpl = ArrayImpl>,
LuDecomposition<Item, ArrayImpl>: LuOperations<Item = Item, ArrayImpl = ArrayImpl>,
{
}

/// Compute the LU decomposition of a matrix.
///
/// The LU Decomposition of an `(m,n)` matrix `A` is defined
/// by `A = PLU`, where `P` is an `(m, m)` permutation matrix,
/// `L` is a `(m, k)` unit lower triangular matrix, and `U` is
/// an `(k, n)` upper triangular matrix.
pub trait IntoLu {
type Item: Scalar;
type ArrayImpl: UnsafeRandomAccessByValue<2, Item = Self::Item>
+ Stride<2>
+ RawAccessMut<Item = Self::Item>
+ Shape<2>;

fn into_lu(self) -> RlstResult<LuDecomposition<Self::Item, Self::ArrayImpl>>;
}

pub trait LuOperations: Sized {
type Item: Scalar;
type ArrayImpl: UnsafeRandomAccessByValue<2, Item = Self::Item>
+ Stride<2>
+ RawAccessMut<Item = Self::Item>
+ Shape<2>;

/// Create a new LU Decomposition from a given array.
fn new(arr: Array<Self::Item, Self::ArrayImpl, 2>) -> RlstResult<Self>;

/// Solve a linear system with a single right-hand side.
///
/// The right-hand side is overwritten with the solution.
fn solve_vec<
ArrayImplMut: RawAccessMut<Item = Self::Item>
+ UnsafeRandomAccessByValue<1, Item = Self::Item>
+ Shape<1>
+ Stride<1>,
>(
&self,
trans: LuTrans,
rhs: Array<Self::Item, ArrayImplMut, 1>,
) -> RlstResult<()>;

/// Solve a linear system with multiple right-hand sides.
///
/// The right-hand sides are overwritten with the solution.
fn solve_mat<
ArrayImplMut: RawAccessMut<Item = Self::Item>
+ UnsafeRandomAccessByValue<2, Item = Self::Item>
+ Shape<2>
+ Stride<2>,
>(
&self,
trans: LuTrans,
rhs: Array<Self::Item, ArrayImplMut, 2>,
) -> RlstResult<()>;

/// Get the L matrix of the LU Decomposition.
///
/// This method resizes the input `arr` as required.
fn get_l_resize<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = Self::Item>
+ UnsafeRandomAccessByRef<2, Item = Self::Item>
+ ResizeInPlace<2>,
>(
&self,
arr: Array<Self::Item, ArrayImplMut, 2>,
);

/// Get the L matrix of the LU decomposition.
///
/// If A has the dimension `(m, n)` then the L matrix
/// has the dimension `(m, k)` with `k = min(m, n)`.
fn get_l<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = Self::Item>
+ UnsafeRandomAccessByRef<2, Item = Self::Item>,
>(
&self,
arr: Array<Self::Item, ArrayImplMut, 2>,
);

/// Get the R matrix of the LU Decomposition.
///
/// This method resizes the input `arr` as required.
fn get_u_resize<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = Self::Item>
+ UnsafeRandomAccessByRef<2, Item = Self::Item>
+ ResizeInPlace<2>,
>(
&self,
arr: Array<Self::Item, ArrayImplMut, 2>,
);

/// Get the R matrix of the LU Decomposition.
///
/// If A has the dimension `(m, n)` then the L matrix
/// has the dimension `(k, n)` with `k = min(m, n)`.
fn get_u<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = Self::Item>
+ UnsafeRandomAccessByRef<2, Item = Self::Item>,
>(
&self,
arr: Array<Self::Item, ArrayImplMut, 2>,
);

/// Get the P matrix of the LU Decomposition.
///
/// This method resizes the input `arr` as required.
fn get_p_resize<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = Self::Item>
+ UnsafeRandomAccessByRef<2, Item = Self::Item>
+ ResizeInPlace<2>,
>(
&self,
arr: Array<Self::Item, ArrayImplMut, 2>,
);

/// Get the permutation vector from the LU decomposition.
///
/// If `perm[i] = j` then the ith row of `LU` corresponds to the jth row of `A`.
fn get_perm(&self) -> Vec<usize>;

/// Get the P matrix of the LU Decomposition.
///
/// If A has the dimension `(m, n)` then the P matrix
/// has the dimension (m, m).
fn get_p<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = Self::Item>
+ UnsafeRandomAccessByRef<2, Item = Self::Item>,
>(
&self,
arr: Array<Self::Item, ArrayImplMut, 2>,
);
}

/// Transposition modes for solving linear systems via LU decomposition.
pub enum LuTrans {
/// Transpose.
Expand All @@ -32,10 +191,12 @@ macro_rules! impl_lu {
+ Stride<2>
+ Shape<2>
+ RawAccessMut<Item = $scalar>,
> LuDecomposition<$scalar, ArrayImpl>
> LuOperations for LuDecomposition<$scalar, ArrayImpl>
{
/// Create a new LU Decomposition from a given array.
pub fn new(mut arr: Array<$scalar, ArrayImpl, 2>) -> RlstResult<Self> {
type Item = $scalar;
type ArrayImpl = ArrayImpl;

fn new(mut arr: Array<$scalar, ArrayImpl, 2>) -> RlstResult<Self> {
let shape = arr.shape();
let stride = arr.stride();

Expand Down Expand Up @@ -64,10 +225,7 @@ macro_rules! impl_lu {
}
}

/// Solve a linear system with a single right-hand side.
///
/// The right-hand side is overwritten with the solution.
pub fn solve_vec<
fn solve_vec<
ArrayImplMut: RawAccessMut<Item = $scalar>
+ UnsafeRandomAccessByValue<1, Item = $scalar>
+ Shape<1>
Expand All @@ -83,10 +241,7 @@ macro_rules! impl_lu {
)
}

/// Solve a linear system with multiple right-hand sides.
///
/// The right-hand sides are overwritten with the solution.
pub fn solve_mat<
fn solve_mat<
ArrayImplMut: RawAccessMut<Item = $scalar>
+ UnsafeRandomAccessByValue<2, Item = $scalar>
+ Shape<2>
Expand Down Expand Up @@ -138,10 +293,7 @@ macro_rules! impl_lu {
}
}

/// Get the L matrix of the LU Decomposition.
///
/// This method resizes the input `arr` as required.
pub fn get_l_resize<
fn get_l_resize<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = $scalar>
Expand All @@ -159,11 +311,7 @@ macro_rules! impl_lu {
self.get_l(arr);
}

/// Get the L matrix of the LU decomposition.
///
/// If A has the dimension `(m, n)` then the L matrix
/// has the dimension `(m, k)` with `k = min(m, n)`.
pub fn get_l<
fn get_l<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = $scalar>
Expand Down Expand Up @@ -197,10 +345,7 @@ macro_rules! impl_lu {
}
}

/// Get the R matrix of the LU Decomposition.
///
/// This method resizes the input `arr` as required.
pub fn get_u_resize<
fn get_u_resize<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = $scalar>
Expand All @@ -218,11 +363,7 @@ macro_rules! impl_lu {
self.get_u(arr);
}

/// Get the R matrix of the LU Decomposition.
///
/// If A has the dimension `(m, n)` then the L matrix
/// has the dimension `(k, n)` with `k = min(m, n)`.
pub fn get_u<
fn get_u<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = $scalar>
Expand Down Expand Up @@ -252,10 +393,7 @@ macro_rules! impl_lu {
}
}

/// Get the P matrix of the LU Decomposition.
///
/// This method resizes the input `arr` as required.
pub fn get_p_resize<
fn get_p_resize<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = $scalar>
Expand Down Expand Up @@ -287,11 +425,7 @@ macro_rules! impl_lu {
perm
}

/// Get the P matrix of the LU Decomposition.
///
/// If A has the dimension `(m, n)` then the P matrix
/// has the dimension (m, m).
pub fn get_p<
fn get_p<
ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar>
+ Shape<2>
+ UnsafeRandomAccessMut<2, Item = $scalar>
Expand Down Expand Up @@ -325,15 +459,11 @@ macro_rules! impl_lu {
+ Stride<2>
+ RawAccessMut<Item = $scalar>
+ Shape<2>,
> Array<$scalar, ArrayImpl, 2>
> IntoLu for Array<$scalar, ArrayImpl, 2>
{
/// Compute the LU decomposition of a matrix.
///
/// The LU Decomposition of an `(m,n)` matrix `A` is defined
/// by `A = PLU`, where `P` is an `(m, m)` permutation matrix,
/// `L` is a `(m, k)` unit lower triangular matrix, and `U` is
/// an `(k, n)` upper triangular matrix.
pub fn into_lu(self) -> RlstResult<LuDecomposition<$scalar, ArrayImpl>> {
type Item = $scalar;
type ArrayImpl = ArrayImpl;
fn into_lu(self) -> RlstResult<LuDecomposition<$scalar, ArrayImpl>> {
assert!(!self.is_empty(), "Matrix is empty.");
LuDecomposition::<$scalar, ArrayImpl>::new(self)
}
Expand Down
Loading

0 comments on commit 9793283

Please sign in to comment.