diff --git a/dense/Cargo.toml b/dense/Cargo.toml index 0f3d6e6a..66b61af1 100644 --- a/dense/Cargo.toml +++ b/dense/Cargo.toml @@ -27,11 +27,11 @@ rlst-blis = { path = "../blis" } approx = { version = "0.5", features = ["num-complex"] } rlst-operator = { path = "../operator" } rlst-common = { path = "../common" } -rlst-lapack = { path = "../lapack" } paste = "1" rand_chacha = "0.3" rlst-blis-src = { path = "../blis-src" } rlst-netlib-lapack-src = { path = "../netlib-lapack-src" } +lapack = "0.19.*" [dev-dependencies] criterion = { version = "0.3", features = ["html_reports"] } diff --git a/dense/src/linalg.rs b/dense/src/linalg.rs index 78fa6c4c..c3099d46 100644 --- a/dense/src/linalg.rs +++ b/dense/src/linalg.rs @@ -1,5 +1,6 @@ //! Linear algebra routines pub mod lu; +pub mod qr; pub fn assert_lapack_stride(stride: [usize; 2]) { assert_eq!( @@ -8,3 +9,9 @@ pub fn assert_lapack_stride(stride: [usize; 2]) { stride[0] ); } + +pub enum Trans { + Trans, + NoTrans, + ConjTrans, +} diff --git a/dense/src/linalg/lu.rs b/dense/src/linalg/lu.rs index 145fab7a..e11aa9a9 100644 --- a/dense/src/linalg/lu.rs +++ b/dense/src/linalg/lu.rs @@ -1,266 +1,286 @@ //! LU Decomposition and linear system solves use super::assert_lapack_stride; +use super::Trans; use crate::array::Array; +use lapack::{dgetrf, dgetrs}; use num::One; use rlst_common::traits::*; use rlst_common::types::*; -use rlst_lapack::Lapack; -use rlst_lapack::Trans; -use rlst_lapack::{Getrf, Getrs}; pub struct LuDecomposition< - Item: Scalar + Lapack, + Item: Scalar, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Stride<2> + Shape<2> + RawAccessMut, > { arr: Array, ipiv: Vec, } -impl< - Item: Scalar + Lapack, - ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> - + Stride<2> - + Shape<2> - + RawAccessMut, - > LuDecomposition -{ - pub fn new(mut arr: Array) -> RlstResult { - let shape = arr.shape(); - let stride = arr.stride(); - - assert_lapack_stride(stride); - - let dim = std::cmp::min(shape[0], shape[1]); - let mut ipiv = vec![0; dim]; - let info = ::getrf( - shape[0] as i32, - shape[1] as i32, - arr.data_mut(), - stride[1] as i32, - ipiv.as_mut_slice(), - ); - - match info { - 0 => Ok(Self { arr, ipiv }), - _ => Err(RlstError::LapackError(info)), - } - } +macro_rules! impl_lu { + ($scalar:ty, $getrf:expr, $getrs:expr) => { + impl< + ArrayImpl: UnsafeRandomAccessByValue<2, Item = $scalar> + + Stride<2> + + Shape<2> + + RawAccessMut, + > LuDecomposition<$scalar, ArrayImpl> + { + pub fn new(mut arr: Array<$scalar, ArrayImpl, 2>) -> RlstResult { + let shape = arr.shape(); + let stride = arr.stride(); + + assert_lapack_stride(stride); + + let dim = std::cmp::min(shape[0], shape[1]); + if dim == 0 { + return Err(RlstError::MatrixIsEmpty((shape[0], shape[1]))); + } + let mut ipiv = vec![0; dim]; + let mut info = 0; + unsafe { + $getrf( + shape[0] as i32, + shape[1] as i32, + arr.data_mut(), + stride[1] as i32, + &mut ipiv, + &mut info, + ); + } - pub fn solve_into< - ArrayImplMut: RawAccessMut - + UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + Stride<2>, - >( - &self, - trans: Trans, - mut rhs: Array, - ) -> RlstResult> { - assert_eq!(self.arr.shape()[0], self.arr.shape()[1]); - let n = self.arr.shape()[0]; - assert_eq!(rhs.shape()[0], n); - - let nrhs = rhs.shape()[1]; - - let arr_stride = self.arr.stride(); - let rhs_stride = rhs.stride(); - - let lda = self.arr.stride()[1]; - let ldb = rhs.stride()[1]; - - assert_lapack_stride(arr_stride); - assert_lapack_stride(rhs_stride); - - let info = ::getrs( - trans, - n as i32, - nrhs as i32, - self.arr.data(), - lda as i32, - self.ipiv.as_slice(), - rhs.data_mut(), - ldb as i32, - ); - - match info { - 0 => Ok(rhs), - _ => Err(RlstError::LapackError(info)), - } - } + match info { + 0 => Ok(Self { arr, ipiv }), + _ => Err(RlstError::LapackError(info)), + } + } - pub fn get_l_resize< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item> - + ResizeInPlace<2>, - >( - &self, - mut arr: Array, - ) { - let m = self.arr.shape()[0]; - let n = self.arr.shape()[1]; - let k = std::cmp::min(m, n); - - arr.resize_in_place([m, k]); - self.get_l(arr); - } + pub fn solve< + ArrayImplMut: RawAccessMut + + UnsafeRandomAccessByValue<2, Item = $scalar> + + Shape<2> + + Stride<2>, + >( + &self, + trans: Trans, + mut rhs: Array<$scalar, ArrayImplMut, 2>, + ) -> RlstResult> { + assert_eq!(self.arr.shape()[0], self.arr.shape()[1]); + let n = self.arr.shape()[0]; + assert_eq!(rhs.shape()[0], n); + + let nrhs = rhs.shape()[1]; + + let arr_stride = self.arr.stride(); + let rhs_stride = rhs.stride(); + + let lda = self.arr.stride()[1]; + let ldb = rhs.stride()[1]; + + assert_lapack_stride(arr_stride); + assert_lapack_stride(rhs_stride); + + let trans_param = match trans { + Trans::NoTrans => b'N', + Trans::Trans => b'T', + Trans::ConjTrans => b'C', + }; + + let mut info = 0; + unsafe { + $getrs( + trans_param, + n as i32, + nrhs as i32, + self.arr.data(), + lda as i32, + self.ipiv.as_slice(), + rhs.data_mut(), + ldb as i32, + &mut info, + ) + }; + + match info { + 0 => Ok(rhs), + _ => Err(RlstError::LapackError(info)), + } + } + + pub fn get_l_resize< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar> + + Shape<2> + + UnsafeRandomAccessMut<2, Item = $scalar> + + UnsafeRandomAccessByRef<2, Item = $scalar> + + ResizeInPlace<2>, + >( + &self, + mut arr: Array<$scalar, ArrayImplMut, 2>, + ) { + let m = self.arr.shape()[0]; + let n = self.arr.shape()[1]; + let k = std::cmp::min(m, n); + + arr.resize_in_place([m, k]); + self.get_l(arr); + } - pub fn get_l< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - mut arr: Array, - ) { - let m = self.arr.shape()[0]; - let n = self.arr.shape()[1]; - let k = std::cmp::min(m, n); - assert_eq!( - arr.shape(), - [m, k], - "Require matrix with shape {} x {}. Given shape is {} x {}", - m, - k, - arr.shape()[0], - arr.shape()[1] - ); - - arr.set_zero(); - for col in 0..k { - for row in col..m { - if col == row { - arr[[row, col]] = ::one(); - } else { - arr[[row, col]] = self.arr.get_value([row, col]).unwrap(); + pub fn get_l< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar> + + Shape<2> + + UnsafeRandomAccessMut<2, Item = $scalar> + + UnsafeRandomAccessByRef<2, Item = $scalar>, + >( + &self, + mut arr: Array<$scalar, ArrayImplMut, 2>, + ) { + let m = self.arr.shape()[0]; + let n = self.arr.shape()[1]; + let k = std::cmp::min(m, n); + assert_eq!( + arr.shape(), + [m, k], + "Require matrix with shape {} x {}. Given shape is {} x {}", + m, + k, + arr.shape()[0], + arr.shape()[1] + ); + + arr.set_zero(); + for col in 0..k { + for row in col..m { + if col == row { + arr[[row, col]] = <$scalar as One>::one(); + } else { + arr[[row, col]] = self.arr.get_value([row, col]).unwrap(); + } + } } } - } - } - pub fn get_r_resize< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item> - + ResizeInPlace<2>, - >( - &self, - mut arr: Array, - ) { - let m = self.arr.shape()[0]; - let n = self.arr.shape()[1]; - let k = std::cmp::min(m, n); - - arr.resize_in_place([k, n]); - self.get_r(arr); - } + pub fn get_r_resize< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar> + + Shape<2> + + UnsafeRandomAccessMut<2, Item = $scalar> + + UnsafeRandomAccessByRef<2, Item = $scalar> + + ResizeInPlace<2>, + >( + &self, + mut arr: Array<$scalar, ArrayImplMut, 2>, + ) { + let m = self.arr.shape()[0]; + let n = self.arr.shape()[1]; + let k = std::cmp::min(m, n); + + arr.resize_in_place([k, n]); + self.get_r(arr); + } - pub fn get_r< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - mut arr: Array, - ) { - let m = self.arr.shape()[0]; - let n = self.arr.shape()[1]; - let k = std::cmp::min(m, n); - assert_eq!( - arr.shape(), - [k, n], - "Require matrix with shape {} x {}. Given shape is {} x {}", - k, - n, - arr.shape()[0], - arr.shape()[1] - ); - - arr.set_zero(); - for col in 0..n { - for row in 0..=std::cmp::min(col, k - 1) { - arr[[row, col]] = self.arr.get_value([row, col]).unwrap(); + pub fn get_r< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar> + + Shape<2> + + UnsafeRandomAccessMut<2, Item = $scalar> + + UnsafeRandomAccessByRef<2, Item = $scalar>, + >( + &self, + mut arr: Array<$scalar, ArrayImplMut, 2>, + ) { + let m = self.arr.shape()[0]; + let n = self.arr.shape()[1]; + let k = std::cmp::min(m, n); + assert_eq!( + arr.shape(), + [k, n], + "Require matrix with shape {} x {}. Given shape is {} x {}", + k, + n, + arr.shape()[0], + arr.shape()[1] + ); + + arr.set_zero(); + for col in 0..n { + for row in 0..=std::cmp::min(col, k - 1) { + arr[[row, col]] = self.arr.get_value([row, col]).unwrap(); + } + } } - } - } - pub fn get_p_resize< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item> - + ResizeInPlace<2>, - >( - &self, - mut arr: Array, - ) { - let m = self.arr.shape()[0]; - - arr.resize_in_place([m, m]); - self.get_p(arr); - } + pub fn get_p_resize< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar> + + Shape<2> + + UnsafeRandomAccessMut<2, Item = $scalar> + + UnsafeRandomAccessByRef<2, Item = $scalar> + + ResizeInPlace<2>, + >( + &self, + mut arr: Array<$scalar, ArrayImplMut, 2>, + ) { + let m = self.arr.shape()[0]; + + arr.resize_in_place([m, m]); + self.get_p(arr); + } - fn get_perm(&self) -> Vec { - let m = self.arr.shape()[0]; - // let n = self.arr.shape()[1]; - // let k = std::cmp::min(m, n); - let ipiv: Vec = self.ipiv.iter().map(|&elem| (elem as usize) - 1).collect(); + fn get_perm(&self) -> Vec { + let m = self.arr.shape()[0]; + // let n = self.arr.shape()[1]; + // let k = std::cmp::min(m, n); + let ipiv: Vec = self.ipiv.iter().map(|&elem| (elem as usize) - 1).collect(); - let mut perm = (0..m).collect::>(); + let mut perm = (0..m).collect::>(); - for (index, &elem) in ipiv.iter().enumerate() { - perm.swap(index, elem); - } + for (index, &elem) in ipiv.iter().enumerate() { + perm.swap(index, elem); + } - perm - } + perm + } - pub fn get_p< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - mut arr: Array, - ) { - let m = self.arr.shape()[0]; - assert_eq!( - arr.shape(), - [m, m], - "Require matrix with shape {} x {}. Given shape is {} x {}", - m, - m, - arr.shape()[0], - arr.shape()[1] - ); - - let perm = self.get_perm(); - - arr.set_zero(); - for col in 0..m { - arr[[perm[col], col]] = ::one(); + pub fn get_p< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = $scalar> + + Shape<2> + + UnsafeRandomAccessMut<2, Item = $scalar> + + UnsafeRandomAccessByRef<2, Item = $scalar>, + >( + &self, + mut arr: Array<$scalar, ArrayImplMut, 2>, + ) { + let m = self.arr.shape()[0]; + assert_eq!( + arr.shape(), + [m, m], + "Require matrix with shape {} x {}. Given shape is {} x {}", + m, + m, + arr.shape()[0], + arr.shape()[1] + ); + + let perm = self.get_perm(); + + arr.set_zero(); + for col in 0..m { + arr[[perm[col], col]] = <$scalar as One>::one(); + } + } } - } -} -impl< - Item: Scalar + Lapack, - ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> - + Stride<2> - + RawAccessMut - + Shape<2>, - > Array -{ - pub fn into_lu(self) -> RlstResult> { - LuDecomposition::new(self) - } + impl< + ArrayImpl: UnsafeRandomAccessByValue<2, Item = $scalar> + + Stride<2> + + RawAccessMut + + Shape<2>, + > Array<$scalar, ArrayImpl, 2> + { + pub fn into_lu(self) -> RlstResult> { + LuDecomposition::new(self) + } + } + }; } +impl_lu!(f64, dgetrf, dgetrs); + #[cfg(test)] mod test { diff --git a/dense/src/linalg/qr.rs b/dense/src/linalg/qr.rs new file mode 100644 index 00000000..d347e561 --- /dev/null +++ b/dense/src/linalg/qr.rs @@ -0,0 +1,28 @@ +//! Interface to QR Decomposition + +// use super::assert_lapack_stride; +// use crate::array::Array; +// use num::One; +// use rlst_common::traits::*; +// use rlst_common::types::*; +// use rlst_lapack::{Dgeqp3, Ormqr}; + +// pub struct QRDecomposition< +// Item: Scalar, +// ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Stride<2> + Shape<2> + RawAccessMut, +// > { +// arr: Array, +// tau: Vec, +// jpvt: Vec, +// } + +// impl< +// Item: Scalar + Dgeqp3 + Ormqr, +// ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> +// + Stride<2> +// + Shape<2> +// + RawAccessMut, +// > QRDecomposition +// { +// pub fn new(arr: Array, work: Option<) +// } diff --git a/lapack/src/geqp3.rs b/lapack/src/geqp3.rs new file mode 100644 index 00000000..31bd6fb5 --- /dev/null +++ b/lapack/src/geqp3.rs @@ -0,0 +1,98 @@ +//! Interface to dgeqp3 + +use lapack::{cgeqp3, dgeqp3, sgeqp3, zgeqp3}; +use num::Zero; +use rlst_common::types::*; + +pub trait Geqp3: Scalar { + fn geqp3(m: i32, n: i32, a: &mut [Self], lda: i32, jpvt: &mut [i32], tau: &mut [Self]) -> i32; +} + +macro_rules! impl_geqp3_real { + ($scalar:ty, $geqp3:expr) => { + impl Geqp3 for $scalar { + fn geqp3( + m: i32, + n: i32, + a: &mut [Self], + lda: i32, + jpvt: &mut [i32], + tau: &mut [Self], + ) -> i32 { + assert!(m >= 0); + assert!(n >= 0); + assert!(lda >= 0); + assert_eq!(a.len() as i32, lda * n); + assert_eq!(jpvt.len() as i32, n); + assert_eq!(tau.len() as i32, std::cmp::min(m, n)); + + let mut info = 0; + let lwork = -1; + let mut work_query = [::zero()]; + unsafe { $geqp3(m, n, a, lda, jpvt, tau, &mut work_query, lwork, &mut info) }; + assert_eq!(info, 0); + let lwork = work_query[0].re() as i32; + + let mut work = vec![::zero(); lwork as usize]; + unsafe { $geqp3(m, n, a, lda, jpvt, tau, &mut work, lwork, &mut info) }; + info + } + } + }; +} + +macro_rules! impl_geqp3_complex { + ($scalar:ty, $geqp3:expr) => { + impl Geqp3 for $scalar { + fn geqp3( + m: i32, + n: i32, + a: &mut [Self], + lda: i32, + jpvt: &mut [i32], + tau: &mut [Self], + ) -> i32 { + assert!(m >= 0); + assert!(n >= 0); + assert!(lda >= 0); + assert_eq!(a.len() as i32, lda * n); + assert_eq!(jpvt.len() as i32, n); + assert_eq!(tau.len() as i32, std::cmp::min(m, n)); + let mut rwork = vec![<::Real as Zero>::zero(); 2 * n as usize]; + + let mut info = 0; + let lwork = -1; + let mut work_query = [::zero()]; + unsafe { + $geqp3( + m, + n, + a, + lda, + jpvt, + tau, + &mut work_query, + lwork, + &mut rwork, + &mut info, + ) + }; + assert_eq!(info, 0); + let lwork = work_query[0].re() as i32; + + let mut work = vec![::zero(); lwork as usize]; + unsafe { + $geqp3( + m, n, a, lda, jpvt, tau, &mut work, lwork, &mut rwork, &mut info, + ) + }; + info + } + } + }; +} + +impl_geqp3_real!(f32, sgeqp3); +impl_geqp3_real!(f64, dgeqp3); +impl_geqp3_complex!(c32, cgeqp3); +impl_geqp3_complex!(c64, zgeqp3); diff --git a/lapack/src/getrf.rs b/lapack/src/getrf.rs index c08c50b4..6b71da8b 100644 --- a/lapack/src/getrf.rs +++ b/lapack/src/getrf.rs @@ -11,7 +11,7 @@ macro_rules! impl_getrf { impl Getrf for $scalar { fn getrf(m: i32, n: i32, a: &mut [Self], lda: i32, ipiv: &mut [i32]) -> i32 { assert!(m >= 0); - assert!(m >= 0); + assert!(n >= 0); assert!(lda >= std::cmp::max(m, 1)); assert_eq!(a.len() as i32, lda * n); assert_eq!(ipiv.len() as i32, std::cmp::min(m, n)); diff --git a/lapack/src/getrs.rs b/lapack/src/getrs.rs index 477f215a..d2f38d20 100644 --- a/lapack/src/getrs.rs +++ b/lapack/src/getrs.rs @@ -1,11 +1,10 @@ //! Getrf -use crate::Trans; use lapack::{cgetrs, dgetrs, sgetrs, zgetrs}; use rlst_common::types::*; pub trait Getrs: Scalar { fn getrs( - trans: Trans, + trans: u8, n: i32, nrhs: i32, a: &[Self], @@ -20,7 +19,7 @@ macro_rules! impl_getrs { ($scalar:ty, $getrs:expr) => { impl Getrs for $scalar { fn getrs( - trans: Trans, + trans: u8, n: i32, nrhs: i32, a: &[Self], @@ -36,6 +35,7 @@ macro_rules! impl_getrs { assert_eq!(ipiv.len() as i32, n); assert_eq!(b.len() as i32, ldb * nrhs); assert!(ldb >= std::cmp::max(1, n)); + assert!(trans == b'N' || trans == b'T' || trans == b'C'); for &elem in ipiv { assert!(elem >= 1); @@ -44,7 +44,7 @@ macro_rules! impl_getrs { let mut info = 0; - unsafe { $getrs(trans as u8, n, nrhs, a, lda, ipiv, b, ldb, &mut info) } + unsafe { $getrs(trans, n, nrhs, a, lda, ipiv, b, ldb, &mut info) } info } diff --git a/lapack/src/lib.rs b/lapack/src/lib.rs index e27a3fbc..e004b386 100644 --- a/lapack/src/lib.rs +++ b/lapack/src/lib.rs @@ -4,21 +4,19 @@ use rlst_common::types::Scalar; +pub mod geqp3; pub mod getrf; pub mod getrs; +pub mod ormqr; +pub mod unmqr; +pub use geqp3::Geqp3; pub use getrf::Getrf; pub use getrs::Getrs; +pub use ormqr::Ormqr; +pub use unmqr::Unmqr; -#[derive(Clone, Copy)] -#[repr(u8)] -pub enum Trans { - NoTranspose = b'N', - Transpose = b'T', - ConjugateTranspose = b'C', -} +// // Collective Lapack wrapper trait +// pub trait Lapack: Scalar + Getrf + Getrs + Unmqr + Ormqr {} -// Collective Lapack wrapper trait -pub trait Lapack: Scalar + Getrf + Getrs {} - -impl Lapack for T {} +// impl Lapack for T {} diff --git a/lapack/src/ormqr.rs b/lapack/src/ormqr.rs new file mode 100644 index 00000000..6874120e --- /dev/null +++ b/lapack/src/ormqr.rs @@ -0,0 +1,253 @@ +//! Implementation of ormqr + +use lapack::cunmqr; +use lapack::dormqr; +use lapack::sormqr; +use lapack::zunmqr; +use num::Zero; +use rlst_common::types::*; + +pub trait Ormqr: Scalar { + fn ormqr( + side: u8, + trans: u8, + m: i32, + n: i32, + k: i32, + a: &mut [Self], + lda: i32, + tau: &mut [Self], + c: &mut [Self], + ldc: i32, + work: Option<&mut [Self]>, + ) -> i32; + + fn ormqr_query_work(side: u8, trans: u8, m: i32, n: i32, k: i32) -> i32; +} + +macro_rules! impl_ormqr_complex { + ($scalar:ty, $ormqr:expr) => { + impl Ormqr for $scalar { + fn ormqr( + side: u8, + trans: u8, + m: i32, + n: i32, + k: i32, + a: &mut [Self], + lda: i32, + tau: &mut [Self], + c: &mut [Self], + ldc: i32, + work: Option<&mut [Self]>, + ) -> i32 { + assert!(side == b'L' || side == b'R'); + assert!(trans == b'C' || trans == b'N'); + assert!(m >= 0); + assert!(n >= 0); + + assert!(if side == b'L' { + k >= 0 && k <= m + } else { + k >= 0 && k <= n + }); + assert!(if side == b'L' { + lda >= std::cmp::max(1, m) + } else { + lda >= std::cmp::max(1, n) + }); + assert_eq!(a.len() as i32, lda * k); + assert_eq!(tau.len() as i32, k); + assert!(ldc >= std::cmp::max(1, m)); + assert_eq!(c.len() as i32, ldc * n); + let mut my_work = Vec::::new(); + let work = if let Some(work) = work { + assert!(if side == b'L' { + work.len() as i32 >= std::cmp::max(1, n) + } else { + work.len() as i32 >= std::cmp::max(1, m) + }); + work + } else { + let len = ::ormqr_query_work(side, trans, m, n, k) as usize; + my_work.resize(len, ::zero()); + &mut my_work + }; + + let mut info = 0; + unsafe { + $ormqr( + side, + trans, + m, + n, + k, + a, + lda, + tau, + c, + ldc, + work, + work.len() as i32, + &mut info, + ) + }; + info + } + + fn ormqr_query_work(side: u8, trans: u8, m: i32, n: i32, k: i32) -> i32 { + let a = [::zero()]; + let tau = [::zero()]; + let mut c = [::zero()]; + let mut work_query = [::zero()]; + let mut info = 0; + + assert!(side == b'L' || side == b'R'); + assert!(trans == b'T' || trans == b'N'); + assert!(m >= 0); + assert!(n >= 0); + assert!(if side == b'L' { + k >= 0 && k <= m + } else { + k >= 0 && k <= n + }); + let lda = if side == b'L' { m } else { n }; + unsafe { + $ormqr( + side, + trans, + m, + n, + k, + &a, + lda, + &tau, + &mut c, + m, + &mut work_query, + -1, + &mut info, + ); + } + assert_eq!(info, 0); + work_query[0].re() as i32 + } + } + }; +} + +macro_rules! impl_ormqr_real { + ($scalar:ty, $ormqr:expr) => { + impl Ormqr for $scalar { + fn ormqr( + side: u8, + trans: u8, + m: i32, + n: i32, + k: i32, + a: &mut [Self], + lda: i32, + tau: &mut [Self], + c: &mut [Self], + ldc: i32, + work: Option<&mut [Self]>, + ) -> i32 { + assert!(side == b'L' || side == b'R'); + assert!(trans == b'T' || trans == b'N'); + assert!(m >= 0); + assert!(n >= 0); + + assert!(if side == b'L' { + k >= 0 && k <= m + } else { + k >= 0 && k <= n + }); + assert!(if side == b'L' { + lda >= std::cmp::max(1, m) + } else { + lda >= std::cmp::max(1, n) + }); + assert_eq!(a.len() as i32, lda * k); + assert_eq!(tau.len() as i32, k); + assert!(ldc >= std::cmp::max(1, m)); + assert_eq!(c.len() as i32, ldc * n); + let mut my_work = Vec::::new(); + let work = if let Some(work) = work { + assert!(if side == b'L' { + work.len() as i32 >= std::cmp::max(1, n) + } else { + work.len() as i32 >= std::cmp::max(1, m) + }); + work + } else { + let len = ::ormqr_query_work(side, trans, m, n, k) as usize; + my_work.resize(len, ::zero()); + &mut my_work + }; + + let mut info = 0; + unsafe { + $ormqr( + side, + trans, + m, + n, + k, + a, + lda, + tau, + c, + ldc, + work, + work.len() as i32, + &mut info, + ) + }; + info + } + + fn ormqr_query_work(side: u8, trans: u8, m: i32, n: i32, k: i32) -> i32 { + let a = [::zero()]; + let tau = [::zero()]; + let mut c = [::zero()]; + let mut work_query = [::zero()]; + let mut info = 0; + + assert!(side == b'L' || side == b'R'); + assert!(trans == b'T' || trans == b'N'); + assert!(m >= 0); + assert!(n >= 0); + assert!(if side == b'L' { + k >= 0 && k <= m + } else { + k >= 0 && k <= n + }); + let lda = if side == b'L' { m } else { n }; + unsafe { + $ormqr( + side, + trans, + m, + n, + k, + &a, + lda, + &tau, + &mut c, + m, + &mut work_query, + -1, + &mut info, + ); + } + assert_eq!(info, 0); + work_query[0].re() as i32 + } + } + }; +} + +impl_ormqr_real!(f32, sormqr); +impl_ormqr_real!(f64, dormqr); +impl_ormqr_complex!(c32, cunmqr); +impl_ormqr_complex!(c64, zunmqr); diff --git a/lapack/src/unmqr.rs b/lapack/src/unmqr.rs new file mode 100644 index 00000000..56f9d2b0 --- /dev/null +++ b/lapack/src/unmqr.rs @@ -0,0 +1,50 @@ +//! Implementation of ormqr + +use lapack::cunmqr; +use lapack::dormqr; +use lapack::sormqr; +use lapack::zunmqr; +use num::Zero; +use rlst_common::types::*; + +use crate::Ormqr; + +pub trait Unmqr: Scalar { + fn unmqr( + side: u8, + trans: u8, + m: i32, + n: i32, + k: i32, + a: &mut [Self], + lda: i32, + tau: &mut [Self], + c: &mut [Self], + ldc: i32, + work: Option<&mut [Self]>, + ) -> i32; + + fn unmqr_query_work(side: u8, trans: u8, m: i32, n: i32, k: i32) -> i32; +} + +impl Unmqr for T { + fn unmqr( + side: u8, + trans: u8, + m: i32, + n: i32, + k: i32, + a: &mut [Self], + lda: i32, + tau: &mut [Self], + c: &mut [Self], + ldc: i32, + work: Option<&mut [Self]>, + ) -> i32 { + ::ormqr(side, trans, m, n, k, a, lda, tau, c, ldc, work) + } + + fn unmqr_query_work(side: u8, trans: u8, m: i32, n: i32, k: i32) -> i32 { + ::ormqr_query_work(side, trans, m, n, k) + } +}