-
Notifications
You must be signed in to change notification settings - Fork 0
/
bh_randomforest.R
66 lines (54 loc) · 2.08 KB
/
bh_randomforest.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
bh_randomforest <-function(Data, label,option)
## Option -->"train" performs training of the random forest
# Data --> Input training data
# label --> training labels
## option--> "test" performs testing of the random forest
# Data --> trained function (fitted function) of the random forest
# label --> Testing Data
{
library(randomForest)
if (strcmp(option,"train"))
{
# Pre-processing data format and formula for RF classifier
label<-as.factor(label)
colnames(Data)<-paste0("feature", 1:ncol(Data))
f<-as.formula(paste("label~",paste0("feature",1:ncol(Data), collapse = "+")))
Data<-data.frame(Data)
### training stage of the RF classifier##########
# fine tunning number of variables to split
# mtry <- tuneRF(Data,label, ntreeTry=200,
# stepFactor=1.5,improve=0.01, trace=TRUE, plot=TRUE)
# training random forest classifier
set.seed(0)
# serial
# Rf <-randomForest(f, data=Data, ntree=500, keep.forest=TRUE, importance=TRUE)
# print(Rf)
# # parallel packages
library(doParallel)
library(parallel)
Ncores<-detectCores()
cl<-makeCluster(Ncores-1)
registerDoParallel(cl)
# # parallel
Rf<-foreach(ntree=rep(ceil(500/(Ncores-1)),Ncores-1), .combine=combine, .packages = 'randomForest')%dopar%
{
randomForest(f, data=Data, ntree=ntree, keep.forest=TRUE, importance=TRUE)
}
stopCluster(cl)
print(Rf)
#classification of training samples
trainclassifiedlabel <- predict(Rf, newdata = Data)
return(list(model=Rf))
}else if (strcmp(option,"test"))
{
# chaning format of the data's for RF prediction
Rf<-Data;TestData<-label
rm(label)
colnames(TestData)<-paste0("feature", 1:ncol(TestData))
TestData<-data.frame(TestData)
# prediction of RF classifier
probvalues<- predict(Rf$model, newdata = TestData, type="prob")
classifiedlabel<-max.col(probvalues)
return(list(classifiedlabel=classifiedlabel,Probvalues=probvalues))
}
}