Skip to content

Commit

Permalink
Multiple interface updates for LU decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke committed Nov 6, 2023
1 parent f614824 commit dc7ba40
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 56 deletions.
3 changes: 0 additions & 3 deletions algorithms/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
//! Collection of Linear Solver Algorithms and Interfaces
#![cfg_attr(feature = "strict", deny(warnings))]

extern crate rlst_blis_src;
extern crate rlst_netlib_lapack_src;

//pub mod dense;
// pub mod iterative_solvers;
// pub mod lapack;
Expand Down
6 changes: 3 additions & 3 deletions common/src/traits/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,17 @@ pub trait PermuteRows {
/// Multiply First * Second and sum into Self
pub trait MultInto<First, Second> {
type Item: Scalar;
fn mult_into(&mut self, alpha: Self::Item, arr_a: First, arr_b: Second, beta: Self::Item);
fn mult_into(self, alpha: Self::Item, arr_a: First, arr_b: Second, beta: Self::Item) -> Self;
}

/// Multiply First * Second and sum into Self. Allow to resize Self if necessary
pub trait MultIntoResize<First, Second> {
type Item: Scalar;
fn mult_into_resize(
&mut self,
self,
alpha: Self::Item,
arr_a: First,
arr_b: Second,
beta: Self::Item,
);
) -> Self;
}
17 changes: 8 additions & 9 deletions dense/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@ cauchy = "0.4"
rand = "0.8"
itertools = "0.10"
rand_distr = "0.4"
rlst-blis = { path = "../blis"}
approx = { version = "0.5", features=["num-complex"] }
rlst-operator = {path = "../operator"}
rlst-common = {path = "../common"}
rlst-lapack = {path = "../lapack"}
rlst-blis = { path = "../blis" }
approx = { version = "0.5", features = ["num-complex"] }
rlst-operator = { path = "../operator" }
rlst-common = { path = "../common" }
rlst-lapack = { path = "../lapack" }
paste = "1"
rand_chacha = "0.3"
rlst-blis-src = { path = "../blis-src" }
rlst-netlib-lapack-src = { path = "../netlib-lapack-src" }

[dev-dependencies]
criterion = { version = "0.3", features = ["html_reports"] }

[package.metadata.docs.rs]
rustdoc-args = [ "--html-in-header", "katex-header.html" ]



rustdoc-args = ["--html-in-header", "katex-header.html"]
7 changes: 7 additions & 0 deletions dense/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,10 @@ impl<
self.0.resize_in_place(shape)
}
}

/// Create an empty array of given type and dimension.
pub fn empty_array<Item: Scalar, const NDIM: usize>() -> DynamicArray<Item, NDIM> {
let shape = [0; NDIM];
let container = VectorContainer::new(0);
Array::new(BaseArray::new(container, shape))
}
46 changes: 28 additions & 18 deletions dense/src/array/mult_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@ impl<
type Item = Item;

fn mult_into(
&mut self,
mut self,
alpha: Item,
arr_a: Array<Item, ArrayImplFirst, 2>,
arr_b: Array<Item, ArrayImplSecond, 2>,
beta: Item,
) {
) -> Self {
let transa = TransMode::NoTrans;
let transb = TransMode::NoTrans;
crate::matrix_multiply::matrix_multiply(transa, transb, alpha, &arr_a, &arr_b, beta, self)
crate::matrix_multiply::matrix_multiply(
transa, transb, alpha, &arr_a, &arr_b, beta, &mut self,
);
self
}
}

Expand All @@ -64,12 +67,12 @@ impl<
type Item = Item;

fn mult_into(
&mut self,
mut self,
alpha: Item,
arr_a: Array<Item, ArrayImplFirst, 2>,
arr_b: Array<Item, ArrayImplSecond, 1>,
beta: Item,
) {
) -> Self {
let transa = TransMode::NoTrans;
let transb = TransMode::NoTrans;

Expand All @@ -84,7 +87,8 @@ impl<
&arr_with_padded_dim,
beta,
&mut self_with_padded_dim,
)
);
self
}
}

Expand All @@ -111,12 +115,12 @@ impl<
type Item = Item;

fn mult_into(
&mut self,
mut self,
alpha: Item,
arr_a: Array<Item, ArrayImplFirst, 1>,
arr_b: Array<Item, ArrayImplSecond, 2>,
beta: Item,
) {
) -> Self {
let transa = TransMode::NoTrans;
let transb = TransMode::NoTrans;

Expand All @@ -131,7 +135,8 @@ impl<
&arr_b,
beta,
&mut self_with_padded_dim,
)
);
self
}
}

Expand Down Expand Up @@ -161,12 +166,12 @@ impl<
type Item = Item;

fn mult_into_resize(
&mut self,
mut self,
alpha: Item,
arr_a: Array<Item, ArrayImplFirst, 2>,
arr_b: Array<Item, ArrayImplSecond, 2>,
beta: Item,
) {
) -> Self {
let transa = TransMode::NoTrans;
let transb = TransMode::NoTrans;

Expand All @@ -175,7 +180,10 @@ impl<
self.resize_in_place(expected_shape);
}

crate::matrix_multiply::matrix_multiply(transa, transb, alpha, &arr_a, &arr_b, beta, self)
crate::matrix_multiply::matrix_multiply(
transa, transb, alpha, &arr_a, &arr_b, beta, &mut self,
);
self
}
}

Expand Down Expand Up @@ -203,12 +211,12 @@ impl<
type Item = Item;

fn mult_into_resize(
&mut self,
mut self,
alpha: Item,
arr_a: Array<Item, ArrayImplFirst, 2>,
arr_b: Array<Item, ArrayImplSecond, 1>,
beta: Item,
) {
) -> Self {
let transa = TransMode::NoTrans;
let transb = TransMode::NoTrans;

Expand All @@ -229,7 +237,8 @@ impl<
&arr_with_padded_dim,
beta,
&mut self_with_padded_dim,
)
);
self
}
}

Expand Down Expand Up @@ -257,12 +266,12 @@ impl<
type Item = Item;

fn mult_into_resize(
&mut self,
mut self,
alpha: Item,
arr_a: Array<Item, ArrayImplFirst, 1>,
arr_b: Array<Item, ArrayImplSecond, 2>,
beta: Item,
) {
) -> Self {
let transa = TransMode::NoTrans;
let transb = TransMode::NoTrans;

Expand All @@ -283,6 +292,7 @@ impl<
&arr_b,
beta,
&mut self_with_padded_dim,
)
);
self
}
}
3 changes: 3 additions & 0 deletions dense/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
//! - [Examples](crate::examples)
#![cfg_attr(feature = "strict", deny(warnings))]

extern crate rlst_blis_src;
extern crate rlst_netlib_lapack_src;

pub mod base_array;
pub mod data_container;
pub mod linalg;
Expand Down
Loading

0 comments on commit dc7ba40

Please sign in to comment.