-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0a3a636
commit 3fec4d9
Showing
9 changed files
with
321 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,3 +27,4 @@ | |
|
||
\.idea/workspace\.xml | ||
.idea | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
package randomforest | ||
|
||
import ( | ||
"fmt" | ||
"math/big" | ||
"math/rand" | ||
) | ||
|
||
/* | ||
Boruta is smart algorithm for select important features with Random Forest. It was developed in language R. | ||
X [][]float64 - data for random forest. At least three features (columns) are required. | ||
Class []int - classes for random forest (0,1,..) | ||
trees int - number of trees used by Boruta algorithm. Is not need too big number of trees. (50-200) | ||
cycles int - number of cycles (20-50) of Boruta algorithm. | ||
threshold float64 - threshold for select feauters (0.05) | ||
recursive bool - algorithm repeat process until all features are important | ||
verbose bool - will print process of boruta algorithm. | ||
*/ | ||
func BorutaDefault(x [][]float64, class []int) []int { | ||
return Boruta(x, class, 100, 20, 0.05, true, true) | ||
} | ||
|
||
func Boruta(x [][]float64, class []int, trees int, cycles int, threshold float64, recursive bool, verbose bool) []int { | ||
//keep mapping of features | ||
featMap := make(map[int]int, 0) | ||
for i := 0; i < len(x[0]); i++ { | ||
featMap[i] = i | ||
} | ||
|
||
c2 := 0 | ||
for { | ||
c2++ | ||
features := len(featMap) | ||
//copy x to working x | ||
wx := make([][]float64, len(x)) | ||
for i := 0; i < len(x); i++ { | ||
wx[i] = make([]float64, features) | ||
for j := 0; j < features; j++ { | ||
wx[i][j] = x[i][featMap[j]] | ||
} | ||
} | ||
|
||
//add shadow columns to wx | ||
for i := 0; i < len(wx); i++ { | ||
for j := 0; j < features; j++ { | ||
wx[i] = append(wx[i], wx[i][j]) | ||
} | ||
} | ||
|
||
tips := make(map[int]int, 0) | ||
for cycle := 0; cycle < cycles; cycle++ { | ||
if verbose { | ||
fmt.Println("Cycle:", cycle+1, "/", c2) | ||
} | ||
//shufle | ||
for i := 0; i < features; i++ { | ||
column := features + i | ||
for j := 0; j < len(wx); j++ { | ||
k := rand.Intn(len(wx)) | ||
wx[j][column], wx[k][column] = wx[k][column], wx[j][column] | ||
} | ||
} | ||
//forest | ||
forest := Forest{Data: ForestData{X: wx, Class: class}} | ||
forest.Train(trees) | ||
//save tips | ||
bestShadow := 0.0 | ||
for i := features; i < 2*features; i++ { | ||
if forest.FeatureImportance[i] > bestShadow { | ||
bestShadow = forest.FeatureImportance[i] | ||
} | ||
} | ||
c := 0 | ||
for i := 0; i < features; i++ { | ||
if forest.FeatureImportance[i] > bestShadow { | ||
tips[i]++ | ||
c++ | ||
} | ||
} | ||
if verbose { | ||
fmt.Println("selected tips:", c, "/", features) | ||
} | ||
} | ||
//select remaining features | ||
tipThreshold := bionimalThreshold(cycles, threshold) | ||
newFeatMap := make(map[int]int, 0) | ||
c := 0 | ||
for i := 0; i < features; i++ { | ||
if tips[i] >= tipThreshold { | ||
newFeatMap[c] = featMap[i] | ||
c++ | ||
} | ||
} | ||
if verbose { | ||
fmt.Println("Threshold count:", tipThreshold) | ||
fmt.Println("Threshold features", len(newFeatMap), "/", len(featMap)) | ||
} | ||
if len(newFeatMap) == len(featMap) || len(newFeatMap) < 3 || !recursive { | ||
result := make([]int, 0) | ||
for _, v := range newFeatMap { | ||
result = append(result, v) | ||
} | ||
return result | ||
} | ||
featMap = newFeatMap | ||
if verbose { | ||
result := make([]int, 0) | ||
for _, v := range newFeatMap { | ||
result = append(result, v) | ||
} | ||
fmt.Println("Selected feautures", result) | ||
} | ||
} | ||
|
||
} | ||
|
||
func bionimalThreshold(n int, threshold float64) int { | ||
sum := 0 | ||
s := make([]int, n+1) | ||
bi := big.Int{} | ||
for i := 0; i <= n; i++ { | ||
bn := int(bi.Binomial(int64(n), int64(i)).Int64()) | ||
sum = sum + bn | ||
s[i] = sum | ||
} | ||
for j := 0; j < n; j++ { | ||
if float64(s[j])/float64(sum) >= threshold { | ||
return j | ||
} | ||
} | ||
return n | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
package main | ||
|
||
import ( | ||
"fmt" | ||
"math/rand" | ||
|
||
randomforest "github.com/malaschitz/randomForest" | ||
"github.com/malaschitz/randomForest/examples/img" | ||
"github.com/petar/GoMNIST" | ||
) | ||
|
||
/* | ||
Using boruta for mnist | ||
With threshold 5% was selected 495 features from 784. | ||
Result of random forest was the same - 96.2% as in minst.go | ||
With threshold 10% was selected 493 features from 784 with the same result. | ||
*/ | ||
func main() { | ||
size := 60000 | ||
xsize := 28 * 28 | ||
labels, err := GoMNIST.ReadLabelFile("examples/train-labels-idx1-ubyte.gz") | ||
if err != nil { | ||
panic(err) | ||
} | ||
_, _, imgs, err := GoMNIST.ReadImageFile("examples/train-images-idx3-ubyte.gz") | ||
if err != nil { | ||
panic(err) | ||
} | ||
if len(labels) != size || len(imgs) != size { | ||
panic("Wrong size") | ||
} | ||
x := make([][]float64, size) | ||
l := make([]int, size) | ||
for i := 0; i < size; i++ { | ||
x[i] = make([]float64, xsize) | ||
for j := 0; j < xsize; j++ { | ||
x[i][j] = float64(imgs[i][j]) | ||
l[i] = int(labels[i]) | ||
} | ||
} | ||
//sample for testing | ||
//x, l = sample(x, l, 200) | ||
// | ||
borutaFeatuters := randomforest.BorutaDefault(x, l) | ||
//borutaFeatuters := randomforest.Boruta(x, l, 100, 20, 0.05, true, true) | ||
image := make([]byte, xsize) | ||
for _, v := range borutaFeatuters { | ||
image[v] = 255 | ||
} | ||
//save redsults as image | ||
img.WriteImage(image, "boruta") | ||
// | ||
// try forest with selected features | ||
// | ||
xSmall := make([][]float64, len(x)) | ||
for i := 0; i < len(x); i++ { | ||
xSmall[i] = make([]float64, len(borutaFeatuters)) | ||
for j := 0; j < len(borutaFeatuters); j++ { | ||
xSmall[i][j] = x[i][borutaFeatuters[j]] | ||
} | ||
} | ||
forest := randomforest.Forest{} | ||
forest.Data = randomforest.ForestData{X: xSmall, Class: l} | ||
forest.Train(100) | ||
|
||
//read test data | ||
tsize := 10000 | ||
tlabels, err := GoMNIST.ReadLabelFile("examples/t10k-labels-idx1-ubyte.gz") | ||
if err != nil { | ||
panic(err) | ||
} | ||
_, _, timgs, err := GoMNIST.ReadImageFile("examples/t10k-images-idx3-ubyte.gz") | ||
if err != nil { | ||
panic(err) | ||
} | ||
if len(tlabels) != tsize || len(timgs) != tsize { | ||
panic("Wrong size") | ||
} | ||
//calculate difference | ||
x = make([][]float64, tsize) | ||
for i := 0; i < tsize; i++ { | ||
x[i] = make([]float64, len(borutaFeatuters)) | ||
for j := 0; j < len(borutaFeatuters); j++ { | ||
x[i][j] = float64(timgs[i][borutaFeatuters[j]]) | ||
} | ||
} | ||
p := 0 | ||
for i := 0; i < tsize; i++ { | ||
vote := forest.Vote(x[i]) | ||
bestI := -1 | ||
bestV := 0.0 | ||
for j, v := range vote { | ||
if v > bestV { | ||
bestV = v | ||
bestI = j | ||
} | ||
} | ||
if int(tlabels[i]) == bestI { | ||
p++ | ||
} | ||
} | ||
fmt.Printf("Trees: %d Results: %5.1f%%\n", 100, 100.0*float64(p)/float64(tsize)) | ||
// Selected 495 features from 784 | ||
// Trees: 100 Results: 96.2% | ||
} | ||
|
||
//create samples from data | ||
func sample(x [][]float64, y []int, count int) (xx [][]float64, yy []int) { | ||
xx = make([][]float64, count) | ||
yy = make([]int, count) | ||
for i := 0; i < count; i++ { | ||
k := rand.Intn(len(x)) | ||
xx[i] = x[k] | ||
yy[i] = y[k] | ||
} | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package img | ||
|
||
import ( | ||
"image" | ||
"image/png" | ||
"log" | ||
"os" | ||
) | ||
|
||
func WriteImage(data []byte, name string) { | ||
img := image.NewGray(image.Rect(0, 0, 28, 28)) | ||
img.Pix = data | ||
out, err := os.Create("./" + name + ".png") | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
err = png.Encode(out, img) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.