-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
659fb08
commit cf96325
Showing
2 changed files
with
841 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
## iterative Random Forests (iRF) | ||
|
||
The R package `iRF` implements iterative Random Forests, a method for | ||
iteratively growing ensemble of weighted decision trees, and detecting | ||
high-order feature interactions by analyzing feature usage on decision paths. | ||
This version uses source codes from the R package `randomForest` by Andy Liaw | ||
and Matthew Weiner, the original Fortran codes by Leo Breiman and Adele Cutler, | ||
and source codes from the R package `FSInteract` by Hyun Jik Kim and Rajen D. | ||
Shah. | ||
|
||
`iRF` can be loaded using the `library` command: | ||
```{r, message=FALSE} | ||
library(iRF) | ||
``` | ||
|
||
`iRF` adds two new features to Breiman's original Random Forest workflow: | ||
|
||
## A. Weighted random forest | ||
|
||
Unlike Breiman's original random forest, which uses uniform random sampling to | ||
select \texttt{mtry} variables during each node split, the 'randomForest' | ||
function in 'iRF' allows non-uniform sampling using a given vector of | ||
nonnegative weights (e.g., feature importances from a previous model fit). In | ||
particular, given a vector of weights $\mathbf{w} = (w_1, \ldots, w_p)$, the | ||
`mtry` variables are selected so that $P(\text{variable }j \text{ is selected}) | ||
= w_j/\sum_{j=1}^p w_j$, for $j = 1, \ldots, p$. | ||
|
||
Based on this weighting scheme, we can iteratively grow weighted random forests, | ||
where Gini importances from the previous random forest fit are used as weights | ||
in the current iteration. | ||
|
||
```{r seed_random} | ||
# set seed for random number generation | ||
set.seed(47) | ||
``` | ||
|
||
### Binary Classification: | ||
We simulate some data for a binary classification exercise. Out of $250$ | ||
features, only three $\{1, 2, 3\}$ are important. | ||
```{r simulate_data} | ||
# simulate data for classification | ||
n <- 500 | ||
p <- 250 | ||
X <- matrix(rnorm(n * p), nrow=n) | ||
Y <- (X[,1] > 0 & X[,2] > 0 & X[,3] > 0) | ||
Y <- as.factor(as.numeric(Y)) | ||
train.id <- 1:(n / 2) | ||
test.id <- setdiff(1:n, train.id) | ||
``` | ||
|
||
Next, we iteratively grow weighted random forests and use feature importances | ||
from last iteration as weights. | ||
```{r fit_irf} | ||
sel.prob <- rep(1/p, p) | ||
# iteratively grow RF, use Gini importance of features as weights | ||
rf <- list() | ||
for (iter in 1:4){ | ||
rf[[iter]] <- randomForest(x=X[train.id,], y=Y[train.id], | ||
xtest=X[test.id,], ytest=Y[test.id], | ||
mtry.select.prob=sel.prob) | ||
# update selection probabilities for next iteration | ||
sel.prob <- rf[[iter]]$importance | ||
} | ||
``` | ||
|
||
### ROC curve | ||
We can measure performance of different iterations of RF on the test set: | ||
```{r print_auc, fig.height = 6, fig.width = 6, message=FALSE} | ||
library(AUC) | ||
plot(0:1, 0:1, type='l', lty = 2, xlab = 'FPR', ylab = 'TPR', main='ROC Curve') | ||
for (iter in 1:4){ | ||
# performance on test set | ||
cat(paste('iter = ', iter, ':: ')) | ||
roc.info <- roc(rf[[iter]]$test$votes[,2], Y[test.id]) | ||
lines(roc.info$fpr, roc.info$tpr, type='l', col=iter, lwd=2) | ||
cat(paste('AUROC: ', round(100*auc(roc.info), 2), '%\n', sep='')) | ||
} | ||
legend('bottomright', legend=paste('iter:', 1:iter), col=1:iter, lwd=2, bty='n') | ||
``` | ||
|
||
### Variable Importance | ||
The outputs of `iRF::randomForest` are objects of class `randomForest` and can | ||
be used with other functions in the R package `randomForest` directly, e.g., to | ||
visualize variable importance measures: | ||
|
||
```{r varimp, fig.width=16, fig.height=5} | ||
par(mfrow=c(1,4)) | ||
for (iter in 1:4) | ||
varImpPlot(rf[[iter]], n.var=10, main=paste('Variable Importance (iter:', iter, ')')) | ||
``` | ||
|
||
|
||
### Regression | ||
|
||
For regression, the usage is similar: | ||
|
||
```{r regress} | ||
# change to a continuous response | ||
y <- as.numeric(Y) + rnorm(n, sd=0.25) | ||
# iteratively grow RF, use Gini importance of features as weights | ||
rf <- list() | ||
sel.prob <- rep(1/p, p) | ||
for (iter in 1:4){ | ||
cat(paste('iter = ', iter, ':: ')) | ||
rf[[iter]] <- randomForest(x=X[train.id,], y=y[train.id], | ||
xtest=X[test.id,], ytest=y[test.id], | ||
mtry.select.prob=sel.prob) | ||
# update selection probabilities for next iteration | ||
sel.prob <- rf[[iter]]$importance/sum(rf[[iter]]$importance) | ||
# performance on test set | ||
ytest <- y[test.id] | ||
test.error = 1 - mean((rf[[iter]]$test$predicted - ytest) ^ 2) / var(ytest) | ||
cat(paste('% var explained: ', round(100 * test.error, 2), '%\n', sep='')) | ||
} | ||
``` | ||
|
||
### Parallel implementation | ||
The weighted random forests can be grown in parallel on multiple servers, using | ||
the `doParallel` library: | ||
|
||
```{r setup_parallel, results="hide", message=FALSE} | ||
# set up cores for parallel implementation | ||
library(doParallel) | ||
registerDoParallel(detectCores()) | ||
n.tree.per.core <- 30 | ||
rfpar <- foreach(n.tree=rep(n.tree.per.core, n.cores), | ||
.combine=combine, .multicombine=TRUE) %dopar% { | ||
randomForest(x=X[train.id,], y=Y[train.id], ntree=n.tree) | ||
} | ||
``` | ||
|
||
## B. Detect high-order feature interactions in a stable fashion | ||
`iRF` detects high-order interaction among features by analyzing feature usage | ||
on the decision paths of large leaf nodes in a random forest. In particular, | ||
given a (weighted) random forest fit, `iRF` (i) passes the training data through | ||
the fitted forest and records the features used on the associated decision | ||
paths; (ii) Applies a weighted version of the random intersection tree (RIT) | ||
algorithm proposed by [Shah and Meinshausen | ||
(2014)](http://jmlr.org/papers/v15/shah14a.html) to find high-order feature | ||
combinations prevalent on the decision paths, where weights are determined by | ||
user specified features; (iii) performs the above two steps on many bootstrap | ||
replicates of the training set to assess stability of the features and their | ||
interactions. | ||
|
||
### Feature usage on decision paths of large leaf nodes | ||
Consider a classification example as before. We iteratively grow weighted random | ||
forests as before, but save the forest components for further processing. | ||
|
||
```{r fit_irf2} | ||
sel.prob <- rep(1/p, p) | ||
# iteratively grow RF, use Gini importance of features as weights | ||
rf <- list() | ||
for (iter in 1:4){ | ||
rf[[iter]] <- randomForest(x=X[train.id,], y=Y[train.id], | ||
xtest=X[test.id,], ytest=Y[test.id], | ||
mtry.select.prob=sel.prob, | ||
keep.forest=TRUE) | ||
# update selection probabilities for next iteration | ||
sel.prob <- rf[[iter]]$importance / sum(rf[[iter]]$importance) | ||
} | ||
``` | ||
|
||
To read feature usage on nodes, use the function `readForest`. This function can | ||
be run in parallel with the argument `n.core`, provided `doParallel` has been set up | ||
for parallel processing. | ||
|
||
```{r large_leaf, message=FALSE} | ||
rforest <- readForest(rfobj=rf[[3]], x=X[train.id,]) | ||
``` | ||
|
||
The `tree.info` data frame provides meta data for each leaf node. | ||
|
||
```{r tree_info, message=FALSE} | ||
head(rforest$tree.info, n=10) | ||
``` | ||
|
||
The `readForest` function optionally computes a sparse matrix `node.feature` | ||
that encodes the splitting features and directions along the decision path of | ||
each leaf node. This encoding can be represented as a binary vector of length | ||
$2p$`, where an entry $j\in\{1\dots p\}$ indicates that a decision path contains | ||
the left child of a node that splits on feature $j$ and an entry $j\in\{p+1\dots | ||
2p\}$ indicates that a decision path contains the right child of a node that | ||
splits on feature $j-p$. | ||
|
||
```{r node_feature, message=FALSE} | ||
rforest$node.feature[1:10, c(1:5, p + 1:5)] | ||
``` | ||
|
||
### Finding feature interactions using random intersection trees (RIT) | ||
|
||
To find prevalent sets of features and their high-order combinations used to | ||
define these nodes, use the random intersection trees (RIT) function. The | ||
following with the following command runs RIT on all class 1 nodes, sampling | ||
each leaf node with probability proportional to the number of observations in | ||
each leaf node: | ||
|
||
```{r rit} | ||
class1.nodes <- rforest$tree.info$prediction - 1 == 1 | ||
wt <- rforest$tree.info$size.node[class1.nodes] | ||
RIT(rforest$node.feature[class1.nodes,], weights=wt, | ||
depth=5, branch=2, n_trees=100) | ||
``` | ||
|
||
### Selecting Stable interactions | ||
|
||
The function `iRF` combines all the above steps and uses bootstrap aggregation | ||
to assess stability of the selected interactions. Setting `select.iter = TRUE` | ||
chooses the optimal iteration number by maximizing prediction accuracy on out of | ||
bag samples and searches for interactions in the corresponding random forest. | ||
|
||
```{r irf, message=FALSE} | ||
fit <- iRF(x=X[train.id,], | ||
y=Y[train.id], | ||
xtest=X[test.id,], | ||
ytest=Y[test.id], | ||
n.iter=5, | ||
n.core=n.cores, | ||
select.iter = TRUE, | ||
n.bootstrap=10 | ||
) | ||
``` | ||
|
||
The output of `iRF` is a list containing up to three entries. If `select.iter = | ||
FALSE`, outputs for each entry will be lists with one entry per iteration. | ||
|
||
(1) `rf.list`: a `randomForest` object. | ||
|
||
(2) `interaction`: the interaction stability scores. The $+$ and $-$ | ||
indicate whether the interaction is characterized by high or low levels of the | ||
indicated feature. For instance, the interaction `1+_2+_3+` describes an | ||
interaction that is defined by high levels of $1, 2,$ and $3$. | ||
|
||
```{r stability} | ||
head(fit$interaction) | ||
``` | ||
|
||
(3) `importance`: a data table of importance metrics associated with each interaction. The `prev0` entry indicates interaction prevalence for class-$0$ leaf nodes and the `prev1` entry indicates interaction prevalence for class-$1$ leaf nodes. `sta.prev` indicates the proportion of times (across `n.bootsrap` samples) an interaction is more prevalent than expected under independent feature selection. `diff` indicates the difference in prevalence between class-$0$ and class-$1$ leaf nodes. `sta.diff` indicates the proportion of times (across `n.bootsrap` samples) an interaction is more prevalent among class-$1$ leaf nodes. `prec` describes the precision of leaf nodes for which an interaction is active. `sta.prec` indicates the proportion of times (across `n.bootsrap` samples) that an interaction is more precise than any subset. For a more detailed description of these importance metrics, see our paper: https://arxiv.org/abs/1810.07287. | ||
|
||
(4) `iter.select`: The selected iteration. | ||
|
||
```{r prevalence, message=FALSE} | ||
library(dplyr) | ||
head(fit$importance) | ||
``` | ||
The selected interactions can be visualized using simple R functions like `dotchart`: | ||
|
||
```{r dotchart, fig.height=6, fig.width=6} | ||
toplot <- fit$importance$diff | ||
names(toplot) <- fit$importance$int | ||
dotchart(rev(toplot[1:min(20, length(toplot))]), | ||
xlab='Prevalence enrichment', xlim=c(0, 1), | ||
main='Prevalent Features/Interactions \n on Decision paths') | ||
``` | ||
|
||
|
Oops, something went wrong.