Skip to content

Commit

Permalink
- Features used can be extracted from penalised Cox model.
Browse files Browse the repository at this point in the history
- Final model can be used with predict function.
- prepareData can filter correlated features.
- easyHard helps to associate variables with per-sample prediction accuracy.
  • Loading branch information
Dario Strbenac committed Dec 21, 2024
1 parent 39dd7d4 commit eca8bce
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 49 deletions.
8 changes: 4 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ Type: Package
Title: A framework for cross-validated classification problems, with
applications to differential variability and differential
distribution testing
Version: 3.10.4
Date: 2024-12-13
Version: 3.10.5
Date: 2024-12-20
Authors@R:
c(
person(given = "Dario", family = "Strbenac", email = "[email protected]", role = c("aut", "cre")),
Expand All @@ -20,9 +20,9 @@ VignetteBuilder: knitr
Encoding: UTF-8
biocViews: Classification, Survival
Depends: R (>= 4.1.0), generics, methods, S4Vectors, MultiAssayExperiment, BiocParallel, survival
Imports: grid, genefilter, utils, dplyr, tidyr, rlang, ranger, ggplot2 (>= 3.0.0), ggpubr, reshape2, ggupset
Imports: grid, genefilter, utils, dplyr, tidyr, rlang, ranger, ggplot2 (>= 3.0.0), ggpubr, reshape2, ggupset, broom, dcanr
Suggests: limma, edgeR, car, Rmixmod, gridExtra (>= 2.0.0), cowplot,
BiocStyle, pamr, PoiClaClu, parathyroidSE, knitr, htmltools, gtable,
BiocStyle, pamr, PoiClaClu, knitr, htmltools, gtable,
scales, e1071, rmarkdown, IRanges, robustbase, glmnet, class, randomForestSRC,
MatrixModels, xgboost, data.tree, ggnewscale
Description: The software formalises a framework for classification and survival model evaluation
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export(crissCrossPlot)
export(crissCrossValidate)
export(crossValidate)
export(distribution)
export(easyHard)
export(edgesToHubNetworks)
export(featureSetSummary)
export(finalModel)
Expand Down Expand Up @@ -80,6 +81,7 @@ exportMethods(calcExternalPerformance)
exportMethods(chosenFeatureNames)
exportMethods(crossValidate)
exportMethods(distribution)
exportMethods(easyHard)
exportMethods(featureSetSummary)
exportMethods(finalModel)
exportMethods(interactorDifferences)
Expand Down Expand Up @@ -111,6 +113,8 @@ import(reshape2)
importFrom(S4Vectors,as.data.frame)
importFrom(S4Vectors,do.call)
importFrom(S4Vectors,mcols)
importFrom(broom,tidy)
importFrom(dcanr,cor.pairs)
importFrom(dplyr,mutate)
importFrom(dplyr,n)
importFrom(generics,train)
Expand Down
96 changes: 94 additions & 2 deletions R/calcPerformance.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
#' @param result An object of class \code{\link{ClassifyResult}}.
#' @param resultsList A list of modelling results. Each element must be of type \code{\link{ClassifyResult}}.
#' @param performanceTypes Default: \code{"auto"} A character vector. If \code{"auto"}, Balanced Accuracy will be used
#' for a classification task and C-index for a time-to-event task.
#' for a classification task and C-index for a time-to-event task. If using \code{easyHard}, the default is
#' \code{"Sample Accuracy"} for a classification task and \code{"Sample C-index"} for a time-to-event task.
#' Must be one of the following options:
#' \itemize{
#' \item{\code{"Error"}: Ordinary error rate.}
Expand Down Expand Up @@ -130,6 +131,11 @@ setMethod("calcExternalPerformance", c("factor", "tabular"), # table has class p
)
})

calcCVperformance <- function(results, ...)
{
lapply(results, calcCVperformance)
}

