forked from serrano-pozo-lab/glia-ihc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
performance-evaluation.Rmd
188 lines (148 loc) · 9.08 KB
/
performance-evaluation.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
---
title: "Performance Evaluation"
description: |
This R script evaluates GBM/CNN performance by bootstrapping across 500 iterations of the independent test set.
author:
- first_name: "Colin"
last_name: "Magdamo"
affiliation: Massachusetts General Hospital
affiliation_url: https://www.massgeneral.org/neurology/research/mind-data-science-lab
orcid_id: 0000-0001-8965-4630
output:
distill::distill_article:
toc: true
---
```{r setup, include = FALSE}
knitr::opts_chunk$set(eval = FALSE)
```
# Dependencies
Load requisite packages and define directories.
```{r load-packages, message=FALSE, warning=FALSE}
library(data.table)
library(ROCR)
library(pROC)
library(purrr)
library(glmnet)
library(dbplyr)
library(PRROC)
library(rms)
```
# Read Data
Read test set classification probabilities from gradient boosting machines (GBM) and convolutional neural network (CNN) model output. To comply with the bootstrapping function, datasets must have a column called `labels` (has true labels) and column called `predictions` (has predicted probability).
```{r read-data}
# parse GBM data
gbm <- readRDS('data/AD vs. CTRL GBM Results.rds')
# parse astrocyte GBM
gbm_astrocyte <- gbm[['Astrocyte']]$Scores[,2:3]
setnames(gbm_astrocyte,c('Condition','Alzheimer'),c('labels','predictions'))
gbm_astrocyte[,labels := as.integer(labels) - 1]
# parse microglia GBM
gbm_microglia <- gbm[['Microglia']]$Scores[,2:3]
setnames(gbm_microglia,c('Condition','Alzheimer'),c('labels','predictions'))
gbm_microglia[,labels := as.integer(labels) - 1]
# read astrocyte CNN
cnn_astrocyte <- fread('data/Astrocyte CNN TestSetResults.csv')[,c('TrueLabel','ProbabilityAD')]
setnames(cnn_astrocyte,c('TrueLabel','ProbabilityAD'),c('labels','predictions'))
cnn_astrocyte[,labels := ifelse(labels == 1,0,1)]
# read microglia CNN
cnn_microglia <- fread('data/Microglia CNN TestSetResults.csv')[,c('TrueLabel','ProbabilityAD')]
setnames(cnn_microglia,c('TrueLabel','ProbabilityAD'),c('labels','predictions'))
cnn_microglia[,labels := ifelse(labels == 1,0,1)]
# create named list
all_experiments <- list(gbm_astrocyte = gbm_astrocyte, gbm_microglia = gbm_microglia,
cnn_astrocyte = cnn_astrocyte,cnn_microglia = cnn_microglia)
```
# Compute Model Performance
Function to compute model performance and calculate 95% confidence intervals by boostrapping across 500 iterations of the independent test set.
```{r compute-performance}
compute_model_performance <- function(model_name,model_preds,metrics = c('acc', 'ppv', 'npv', 'sens', 'spec', 'f','auc','aucpr'),threshold_criteria = 'max_accuracy',marginal_thresholds = NULL){
# prediction object ROCR expects; if it is NOT ROCR class prediction then make model_preds one
if(!'prediction' %in% class(model_preds)) {
ROCR_prediction <- prediction(model_preds$predictions,model_preds$labels)
} else {ROCR_prediction <- model_preds}
# first, check if there are any threshold dependent metrics
if(any(!metrics %in% c('auc','aucpr'))){
acc <- performance(ROCR_prediction,'acc')
acc <- data.table(cutoffs = [email protected][[1]],accuracy = [email protected][[1]])
if(threshold_criteria == 'max_accuracy'){
# pick cutoff that yields max accuracy
threshold <- acc[acc[,.I[which.max(accuracy)]],cutoffs]
}
# must input all_experiments_performance_metrics_cast if using this threshold_criteria
if(threshold_criteria == 'from_marginal'){
# merge back threshold from marginal model
colnames_strip <- names(stratifed_result_indices)
from_marginal_model <- gsub(paste(colnames_strip,collapse = "|"),"",model_name)
from_marginal_model <- gsub('_\\b',"",from_marginal_model)
threshold <- marginal_thresholds[model_name == from_marginal_model,Threshold]
# need to get threshold closest to this ON the data in strata
threshold <- acc[cutoffs <= threshold,][order(-cutoffs)][1,cutoffs] # get threshold within acc object immediately BEFORE this threshold or equal to if its
}
}
# build DT for results from threshold independent metrics
ROCR_results_no_threshold <- map_dfr(metrics[metrics %in% c('auc','aucpr')],~data.table(Measure = .x,
Value = performance(ROCR_prediction,.x)@y.values[[1]],
model_name = model_name,
threshold_criteria = threshold_criteria,
threshold = 'Measure Independent of Threshold'))
# threshold dependent methods
# index y value where x value equal to threshold selected from threshold criteria
ROCR_results_threshold <- map_dfr(metrics[!metrics %in% c('auc','aucpr')],~data.table(Measure = .x,
Value = performance(ROCR_prediction,.x)@y.values[[1]][which(performance(ROCR_prediction,.x)@x.values[[1]] == threshold)],
model_name = model_name,
threshold_criteria = threshold_criteria,
threshold = threshold
))
# add CIs, sample 300 times w replacement and take .025 slices each end
set.seed(5281995)
seeds <- sample.int(10000000,500)
# for 1:500, compute CI for each measure
# takes a replicate index and data to resample, resamples w replacement, computes all specified measures for each replicate
boot_CIs <- function(i,data_to_resample){
# new seed for each replicate
set.seed(seed = seeds[[i]]);
# extract just preds and labels from ROCR object
data <- data.table(predictions = data_to_resample@predictions[[1]],labels = data_to_resample@labels[[1]])
# resample w replacement
data <- data[sample(1:nrow(data),replace = T),]
# make a new prediction object
ROCR_prediction <- prediction(data$predictions,data$labels)
if(any(!metrics %in% c('auc','aucpr'))){
acc <- performance(ROCR_prediction,'acc')
acc <- data.table(cutoffs = [email protected][[1]],accuracy = [email protected][[1]])
if(threshold_criteria == 'max_accuracy'){
# pick cutoff that yields max accuracy
threshold <- acc[acc[,.I[which.max(accuracy)]],cutoffs]
}
}
ROCR_results_no_threshold <- map_dfr(metrics[metrics %in% c('auc','aucpr')],~data.table(Measure = .x,
Value = performance(ROCR_prediction,.x)@y.values[[1]],
model_name = model_name,
threshold_criteria = threshold_criteria,
threshold = 'Measure Independent of Threshold',
replicate = i))
# threshold dependent methods, index y vals where x val equal to threshold selected from threshold criteria
ROCR_results_threshold <- map_dfr(metrics[!metrics %in% c('auc','aucpr')],~data.table(Measure = .x,
Value = performance(ROCR_prediction,.x)@y.values[[1]][which(performance(ROCR_prediction,.x)@x.values[[1]] == threshold)],
model_name = model_name,
threshold_criteria = threshold_criteria,
threshold = threshold,
replicate = i))
rbindlist(list(ROCR_results_threshold,ROCR_results_no_threshold),fill = TRUE)
}
# get results of each replicate
bootstrap_runs <- map_dfr(1:500,~boot_CIs(.x,data_to_resample = ROCR_prediction))
# cast so each row is a replicate, each col is a measure
bootstrap_runs_cast <- dcast(bootstrap_runs,model_name+threshold_criteria+replicate~Measure,value.var = 'Value')
# get quantiles accross measure cols
bootstrap_aggregation <- imap_dfr(bootstrap_runs_cast[,4:ncol(bootstrap_runs_cast)],~data.table(Measure = .y,CI_l = quantile(.x,.025,na.rm = T),CI_u = quantile(.x,.975,na.rm = T)))
# bind metrics together
final_metrics <- rbindlist(list(ROCR_results_threshold,ROCR_results_no_threshold),fill = TRUE)
final_metrics[bootstrap_aggregation,on=.(Measure),`:=` (CI_l = i.CI_l,CI_u = i.CI_u)]
return(final_metrics)
}
```
Map previously defined function.
```{r map-compute}
all_experiments_performance_metrics<- imap_dfr(all_experiments, ~compute_model_performance(model_name = .y, model_preds = .x))
```