Skip to content

Commit

Permalink
Updated distributed csr interface (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke authored Jan 10, 2025
1 parent e27fd4b commit edfd2fc
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 37 deletions.
9 changes: 7 additions & 2 deletions examples/cg_distributed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,13 @@ pub fn main() {
// The constructor takes care of the fact that the aij entries are only defined on rank 0.
// It sends the entries around according to the index layout and constructs the parallel
// distributed matrix.
let distributed_mat =
DistributedCsrMatrix::from_aij(&index_layout, &index_layout, &rows, &cols, &data);
let distributed_mat = DistributedCsrMatrix::from_aij(
index_layout.clone(),
index_layout.clone(),
&rows,
&cols,
&data,
);

// We can now wrap the matrix into an operator.
let op = Operator::from(distributed_mat);
Expand Down
13 changes: 9 additions & 4 deletions examples/distributed_csr_matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,13 @@ fn main() {
data.push(item);
});

dist_mat =
DistributedCsrMatrix::from_aij(&domain_layout, &range_layout, &rows, &cols, &data);
dist_mat = DistributedCsrMatrix::from_aij(
domain_layout.clone(),
range_layout.clone(),
&rows,
&cols,
&data,
);

// dist_mat = DistributedCsrMatrix::from_serial_root(
// sparse_mat,
Expand All @@ -79,8 +84,8 @@ fn main() {
// Create a distributed matrix on the non-root node (compare to `from_serial_root`).
//dist_mat = DistributedCsrMatrix::from_serial(0, &domain_layout, &range_layout, &world);
dist_mat = DistributedCsrMatrix::from_aij(
&domain_layout,
&range_layout,
domain_layout.clone(),
range_layout.clone(),
&Vec::default(),
&Vec::default(),
&Vec::default(),
Expand Down
21 changes: 9 additions & 12 deletions src/operator/interface/distributed_sparse_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ use super::DistributedArrayVectorSpace;

/// CSR matrix operator
pub struct DistributedCsrMatrixOperatorImpl<
'a,
DomainLayout: IndexLayout<Comm = C>,
RangeLayout: IndexLayout<Comm = C>,
Item: RlstScalar + Equivalence,
C: Communicator,
> {
csr_mat: DistributedCsrMatrix<'a, DomainLayout, RangeLayout, Item, C>,
csr_mat: DistributedCsrMatrix<DomainLayout, RangeLayout, Item, C>,
domain: DistributedArrayVectorSpace<DomainLayout, Item>,
range: DistributedArrayVectorSpace<RangeLayout, Item>,
}
Expand All @@ -32,7 +31,7 @@ impl<
RangeLayout: IndexLayout<Comm = C>,
Item: RlstScalar + Equivalence,
C: Communicator,
> std::fmt::Debug for DistributedCsrMatrixOperatorImpl<'_, DomainLayout, RangeLayout, Item, C>
> std::fmt::Debug for DistributedCsrMatrixOperatorImpl<DomainLayout, RangeLayout, Item, C>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DistributedCsrMatrixOperator")
Expand All @@ -43,16 +42,15 @@ impl<
}

impl<
'a,
DomainLayout: IndexLayout<Comm = C>,
RangeLayout: IndexLayout<Comm = C>,
Item: RlstScalar + Equivalence,
C: Communicator,
> DistributedCsrMatrixOperatorImpl<'a, DomainLayout, RangeLayout, Item, C>
> DistributedCsrMatrixOperatorImpl<DomainLayout, RangeLayout, Item, C>
{
/// Create a new CSR matrix operator
pub fn new(
csr_mat: DistributedCsrMatrix<'a, DomainLayout, RangeLayout, Item, C>,
csr_mat: DistributedCsrMatrix<DomainLayout, RangeLayout, Item, C>,
domain: DistributedArrayVectorSpace<DomainLayout, Item>,
range: DistributedArrayVectorSpace<RangeLayout, Item>,
) -> Self {
Expand Down Expand Up @@ -82,7 +80,7 @@ impl<
RangeLayout: IndexLayout<Comm = C>,
Item: RlstScalar + Equivalence,
C: Communicator,
> OperatorBase for DistributedCsrMatrixOperatorImpl<'_, DomainLayout, RangeLayout, Item, C>
> OperatorBase for DistributedCsrMatrixOperatorImpl<DomainLayout, RangeLayout, Item, C>
{
type Domain = DistributedArrayVectorSpace<DomainLayout, Item>;
type Range = DistributedArrayVectorSpace<RangeLayout, Item>;
Expand All @@ -101,7 +99,7 @@ impl<
RangeLayout: IndexLayout<Comm = C>,
Item: RlstScalar + Equivalence,
C: Communicator,
> AsApply for DistributedCsrMatrixOperatorImpl<'_, DomainLayout, RangeLayout, Item, C>
> AsApply for DistributedCsrMatrixOperatorImpl<DomainLayout, RangeLayout, Item, C>
{
fn apply_extended(
&self,
Expand All @@ -116,15 +114,14 @@ impl<
}

impl<
'a,
DomainLayout: IndexLayout<Comm = C>,
RangeLayout: IndexLayout<Comm = C>,
Item: RlstScalar + Equivalence,
C: Communicator,
> From<DistributedCsrMatrix<'a, DomainLayout, RangeLayout, Item, C>>
for Operator<DistributedCsrMatrixOperatorImpl<'a, DomainLayout, RangeLayout, Item, C>>
> From<DistributedCsrMatrix<DomainLayout, RangeLayout, Item, C>>
for Operator<DistributedCsrMatrixOperatorImpl<DomainLayout, RangeLayout, Item, C>>
{
fn from(csr_mat: DistributedCsrMatrix<'a, DomainLayout, RangeLayout, Item, C>) -> Self {
fn from(csr_mat: DistributedCsrMatrix<DomainLayout, RangeLayout, Item, C>) -> Self {
let domain_layout = csr_mat.domain_layout();
let range_layout = csr_mat.range_layout();
let domain = DistributedArrayVectorSpace::new(domain_layout.clone());
Expand Down
39 changes: 20 additions & 19 deletions src/sparse/sparse_mat/distributed_csr_mat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use super::tools::normalize_aij;

/// Distributed CSR matrix
pub struct DistributedCsrMatrix<
'a,
DomainLayout: IndexLayout<Comm = C>,
RangeLayout: IndexLayout<Comm = C>,
T: RlstScalar + Equivalence,
Expand All @@ -31,26 +30,25 @@ pub struct DistributedCsrMatrix<
local_matrix: CsrMatrix<T>,
global_indices: Vec<usize>,
local_dof_count: usize,
domain_layout: &'a DomainLayout,
range_layout: &'a RangeLayout,
domain_layout: DomainLayout,
range_layout: RangeLayout,
domain_ghosts: GhostCommunicator<usize>,
}

impl<
'a,
DomainLayout: IndexLayout<Comm = C>,
RangeLayout: IndexLayout<Comm = C>,
T: RlstScalar + Equivalence,
C: Communicator,
> DistributedCsrMatrix<'a, DomainLayout, RangeLayout, T, C>
> DistributedCsrMatrix<DomainLayout, RangeLayout, T, C>
{
/// Create new
pub fn new(
indices: Vec<usize>,
indptr: Vec<usize>,
data: Vec<T>,
domain_layout: &'a DomainLayout,
range_layout: &'a RangeLayout,
domain_layout: DomainLayout,
range_layout: RangeLayout,
) -> Self {
assert!(std::ptr::addr_eq(domain_layout.comm(), range_layout.comm()));
let comm = domain_layout.comm();
Expand Down Expand Up @@ -140,19 +138,19 @@ impl<
}

/// Domain layout
pub fn domain_layout(&self) -> &'a DomainLayout {
self.domain_layout
pub fn domain_layout(&self) -> &DomainLayout {
&self.domain_layout
}

/// Range layout
pub fn range_layout(&self) -> &'a RangeLayout {
self.range_layout
pub fn range_layout(&self) -> &RangeLayout {
&self.range_layout
}

/// Create a new distributed CSR matrix from an aij format.
pub fn from_aij(
domain_layout: &'a DomainLayout,
range_layout: &'a RangeLayout,
domain_layout: DomainLayout,
range_layout: RangeLayout,
rows: &[usize],
cols: &[usize],
data: &[T],
Expand Down Expand Up @@ -230,10 +228,13 @@ impl<
/// Create from root
pub fn from_serial(
root: usize,
domain_layout: &'a DomainLayout,
range_layout: &'a RangeLayout,
comm: &'a C,
domain_layout: DomainLayout,
range_layout: RangeLayout,
) -> Self {
assert!(std::ptr::addr_eq(domain_layout.comm(), range_layout.comm()));

let comm = domain_layout.comm();

let root_process = comm.process_at_rank(root as i32);

let my_index_range = range_layout.local_range();
Expand Down Expand Up @@ -286,8 +287,8 @@ impl<
/// Create from root
pub fn from_serial_root(
csr_mat: CsrMatrix<T>,
domain_layout: &'a DomainLayout,
range_layout: &'a RangeLayout,
domain_layout: DomainLayout,
range_layout: RangeLayout,
) -> Self {
assert!(std::ptr::addr_eq(domain_layout.comm(), range_layout.comm()));
let comm = domain_layout.comm();
Expand Down Expand Up @@ -450,7 +451,7 @@ impl<
DomainLayout: IndexLayout<Comm = C>,
RangeLayout: IndexLayout<Comm = C>,
C: Communicator,
> Shape<2> for DistributedCsrMatrix<'_, DomainLayout, RangeLayout, T, C>
> Shape<2> for DistributedCsrMatrix<DomainLayout, RangeLayout, T, C>
{
fn shape(&self) -> [usize; 2] {
[
Expand Down

0 comments on commit edfd2fc

Please sign in to comment.