#' @rdname calcPerformance
#' @usage NULL
#' @export
Expand Down Expand Up @@ -427,4 +433,90 @@ performanceTable <- function(resultsList, performanceTypes = "auto", aggregate =
})
DataFrame(tidyr::pivot_wider(as.data.frame(result@characteristics), names_from = characteristic, values_from = value), performances, check.names = FALSE)
}))
}
}

#' @rdname calcPerformance
#' @usage NULL
#' @export
setGeneric("easyHard", function(measurements, result, assay, performanceType, ...)
standardGeneric("easyHard"))

#' @rdname calcPerformance
#' @exportMethod easyHard
#' @importFrom broom tidy
#' @param assay For \code{easyHard} only. The assay to use to look for associations to the per-sample metric.
#' @param useFeatures For \code{easyHard} only. Default: \code{NULL} (i.e. use all provided features). A vector of features to consider of the assay specified.
#' This allows for the avoidance of variables such spike-in RNAs, sample IDs, sample acquisition dates, etc. which are not relevant for outcome prediction.
#' @param fitMode For \code{easyHard} only. Default:\code{"single"}. Either \code{"single"} or \code{"full"}. If \code{"single"},
#' an ordinary GLM model is fitted for each covariate separately. If \code{"full"}, elastic net is used to automatically tune the non-zero model coefficients.
#' @return For \code{easyHard}, a \code{\link{DataFrame}} of logistic regression model summary.

