Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize SUNMatrixWrapper functions #1538

Merged
merged 8 commits into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 90 additions & 12 deletions include/amici/sundials_matrix_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

#include <vector>

#include <assert.h>

#include "amici/vector.h"

namespace amici {
Expand Down Expand Up @@ -118,13 +120,25 @@ class SUNMatrixWrapper {
* @brief Get the number of rows
* @return number of rows
*/
sunindextype rows() const;
sunindextype rows() const {
assert(!matrix_ ||
(matrix_id() == SUNMATRIX_SPARSE ?
num_rows_ == SM_ROWS_S(matrix_) :
num_rows_ == SM_ROWS_D(matrix_)));
return num_rows_;
}

/**
* @brief Get the number of columns
* @return number of columns
*/
sunindextype columns() const;
sunindextype columns() const {
assert(!matrix_ ||
(matrix_id() == SUNMATRIX_SPARSE ?
num_columns_ == SM_COLUMNS_S(matrix_) :
num_columns_ == SM_COLUMNS_D(matrix_)));
return num_columns_;
}

/**
* @brief Get the number of specified non-zero elements (sparse matrices only)
Expand Down Expand Up @@ -162,70 +176,134 @@ class SUNMatrixWrapper {
* @param idx data index
* @return idx-th data entry
*/
realtype get_data(sunindextype idx) const;
realtype get_data(sunindextype idx) const{
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(idx < capacity());
assert(SM_DATA_S(matrix_) == data_);
return data_[idx];
}

/**
* @brief Get data entry for a dense matrix
* @param irow row
* @param icol col
* @return A(irow,icol)
*/
realtype get_data(sunindextype irow, sunindextype icol) const;
realtype get_data(sunindextype irow, sunindextype icol) const{
assert(matrix_);
assert(matrix_id() == SUNMATRIX_DENSE);
assert(irow < rows());
assert(icol < columns());
return SM_ELEMENT_D(matrix_, irow, icol);
}

/**
* @brief Set data entry for a sparse matrix
* @param idx data index
* @param data data for idx-th entry
*/
void set_data(sunindextype idx, realtype data);
void set_data(sunindextype idx, realtype data) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(idx < capacity());
assert(SM_DATA_S(matrix_) == data_);
data_[idx] = data;
}

/**
* @brief Set data entry for a dense matrix
* @param irow row
* @param icol col
* @param data data for idx-th entry
*/
void set_data(sunindextype irow, sunindextype icol, realtype data);
void set_data(sunindextype irow, sunindextype icol, realtype data) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_DENSE);
assert(irow < rows());
assert(icol < columns());
SM_ELEMENT_D(matrix_, irow, icol) = data;
}

/**
* @brief Get the index value of a sparse matrix
* @param idx data index
* @return row (CSC) or column (CSR) for idx-th data entry
*/
sunindextype get_indexval(sunindextype idx) const;
sunindextype get_indexval(sunindextype idx) const {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(idx < capacity());
assert(indexvals_ == SM_INDEXVALS_S(matrix_));
return indexvals_[idx];
}

/**
* @brief Set the index value of a sparse matrix
* @param idx data index
* @param val row (CSC) or column (CSR) for idx-th data entry
*/
void set_indexval(sunindextype idx, sunindextype val);
void set_indexval(sunindextype idx, sunindextype val) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(idx < capacity());
assert(indexvals_ == SM_INDEXVALS_S(matrix_));
indexvals_[idx] = val;
}

/**
* @brief Set the index values of a sparse matrix
* @param vals rows (CSC) or columns (CSR) for data entries
*/
void set_indexvals(const gsl::span<const sunindextype> vals);
void set_indexvals(const gsl::span<const sunindextype> vals) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(static_cast<sunindextype>(vals.size()) == capacity());
assert(indexvals_ == SM_INDEXVALS_S(matrix_));
std::copy_n(vals.begin(), capacity(), indexvals_);
}

/**
* @brief Get the index pointer of a sparse matrix
* @param ptr_idx pointer index
* @return index where the ptr_idx-th column (CSC) or row (CSR) starts
*/
sunindextype get_indexptr(sunindextype ptr_idx) const;
sunindextype get_indexptr(sunindextype ptr_idx) const {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(ptr_idx <= num_indexptrs());
assert(indexptrs_ == SM_INDEXPTRS_S(matrix_));
return indexptrs_[ptr_idx];
}

/**
* @brief Set the index pointer of a sparse matrix
* @param ptr_idx pointer index
* @param ptr data-index where the ptr_idx-th column (CSC) or row (CSR) starts
*/
void set_indexptr(sunindextype ptr_idx, sunindextype ptr);
void set_indexptr(sunindextype ptr_idx, sunindextype ptr) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(ptr_idx <= num_indexptrs());
assert(ptr <= capacity());
assert(indexptrs_ == SM_INDEXPTRS_S(matrix_));
indexptrs_[ptr_idx] = ptr;
if (ptr_idx == num_indexptrs())
num_nonzeros_ = ptr;
}

/**
* @brief Set the index pointers of a sparse matrix
* @param ptrs starting data-indices where the columns (CSC) or rows (CSR) start
*/
void set_indexptrs(const gsl::span<const sunindextype> ptrs);
void set_indexptrs(const gsl::span<const sunindextype> ptrs) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(static_cast<sunindextype>(ptrs.size()) == num_indexptrs() + 1);
assert(indexptrs_ == SM_INDEXPTRS_S(matrix_));
std::copy_n(ptrs.begin(), num_indexptrs() + 1, indexptrs_);
num_nonzeros_ = indexptrs_[num_indexptrs()];
}

/**
* @brief Get the type of sparse matrix
Expand Down
Loading