Skip to content

Commit

Permalink
further speed up graflex
Browse files Browse the repository at this point in the history
  • Loading branch information
suzannejin committed Sep 14, 2024
1 parent 4383d0b commit e89872e
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 49 deletions.
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
## propr 5.1.1
# propr 5.1.2
---------------------
* Restructured `graflex` to speed up

# propr 5.1.1
---------------------
* Speed up `graflex` related functions

Expand Down
10 changes: 3 additions & 7 deletions R/3-shared-graflex.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,15 @@ runGraflex <- function(A, K, p=100, ncores=1) {
stop("'A' must be a square matrix.")

if (ncores == 1){

# for each knowledge network, calculate odds ratio and FDR
res <- lapply(1:ncol(K), function(k) {
Gk <- K[, k] %*% t(K[, k]) # converts the k column into an adjacency matrix (genes x genes)
graflex(A, Gk, p=p) # this calls the graflex function implemented in Rcpp C++
graflex(A, K[,k], p=p) # this calls the modified graflex function implemented in Rcpp C++
})

}else{
} else {
packageCheck("parallel")
cl <- parallel::makeCluster(ncores)
res <- parallel::parLapply(cl, 1:ncol(K), function(k) {
Gk <- K[, k] %*% t(K[, k])
graflex(A, Gk, p=p)
graflex(A, K[,k], p=p)
})
parallel::stopCluster(cl)
}
Expand Down
8 changes: 6 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,12 @@ getFDR <- function(actual, permuted) {
.Call(`_propr_getFDR`, actual, permuted)
}

graflex <- function(A, G, p = 100L) {
.Call(`_propr_graflex`, A, G, p)
getG <- function(Gk) {
.Call(`_propr_getG`, Gk)
}

graflex <- function(A, Gk, p = 100L) {
.Call(`_propr_graflex`, A, Gk, p)
}

lr2vlr <- function(lr) {
Expand Down
20 changes: 16 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,16 +367,27 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// getG
IntegerMatrix getG(const IntegerVector& Gk);
RcppExport SEXP _propr_getG(SEXP GkSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const IntegerVector& >::type Gk(GkSEXP);
rcpp_result_gen = Rcpp::wrap(getG(Gk));
return rcpp_result_gen;
END_RCPP
}
// graflex
NumericVector graflex(const IntegerMatrix& A, const IntegerMatrix& G, int p);
RcppExport SEXP _propr_graflex(SEXP ASEXP, SEXP GSEXP, SEXP pSEXP) {
NumericVector graflex(const IntegerMatrix& A, const IntegerVector& Gk, int p);
RcppExport SEXP _propr_graflex(SEXP ASEXP, SEXP GkSEXP, SEXP pSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const IntegerMatrix& >::type A(ASEXP);
Rcpp::traits::input_parameter< const IntegerMatrix& >::type G(GSEXP);
Rcpp::traits::input_parameter< const IntegerVector& >::type Gk(GkSEXP);
Rcpp::traits::input_parameter< int >::type p(pSEXP);
rcpp_result_gen = Rcpp::wrap(graflex(A, G, p));
rcpp_result_gen = Rcpp::wrap(graflex(A, Gk, p));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -510,6 +521,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_propr_getORperm", (DL_FUNC) &_propr_getORperm, 3},
{"_propr_permuteOR", (DL_FUNC) &_propr_permuteOR, 3},
{"_propr_getFDR", (DL_FUNC) &_propr_getFDR, 2},
{"_propr_getG", (DL_FUNC) &_propr_getG, 1},
{"_propr_graflex", (DL_FUNC) &_propr_graflex, 3},
{"_propr_lr2vlr", (DL_FUNC) &_propr_lr2vlr, 1},
{"_propr_lr2phi", (DL_FUNC) &_propr_lr2phi, 1},
Expand Down
83 changes: 48 additions & 35 deletions src/graflex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,49 @@
#include <numeric>
using namespace Rcpp;

// Function to calculate the contingency table and the odds ratio
// Optimized function to calculate the contingency table and the odds ratio
// [[Rcpp::export]]
NumericVector getOR(const IntegerMatrix& A, const IntegerMatrix& G) {
int ncol = A.ncol();

// calculate the contingency table
int a = 0, b = 0, c = 0, d = 0;
for (int j = 0; j < ncol; ++j) {
for (int i = j+1; i < ncol; ++i) {
if (A(i, j) == 0) {
if (G(i, j) == 0) ++a; // not in A and not in G
else ++b; // not in A but in G
} else {
if (G(i, j) == 0) ++c; // in A but not in G
else ++d; // in A and in G
}
}

for (int j = 0; j < ncol - 1; ++j) {
for (int i = j + 1; i < ncol; ++i) {
int a_val = A(i, j), g_val = G(i, j);
a += (1 - a_val) * (1 - g_val);
b += (1 - a_val) * g_val;
c += a_val * (1 - g_val);
d += a_val * g_val;
}
}

// calculate the odds ratio
double odds_ratio = static_cast<double>(a * d) / (b * c);
double log_odds_ratio = std::log(odds_ratio);

return NumericVector::create(
a, b, c, d, odds_ratio, std::log(odds_ratio), R_NaN, R_NaN
);
return NumericVector::create(a, b, c, d, odds_ratio, log_odds_ratio, R_NaN, R_NaN);
}

// Function to calculate the contingency table and the odds ratio, given a permuted index vector
// Optimized function to calculate the contingency table and the odds ratio, given a permuted index vector
// [[Rcpp::export]]
NumericVector getORperm(const IntegerMatrix& A, const IntegerMatrix& G, const IntegerVector& perm) {
int ncol = A.ncol();

// calculate the contingency table
int a = 0, b = 0, c = 0, d = 0;
for (int j = 0; j < ncol; ++j) {
for (int i = j+1; i < ncol; ++i) {
if (A(perm[i], perm[j]) == 0) {
if (G(i, j) == 0) ++a; // not in A and not in G
else ++b; // not in A but in G
} else {
if (G(i, j) == 0) ++c; // in A but not in G
else ++d; // in A and in G
}
}

for (int j = 0; j < ncol - 1; ++j) {
int pj = perm[j];
for (int i = j + 1; i < ncol; ++i) {
int a_val = A(perm[i], pj), g_val = G(i, j);
a += (1 - a_val) * (1 - g_val);
b += (1 - a_val) * g_val;
c += a_val * (1 - g_val);
d += a_val * g_val;
}
}

// calculate the odds ratio
double odds_ratio = static_cast<double>(a * d) / (b * c);
double log_odds_ratio = std::log(odds_ratio);

return NumericVector::create(
a, b, c, d, odds_ratio, std::log(odds_ratio), R_NaN, R_NaN
);
return NumericVector::create(a, b, c, d, odds_ratio, log_odds_ratio, R_NaN, R_NaN);
}

// Function to calculate the odds ratio and other relevant info for each permutation
Expand Down Expand Up @@ -99,9 +90,31 @@ List getFDR(double actual, const NumericVector& permuted) {
);
}

// Function to calculate the G matrix from the Gk vector
// [[Rcpp::export]]
IntegerMatrix getG(const IntegerVector& Gk) {
int n = Gk.size();
IntegerMatrix G(n, n);

for (int i = 0; i < n; ++i) {
int gi = Gk[i];
G(i, i) = gi * gi;
for (int j = 0; j < i; ++j) {
int value = gi * Gk[j];
G(i, j) = value;
G(j, i) = value;
}
}

return G;
}

// Function to calculate the odds ratio and FDR, given the adjacency matrix A and the knowledge graph G
// [[Rcpp::export]]
NumericVector graflex(const IntegerMatrix& A, const IntegerMatrix& G, int p = 100) {
NumericVector graflex(const IntegerMatrix& A, const IntegerVector& Gk, int p = 100) {

// Calculate Gk
IntegerMatrix G = getG(Gk);

// get the actual odds ratio
NumericVector actual = getOR(A, G);
Expand Down

0 comments on commit e89872e

Please sign in to comment.