setMethod("easyHard", "MultiAssayExperimentOrList",
function(measurements, result, assay = "clinical", useFeatures = NULL, performanceType = "auto",
fitMode = c("single", "full"))
{
if(!requireNamespace("glmnet", quietly = TRUE))
stop("The package 'glmnet' could not be found. Please install it.")

if(!assay %in% names(measurements)) stop("'assay' is not one of the names of 'measurements'.")
fitMode <- match.arg(fitMode)

if(is(measurements, "MultiAssayExperiment"))
{
if(assay == "clinical")
assay <- colData(measurements)
else assay <- t(measurements[, , assay]) # Ensure that features are in columns.
} else {assay <- measurements[[assay]]}
if(!is.null(useFeatures)) assay <- assay[, useFeatures]
if(performanceType == "auto")
{
if("risk" %in% colnames(predictions(result)))
{
performanceType <- "Sample C-index"
} else {performanceType <-"Sample Accuracy"}
}
if(!performanceType %in% names(performance(result)))
{
warning(paste(performanceType, "not found in result. Calculating it now."))
result <- calcCVperformance(result, performanceType)
}
samplePerformance <- performance(result)[[performanceType]]
if(any(is.na(samplePerformance)))
{
keep <- !is.na(samplePerformance)
assay <- assay[keep, ]
samplePerformance <- samplePerformance[keep]
}
assay <- assay[names(samplePerformance), ] # Just in case.
assayOHE <- MatrixModels::model.Matrix(~ 0 + ., data = assay)

if(fitMode == "single")
{
as(do.call(rbind, lapply(colnames(assay), function(featureID)
{
covariate <- assay[, featureID]
fitted <- glm(samplePerformance ~ covariate, family = binomial, weights = rep(100, length(samplePerformance)))
summaryDF <- broom::tidy(fitted)
if(is.factor(covariate))
{
summaryDF[2:nrow(summaryDF), "term"] <- paste(featureID, levels(covariate)[2:length(levels(covariate))], sep = ": ")

} else {summaryDF[, "term"] <- featureID}
summaryDF[2:nrow(summaryDF), ]
})), "DataFrame")
} else { # Penalised regression.
samplePerformanceM <- matrix(c(1 - samplePerformance, samplePerformance), ncol = 2)
fitted <- glmnet::glmnet(assayOHE, samplePerformanceM, family = "binomial")
lambdaConsider <- colSums(as.matrix(fitted[["beta"]])) != 0
bestLambda <- fitted[["lambda"]][lambdaConsider][which.min(sapply(fitted[["lambda"]][lambdaConsider], function(lambda) # Largest Lambda with minimum balanced error rate.
{
predictions <- predict(fitted, assayOHE, s = lambda, type = "response")
sum(abs(samplePerformanceM[, 2] - predictions))
}))[1]]
useVariables <- abs(fitted[["beta"]][, fitted[["lambda"]] == bestLambda]) > 0.00001
useVariables <- colnames(assay)[unique(assayOHE@assign[useVariables])]
dataForModel <- data.frame(assay, performance = samplePerformanceM[, 2])
fitted <- glm(performance ~ . + 0, data = dataForModel, family = binomial(), weights = rep(100, nrow(dataForModel)))
broom::tidy(fitted)
}
})
6 changes: 3 additions & 3 deletions R/crissCrossValidate.R
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ crissCrossPlot <- function(crissCrossResult, includeValues = FALSE){

ggheatmap <- ggplot(melted_cormat, aes(Var1, Var2, fill = value)) +
geom_tile(color = "white") +
scale_fill_gradient2(high = "red", mid = "white", low = "blue",
scale_fill_gradient2(high = "#e25563ff", mid = "white", low = "#094bacff",
midpoint = 0.5, limit = c(0,1), space = "Lab",
name=as.character(scalebar_title)) +
theme_bw() + xlab("Training Dataset") + ylab("Testing Dataset") +
Expand All @@ -205,7 +205,7 @@ crissCrossPlot <- function(crissCrossResult, includeValues = FALSE){
melted_cormat_1 <- melt(real, na.rm = TRUE)
ggheatmap_1 <- ggplot(melted_cormat_1, aes(Var1, Var2, fill = value)) +
geom_tile(color = "white") +
scale_fill_gradient2(high = "red", mid = "white", low = "blue",
scale_fill_gradient2(high = "#e25563ff", mid = "white", low = "#094bacff",
midpoint = 0.5, limit = c(0,1), space = "Lab",
name=as.character(scalebar_title)) +
theme_bw() + xlab("Features Extracted") + ylab("Dataset Tested") +
Expand All @@ -218,7 +218,7 @@ crissCrossPlot <- function(crissCrossResult, includeValues = FALSE){
melted_cormat_2 <- melt(random, na.rm = TRUE)
ggheatmap_2 <- ggplot(melted_cormat_2, aes(Var1, Var2, fill = value)) +
geom_tile(color = "white") +
scale_fill_gradient2(high = "red", mid = "white", low = "blue",
scale_fill_gradient2(high = "#e25563ff", mid = "white", low = "#094bacff",
midpoint = 0.5, limit = c(0,1), space = "Lab",
name=as.character(scalebar_title)) +
theme_bw() + xlab("Features Extracted") + ylab("Dataset Tested") +
Expand Down
15 changes: 10 additions & 5 deletions R/crossValidate.R
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,8 @@ CV <- function(measurements, outcome, x, outcomeTrain, measurementsTest, outcome
if(is.character(classifyResults)) stop(classifyResults)
fullResult <- runTest(measurements, outcome, measurements, outcome, crossValParams = crossValParams, modellingParams = modellingParams, characteristics = characteristics, .iteration = 1)
classifyResults@finalModel <- fullResult$models
class(classifyResults@finalModel) <- c("trainedByClassifyR", classifyResults@finalModel)
attr(classifyResults@finalModel, "predictFunction") <- modellingParams@trainParams@classifier
}

classifyResults
Expand Down Expand Up @@ -797,7 +799,7 @@ train.DataFrame <- function(x, outcomeTrain, selectionMethod = "auto", nFeatures
if(isTuneCross && !extraParams[["tuneCross"]][["performanceType"]] %in% c("auto", .ClassifyRenvir[["performanceTypes"]]))
stop(paste("performanceType for tuning must be one of", paste(c("auto", .ClassifyRenvir[["performanceTypes"]]), collapse = ", "), "but is", extraParams[["tuneCross"]][["performanceType"]]))

isCategorical <- is.character(outcome) && (length(outcome) == 1 || length(outcome) == nrow(measurements)) || is.factor(outcome)
isCategorical <- is.character(outcomeTrain) && (length(outcomeTrain) == 1 || length(outcomeTrain) == nrow(measurements)) || is.factor(outcomeTrain)
if(isTuneCross && extraParams[["tuneCross"]][["performanceType"]] == "auto")
if(isCategorical) extraParams[["tuneCross"]][["performanceType"]] <- "Balanced Accuracy" else extraParams[["tuneCross"]][["performanceType"]] <- "C-index"
if(length(selectionMethod) == 1 && selectionMethod == "auto")
Expand Down Expand Up @@ -838,7 +840,7 @@ train.DataFrame <- function(x, outcomeTrain, selectionMethod = "auto", nFeatures
measurementsUse <- measurements
}

classifierParams <- .classifierKeywordToParams(keyword = classifierForAssay)
classifierParams <- .classifierKeywordToParams(classifierForAssay, extraParams[["train"]][["tuneParams"]])
if(!is.null(extraParams) && "train" %in% names(extraParams))
{
for(paramIndex in seq_along(extraParams[["train"]]))
Expand Down Expand Up @@ -922,6 +924,7 @@ train.DataFrame <- function(x, outcomeTrain, selectionMethod = "auto", nFeatures
if(multiViewMethod == "merge"){
measurementsUse <- measurements[, S4Vectors::mcols(measurements)[["assay"]] %in% assayIDs, drop = FALSE]
model <- .doTrain(measurementsUse, outcomeTrain, NULL, NULL, crossValParams, modellingParams, verbose = verbose)[["model"]]
attr(model, "predictFunction") <- modellingParams@trainParams@classifier
class(model) <- c("trainedByClassifyR", class(model))
}

Expand All @@ -948,6 +951,7 @@ train.DataFrame <- function(x, outcomeTrain, selectionMethod = "auto", nFeatures
getFeatures = prevalFeatures),
predictParams = PredictParams(prevalPredictInterface, characteristics = paramsAssays$clinical@predictParams@characteristics))
model <- .doTrain(measurementsUse, outcomeTrain, NULL, NULL, crossValParams, modellingParams, verbose = verbose)[["model"]]
attr(model, "predictFunction") <- modellingParams@trainParams@classifier
class(model) <- c("trainedByClassifyR", class(model))
}

Expand All @@ -965,6 +969,7 @@ train.DataFrame <- function(x, outcomeTrain, selectionMethod = "auto", nFeatures
getFeatures = PCAfeatures),
predictParams = PredictParams(pcaPredictInterface, characteristics = paramsClinical$clinical@predictParams@characteristics))
model <- .doTrain(measurementsUse, outcomeTrain, NULL, NULL, crossValParams, modellingParams, verbose = verbose)[["model"]]
attr(model, "predictFunction") <- modellingParams@trainParams@classifier
class(model) <- c("trainedByClassifyR", class(model))
}
if(missing(models) || is.null(models)) return(model) else return(models)
Expand Down Expand Up @@ -1050,12 +1055,12 @@ predict.trainedByClassifyR <- function(object, newData, outcome, ...)
}, newData, names(newData))
newData <- do.call(cbind, newData)
} else if(is(newData, "MultiAssayExperiment"))
{
newData <- prepareData(newData, outcome)
{
newData <- prepareData(newData, outcome)[["measurements"]]
}

predictFunctionUse <- attr(object, "predictFunction")
class(object) <- rev(class(object)) # Now want the predict method of the specific model to be picked, so put model class first.
class(object) <- class(object)[-1] # Now want the predict method of the specific model to be picked, so put model class first.
if (is(object, "listOfModels"))
mapply(function(model, assay) predictFunctionUse(model, assay), object, newData, MoreArgs = list(...), SIMPLIFY = FALSE)
else do.call(predictFunctionUse, list(object, newData, ...)) # Object is itself a trained model and it is assumed that a predict method is defined for it.
Expand Down
4 changes: 3 additions & 1 deletion R/interfaceCoxnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ coxnetTrainInterface <- function(measurementsTrain, survivalTrain, lambda = NULL
message(Sys.time(), ": Fitting coxnet model to data.")

measurementsTrain <- data.frame(measurementsTrain, check.names = FALSE)
measurementsMatrix <- glmnet::makeX(as(measurementsTrain, "data.frame"))
measurementsMatrix <- MatrixModels::model.Matrix(~ 0 + ., data = measurementsTrain)

# The response variable is a Surv class of object.
fit <- glmnet::cv.glmnet(measurementsMatrix, survivalTrain, family = "cox", type = "C", ...)
fitted <- fit$glmnet.fit

offset <- -mean(predict(fitted, measurementsMatrix, s = fit$lambda.min, type = "link"))
attr(fitted, "tune") <- list(lambda = fit$lambda.min, offset = offset)
attr(fitted, "featureNames") <- colnames(measurementsMatrix)
attr(fitted, "featureGroups") <- measurementsMatrix@assign

class(fitted) <- class(fitted)[1] # Get rid of glmnet which messes with dispatch.
fitted
Expand Down
16 changes: 11 additions & 5 deletions R/interfacePenalisedGLM.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ attr(penalisedGLMtrainInterface, "name") <- "penalisedGLMtrainInterface"

# model is of class multnet
penalisedGLMpredictInterface <- function(model, measurementsTest, lambda, ..., returnType = c("both", "class", "score"), verbose = 3)
{ # ... just consumes emitted tuning variables from .doTrain which are unused.
{
# ... just consumes emitted tuning variables from .doTrain which are unused.
returnType <- match.arg(returnType)
# One-hot encoding needed.
# One-hot encoding needed.
measurementsTest <- MatrixModels::model.Matrix(~ 0 + ., data = measurementsTest)

# Ensure that testing data has same columns names in same order as training data.
Expand All @@ -51,7 +52,7 @@ penalisedGLMpredictInterface <- function(model, measurementsTest, lambda, ..., r


measurementsTest <- measurementsTest[, rownames(model[["beta"]][[1]])]

classPredictions <- factor(as.character(predict(model, measurementsTest, s = lambda, type = "class")), levels = model[["classnames"]])
classScores <- predict(model, measurementsTest, s = lambda, type = "response")[, , 1]

Expand All @@ -78,8 +79,13 @@ penalisedFeatures <- function(model)
{
# Floating point numbers test for equality.
whichCoefficientColumn <- which(abs(model[["lambda"]] - attr(model, "tune")[["lambda"]]) < 0.00001)[1]
coefficientsUsed <- sapply(model[["beta"]], function(classCoefficients) classCoefficients[, whichCoefficientColumn])
featureScores <- rowSums(abs(coefficientsUsed))
if(is.list(model[["beta"]])) # Categorical data
{
coefficientsUsed <- sapply(model[["beta"]], function(classCoefficients) classCoefficients[, whichCoefficientColumn])
featureScores <- rowSums(abs(coefficientsUsed))
} else { # survival data
featureScores <- model[["beta"]][, whichCoefficientColumn]
}
featureGroups <- attr(model, "featureGroups")[match(names(featureScores), attr(model, "featureNames"))]
groupScores <- unname(by(featureScores, featureGroups, max))
rankedFeaturesIndices <- order(groupScores, decreasing = TRUE)
Expand Down
2 changes: 1 addition & 1 deletion R/precisionPathways.R
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ bubblePlot.PrecisionPathways <- function(precisionPathways, pathwayColours = NUL
performance <- precisionPathways[["performance"]]
performance <- cbind(Sequence = rownames(performance), performance)
ggplot2::ggplot(performance, aes(x = accuracy, y = cost, colour = Sequence, size = 4)) + ggplot2::geom_point() +
ggplot2::scale_color_manual(values = pathwayColours) + ggplot2::labs(x = "Balanced Accuracy", y = "Total Cost") + ggplot2::guides(size = FALSE)
ggplot2::scale_color_manual(values = pathwayColours) + ggplot2::labs(x = "Balanced Accuracy", y = "Total Cost") + ggplot2::guides(size = "none")
}

#' @export
Expand Down
Loading

0 comments on commit eca8bce

Please sign in to comment.