Skip to content

Commit

Permalink
Merge 67e0db8 into 255ac68
Browse files Browse the repository at this point in the history
  • Loading branch information
RiboRings authored Jan 23, 2025
2 parents 255ac68 + 67e0db8 commit ad5bd14
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 64 deletions.
148 changes: 97 additions & 51 deletions R/mediate.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
#' (Default: \code{"holm"})
#'
#' @param add.metadata \code{Logical scalar}. Should the model metadata be
#' returned. (Default: \code{FALSE})
#' returned. (Default: \code{TRUE})
#'
#' @param sort \code{Logical scalar}. Should the results be sorted by decreasing
#' significance in terms of ACME_pval. (Default: \code{FALSE})
#'
#' @param verbose \code{Logical scalar}. Should execution messages be printed.
#' (Default: \code{TRUE})
Expand All @@ -61,22 +64,35 @@
#' \code{addMediation} returns an updated
#' \code{\link[SummarizedExperiment:SummarizedExperiment-class]{SummarizedExperiment}}
#' instance with the same \code{data.frame} stored in the metadata as
#' "mediation". Its columns include:
#' "mediation" or as specified in the \code{name} argument. Its columns include:
#'
#' \describe{
#' \item{Mediator}{the mediator variable}
#' \item{ACME_estimate}{the Average Causal Mediation Effect (ACME) estimate}
#' \item{ADE_estimate}{the Average Direct Effect (ADE) estimate}
#' \item{Total_estimate}{the Total Effect estimate}
#' \item{ACME_pval}{the adjusted p-value for the ACME estimate}
#' \item{ADE_pval}{the adjusted p-value for the ADE estimate}
#' \item{Total_pval}{the adjusted p-value for the Total Effect estimate}
#' \item{ACME_CI_lower}{the 2.5% CI for the ACME estimate}
#' \item{ACME_CI_upper}{the 2.5% CI for the ACME estimate}
#' \item{ADE_CI_lower}{the 2.5% CI for the ADE estimate}
#' \item{ADE_CI_upper}{the 97.5% CI for the ADE estimate}
#' \item{Total_CI_lower}{the 2.5 CI for the Total Effect estimate}
#' \item{Total_CI_upper}{the 97.5 CI for the Total Effect estimate}
#' }
#'
#' The original output of \code{\link[mediation:mediate]{mediate}} for each
#' analysed mediator is stored as the "model_metadata" attribute of the
#' resulting \code{data.frame} and can be accessed via the \code{attr} function.
#'
#' @name getMediation
#'
#' @examples
#' \dontrun{
#' # Import libraries
#' library(mia)
#' library(miaViz)
#' library(scater)
#'
#' # Load dataset
Expand All @@ -97,11 +113,10 @@
#' covariates = c("sex", "age"),
#' treat.value = "Scandinavia",
#' control.value = "CentralEurope",
#' boot = TRUE, sims = 100,
#' add.metadata = TRUE)
#' boot = TRUE, sims = 100)
#'
#' # Visualise model statistics for 1st mediator
#' plot(attr(med_df, "metadata")[[1]])
#' plotMediation(med_df)
#'
#' # Apply clr transformation to counts assay
#' tse <- transformAssay(tse,
Expand Down Expand Up @@ -143,6 +158,9 @@
#'
#' # Show results for first 5 mediators
#' head(metadata(tse)$reddim_mediation, 5)
#'
#' # Access model metadata
#' attr(metadata(tse)$reddim_mediation, "model_metadata")
#' }
#'
NULL
Expand All @@ -154,7 +172,7 @@ setMethod("addMediation", signature = c(x = "SummarizedExperiment"),
function(x, outcome, treatment, name = "mediation",
mediator = NULL, assay.type = NULL, dimred = NULL,
family = gaussian(), covariates = NULL, p.adj.method = "holm",
add.metadata = FALSE, verbose = TRUE, ...) {
add.metadata = TRUE, verbose = TRUE, ...) {

med_df <- getMediation(
x, outcome, treatment,
Expand All @@ -176,7 +194,7 @@ setMethod("getMediation", signature = c(x = "SummarizedExperiment"),
function(x, outcome, treatment,
mediator = NULL, assay.type = NULL, dimred = NULL,
family = gaussian(), covariates = NULL, p.adj.method = "holm",
add.metadata = FALSE, verbose = TRUE, ...) {
add.metadata = TRUE, sort = FALSE, verbose = TRUE, ...) {
###################### Input check ########################
if( !outcome %in% names(colData(x)) ){
stop(outcome, " not found in colData(x).", call. = FALSE)
Expand All @@ -195,7 +213,7 @@ setMethod("getMediation", signature = c(x = "SummarizedExperiment"),
}

# Check that arguments can be passed to mediate and filter samples
x <- .check.mediate.args(
x <- .check_mediate_args(
x, outcome, treatment, mediator, covariates, verbose, ...
)

Expand Down Expand Up @@ -242,14 +260,11 @@ setMethod("getMediation", signature = c(x = "SummarizedExperiment"),
mediators <- rownames(mat)
}

# Create template list of results
results <- list(
Mediator = c(), ACME_estimate = c(), ADE_estimate = c(),
ACME_pval = c(), ADE_pval = c(), Model = list()
)

# Create template list of models
models <- list()
# Set initial index
i <- 0

for( mediator in mediators ){
# Update index
i <- i + 1
Expand All @@ -264,18 +279,22 @@ setMethod("getMediation", signature = c(x = "SummarizedExperiment"),
family = family, mat = mat,
covariates = covariates, ...
)
# Update list of results
results <- .update.results(results, med_out, mediator)
# Update list of models
models <- c(models, list(med_out))
}

# Name models by mediators
names(models) <- mediators
# Combine results into dataframe
med_df <- .make.output(results, p.adj.method, add.metadata)
med_df <- .make_output(models, p.adj.method, add.metadata, sort)

return(med_df)
}
)

# Check that arguments can be passed to mediate and remove unused samples
#' @importFrom stats na.omit
.check.mediate.args <- function(
.check_mediate_args <- function(
x, outcome, treatment, mediator, covariates, verbose = TRUE, ...) {

# Create dataframe from selected columns of colData
Expand Down Expand Up @@ -344,6 +363,7 @@ setMethod("getMediation", signature = c(x = "SummarizedExperiment"),
#' @importFrom stats lm formula glm
.run_mediate <- function(x, outcome, treatment, mediator = NULL, mat = NULL,
family = gaussian(), covariates = NULL, ...) {

# Create initial dataframe with outcome and treatment variables
df <- data.frame(
Outcome = colData(x)[[outcome]], Treatment = colData(x)[[treatment]])
Expand Down Expand Up @@ -389,48 +409,74 @@ setMethod("getMediation", signature = c(x = "SummarizedExperiment"),
treat = "Treatment", mediator = "Mediator",
covariates = covariates, ...
)
return(med_out)
}


# Update list of results
.update.results <- function(results, med_out, mediator) {
# Update model variables
results[["Mediator"]] <- c(results[["Mediator"]], mediator)
# Update stats of ACME (average causal mediation effect)
results[["ACME_estimate"]] <- c(results[["ACME_estimate"]], med_out$d.avg)
results[["ACME_pval"]] <- c(results[["ACME_pval"]], med_out$d.avg.p)
# Update stats of ADE (average direct effect)
results[["ADE_estimate"]] <- c(results[["ADE_estimate"]], med_out$z.avg)
results[["ADE_pval"]] <- c(results[["ADE_pval"]], med_out$z.avg.p)
# Add current model to metadata
results[["Model"]][[length(results[["Model"]]) + 1]] <- med_out
return(results)
return(med_out)
}

# Combine results into output dataframe
.make.output <- function(results, p.adj.method, add.metadata) {

# Create dataframe with model variables, effect sizes and p-values
med_df <- do.call(data.frame, results[seq_len(length(results) - 1)])
.make_output <- function(models, p.adj.method, add.metadata, sort) {
# Create empty data.frame to store model metadata
med_df <- data.frame(matrix(nrow = length(models), ncol = 56))
# Use mediators and model properties as row and column names, respectively
rownames(med_df) <- names(models)
colnames(med_df) <- names(models[[1]])
# Iterate over mediators or rows
for( row in seq_len(length(models)) ){
# Iterate over model metadata or columns
for( col in colnames(med_df) ){
entry <- models[[row]][[col]]
# If entry is null, replace with empty string
if( is.null(entry) ){
entry <- ""
}
# If entry is not a scalar, convert to list
if( length(entry) > 1){
entry <- list(entry)
}
# Store entry into data.frame
med_df[[row, col]] <- I(entry)
}
}

# Convert rownames to column
med_df["Mediator"] <- rownames(med_df)
rownames(med_df) <- NULL
# Select columns to keep
med_df <- med_df[ , c("Mediator", "d.avg", "z.avg", "tau.coef", "d.avg.p",
"z.avg.p", "tau.p", "d.avg.ci", "z.avg.ci", "tau.ci")]
# Rename columns
colnames(med_df) <- c("Mediator", "ACME_estimate", "ADE_estimate",
"Total_estimate", "ACME_pval", "ADE_pval", "Total_pval", "ACME_CI",
"ADE_CI", "Total_CI")

# Compute adjusted p-values and add them to dataframe
med_df[["ACME_pval"]] <- p.adjust(
med_df[["ACME_pval"]],
method = p.adj.method
)
med_df[["ADE_pval"]] <- p.adjust(
med_df[["ADE_pval"]],
method = p.adj.method
pval_cols <- endsWith(colnames(med_df), "pval")
med_df[ , pval_cols] <- apply(
med_df[ , pval_cols], MARGIN = 2,
p.adjust, method = p.adj.method
)

if( add.metadata ){
# models for every mediator are saved into the metadata attribute
attr(med_df, "metadata") <- results[["Model"]]
# Find CI columns
ci_cols <- colnames(med_df)[endsWith(colnames(med_df), "CI")]
for( col in ci_cols ){
# Retrieve CI columns
ci_list <- unlist(med_df[ , col])
upper_cond <- seq_len(length(ci_list)) %% 2 == 0
names(ci_list) <- NULL
# Split lower and upper CIs
med_df[ , paste(col, "lower", sep = "_")] <- ci_list[!upper_cond]
med_df[ , paste(col, "upper", sep = "_")] <- ci_list[upper_cond]
med_df[ , col] <- NULL
}

# Order output dataframe by ACME p-values
med_df <- med_df[order(med_df[["ACME_pval"]]), ]
if( add.metadata ){
# Store model for each mediator into the model_metadata attribute
attr(med_df, "model_metadata") <- models
}
if( sort ){
# Order output dataframe by ACME p-values
med_df <- med_df[order(med_df[["ACME_pval"]]), ]
}

return(med_df)
}
}
33 changes: 26 additions & 7 deletions man/getMediation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions tests/testthat/test-mediate.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ test_that("getMediation", {

expect_error(
getMediation(tse, outcome = "bmi_group", treatment = "nationality", assay.type = "wrong_name"),
"wrong_name not found in assays(x).", fixed = TRUE
"'assay.type' must be a valid name of assays(x)", fixed = TRUE
)

expect_error(
Expand Down Expand Up @@ -69,10 +69,10 @@ test_that("getMediation", {
treat.value = "Scandinavia", control.value = "CentralEurope",
boot = TRUE, sims = 1)

expect_equal(attr(med_df, "metadata")[[1]]$d.avg, med_out$d.avg)
expect_equal(attr(med_df, "metadata")[[1]]$d.avg.p, med_out$d.avg.p)
expect_equal(attr(med_df, "metadata")[[1]]$z.avg, med_out$z.avg)
expect_equal(attr(med_df, "metadata")[[1]]$z.avg.p, med_out$z.avg.p)
expect_equal(attr(med_df, "model_metadata")[[1]][["d.avg"]], med_out$d.avg)
expect_equal(attr(med_df, "model_metadata")[[1]][["d.avg.p"]], med_out$d.avg.p)
expect_equal(attr(med_df, "model_metadata")[[1]][["z.avg"]], med_out$z.avg)
expect_equal(attr(med_df, "model_metadata")[[1]][["z.avg.p"]], med_out$z.avg.p)

### Batch 3: check output format and dimensionality with respect to SE ###
med_df <- getMediation(tse, outcome = "bmi_group", treatment = "nationality", assay.type = "counts",
Expand All @@ -81,6 +81,9 @@ test_that("getMediation", {

expect_named(tse, med_df[["Mediator"]])

expect_named(med_df, c("Mediator", "ACME_estimate", "ADE_estimate", "ACME_pval", "ADE_pval"))
expect_named(med_df, c("Mediator", "ACME_estimate", "ADE_estimate",
"Total_estimate", "ACME_pval", "ADE_pval", "Total_pval", "ACME_CI_lower",
"ACME_CI_upper", "ADE_CI_lower", "ADE_CI_upper", "Total_CI_lower",
"Total_CI_upper"))

})

0 comments on commit ad5bd14

Please sign in to comment.