Skip to content

Commit

Permalink
Boruta algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
malaschitz committed Jul 22, 2021
1 parent 0a3a636 commit 3fec4d9
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@

\.idea/workspace\.xml
.idea
.DS_Store
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,25 @@ forest.AddDataRow(data, res, 1000, 10, 2000)
// AddDataRow : add new row, trim oldest row if there is more than 1000 rows, calculate a new 10 trees, but remove oldest trees if there is more than 2000 trees.
```

# Boruta Algorithm for feature selection

Boruta algorithm was developed as package for language R.
It is one of most effective feature selection algorithm.
There is [paper](https://www.jstatsoft.org/article/view/v036i11) in Journal of Statistical Software.

Boruta algorithm use random forest for selection important features.

```go
xData := ... //data
yData := ... //labels
selectedFeatures := randomforest.BorutaDefault(xData, yData)
// or randomforest.BorutaDefault(xData, yData, 100, 20, 0.05, true, true)
```

In _/examples_ is example with [MNIST database](https://en.wikipedia.org/wiki/MNIST_database).
On picture are selected features (495 from 784) from images.

![boruta 05](boruta05.png)



133 changes: 133 additions & 0 deletions boruta.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package randomforest

import (
"fmt"
"math/big"
"math/rand"
)

/*
Boruta is smart algorithm for select important features with Random Forest. It was developed in language R.
X [][]float64 - data for random forest. At least three features (columns) are required.
Class []int - classes for random forest (0,1,..)
trees int - number of trees used by Boruta algorithm. Is not need too big number of trees. (50-200)
cycles int - number of cycles (20-50) of Boruta algorithm.
threshold float64 - threshold for select feauters (0.05)
recursive bool - algorithm repeat process until all features are important
verbose bool - will print process of boruta algorithm.
*/
func BorutaDefault(x [][]float64, class []int) []int {
return Boruta(x, class, 100, 20, 0.05, true, true)
}

func Boruta(x [][]float64, class []int, trees int, cycles int, threshold float64, recursive bool, verbose bool) []int {
//keep mapping of features
featMap := make(map[int]int, 0)
for i := 0; i < len(x[0]); i++ {
featMap[i] = i
}

c2 := 0
for {
c2++
features := len(featMap)
//copy x to working x
wx := make([][]float64, len(x))
for i := 0; i < len(x); i++ {
wx[i] = make([]float64, features)
for j := 0; j < features; j++ {
wx[i][j] = x[i][featMap[j]]
}
}

//add shadow columns to wx
for i := 0; i < len(wx); i++ {
for j := 0; j < features; j++ {
wx[i] = append(wx[i], wx[i][j])
}
}

tips := make(map[int]int, 0)
for cycle := 0; cycle < cycles; cycle++ {
if verbose {
fmt.Println("Cycle:", cycle+1, "/", c2)
}
//shufle
for i := 0; i < features; i++ {
column := features + i
for j := 0; j < len(wx); j++ {
k := rand.Intn(len(wx))
wx[j][column], wx[k][column] = wx[k][column], wx[j][column]
}
}
//forest
forest := Forest{Data: ForestData{X: wx, Class: class}}
forest.Train(trees)
//save tips
bestShadow := 0.0
for i := features; i < 2*features; i++ {
if forest.FeatureImportance[i] > bestShadow {
bestShadow = forest.FeatureImportance[i]
}
}
c := 0
for i := 0; i < features; i++ {
if forest.FeatureImportance[i] > bestShadow {
tips[i]++
c++
}
}
if verbose {
fmt.Println("selected tips:", c, "/", features)
}
}
//select remaining features
tipThreshold := bionimalThreshold(cycles, threshold)
newFeatMap := make(map[int]int, 0)
c := 0
for i := 0; i < features; i++ {
if tips[i] >= tipThreshold {
newFeatMap[c] = featMap[i]
c++
}
}
if verbose {
fmt.Println("Threshold count:", tipThreshold)
fmt.Println("Threshold features", len(newFeatMap), "/", len(featMap))
}
if len(newFeatMap) == len(featMap) || len(newFeatMap) < 3 || !recursive {
result := make([]int, 0)
for _, v := range newFeatMap {
result = append(result, v)
}
return result
}
featMap = newFeatMap
if verbose {
result := make([]int, 0)
for _, v := range newFeatMap {
result = append(result, v)
}
fmt.Println("Selected feautures", result)
}
}

}

func bionimalThreshold(n int, threshold float64) int {
sum := 0
s := make([]int, n+1)
bi := big.Int{}
for i := 0; i <= n; i++ {
bn := int(bi.Binomial(int64(n), int64(i)).Int64())
sum = sum + bn
s[i] = sum
}
for j := 0; j < n; j++ {
if float64(s[j])/float64(sum) >= threshold {
return j
}
}
return n
}
Binary file added boruta.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added boruta05.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
120 changes: 120 additions & 0 deletions examples/boruta.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package main

import (
"fmt"
"math/rand"

randomforest "github.com/malaschitz/randomForest"
"github.com/malaschitz/randomForest/examples/img"
"github.com/petar/GoMNIST"
)

/*
Using boruta for mnist
With threshold 5% was selected 495 features from 784.
Result of random forest was the same - 96.2% as in minst.go
With threshold 10% was selected 493 features from 784 with the same result.
*/
func main() {
size := 60000
xsize := 28 * 28
labels, err := GoMNIST.ReadLabelFile("examples/train-labels-idx1-ubyte.gz")
if err != nil {
panic(err)
}
_, _, imgs, err := GoMNIST.ReadImageFile("examples/train-images-idx3-ubyte.gz")
if err != nil {
panic(err)
}
if len(labels) != size || len(imgs) != size {
panic("Wrong size")
}
x := make([][]float64, size)
l := make([]int, size)
for i := 0; i < size; i++ {
x[i] = make([]float64, xsize)
for j := 0; j < xsize; j++ {
x[i][j] = float64(imgs[i][j])
l[i] = int(labels[i])
}
}
//sample for testing
//x, l = sample(x, l, 200)
//
borutaFeatuters := randomforest.BorutaDefault(x, l)
//borutaFeatuters := randomforest.Boruta(x, l, 100, 20, 0.05, true, true)
image := make([]byte, xsize)
for _, v := range borutaFeatuters {
image[v] = 255
}
//save redsults as image
img.WriteImage(image, "boruta")
//
// try forest with selected features
//
xSmall := make([][]float64, len(x))
for i := 0; i < len(x); i++ {
xSmall[i] = make([]float64, len(borutaFeatuters))
for j := 0; j < len(borutaFeatuters); j++ {
xSmall[i][j] = x[i][borutaFeatuters[j]]
}
}
forest := randomforest.Forest{}
forest.Data = randomforest.ForestData{X: xSmall, Class: l}
forest.Train(100)

//read test data
tsize := 10000
tlabels, err := GoMNIST.ReadLabelFile("examples/t10k-labels-idx1-ubyte.gz")
if err != nil {
panic(err)
}
_, _, timgs, err := GoMNIST.ReadImageFile("examples/t10k-images-idx3-ubyte.gz")
if err != nil {
panic(err)
}
if len(tlabels) != tsize || len(timgs) != tsize {
panic("Wrong size")
}
//calculate difference
x = make([][]float64, tsize)
for i := 0; i < tsize; i++ {
x[i] = make([]float64, len(borutaFeatuters))
for j := 0; j < len(borutaFeatuters); j++ {
x[i][j] = float64(timgs[i][borutaFeatuters[j]])
}
}
p := 0
for i := 0; i < tsize; i++ {
vote := forest.Vote(x[i])
bestI := -1
bestV := 0.0
for j, v := range vote {
if v > bestV {
bestV = v
bestI = j
}
}
if int(tlabels[i]) == bestI {
p++
}
}
fmt.Printf("Trees: %d Results: %5.1f%%\n", 100, 100.0*float64(p)/float64(tsize))
// Selected 495 features from 784
// Trees: 100 Results: 96.2%
}

//create samples from data
func sample(x [][]float64, y []int, count int) (xx [][]float64, yy []int) {
xx = make([][]float64, count)
yy = make([]int, count)
for i := 0; i < count; i++ {
k := rand.Intn(len(x))
xx[i] = x[k]
yy[i] = y[k]
}
return
}
21 changes: 21 additions & 0 deletions examples/img/write.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package img

import (
"image"
"image/png"
"log"
"os"
)

func WriteImage(data []byte, name string) {
img := image.NewGray(image.Rect(0, 0, 28, 28))
img.Pix = data
out, err := os.Create("./" + name + ".png")
if err != nil {
log.Fatal(err)
}
err = png.Encode(out, img)
if err != nil {
log.Fatal(err)
}
}
23 changes: 16 additions & 7 deletions examples/mnist_test.go → examples/mnist.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package main

/*
Removed from test package, because execution time is too long.
A it was problem where was package publihed on AwesomeGO.
*/

import (
"fmt"
"math/rand"
Expand All @@ -15,14 +20,14 @@ import (
func ExampleMNIST() {
//read data
rand.Seed(1)
TREES := 10
TREES := 100
size := 60000
xsize := 28 * 28
labels, err := GoMNIST.ReadLabelFile("train-labels-idx1-ubyte.gz")
labels, err := GoMNIST.ReadLabelFile("examples/train-labels-idx1-ubyte.gz")
if err != nil {
panic(err)
}
_, _, imgs, err := GoMNIST.ReadImageFile("train-images-idx3-ubyte.gz")
_, _, imgs, err := GoMNIST.ReadImageFile("examples/train-images-idx3-ubyte.gz")
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -51,11 +56,11 @@ func ExampleMNIST() {

//read test data
tsize := 10000
tlabels, err := GoMNIST.ReadLabelFile("t10k-labels-idx1-ubyte.gz")
tlabels, err := GoMNIST.ReadLabelFile("examples/t10k-labels-idx1-ubyte.gz")
if err != nil {
panic(err)
}
_, _, timgs, err := GoMNIST.ReadImageFile("t10k-images-idx3-ubyte.gz")
_, _, timgs, err := GoMNIST.ReadImageFile("examples/t10k-images-idx3-ubyte.gz")
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -89,6 +94,10 @@ func ExampleMNIST() {
//writeImage(timgs[i], fmt.Sprintf("img%06d_%d_%d", i, tlabels[i], bestLabel))
}
}
fmt.Printf("Trees: %d Results: %5.0f%%\n", TREES, 100.0*float64(p)/float64(tsize))
//Output: Trees: 10 Results: 95%
fmt.Printf("Trees: %d Results: %5.1f%%\n", TREES, 100.0*float64(p)/float64(tsize))
//Output: Trees: 10 Results: 96.0%
}

func main() {
ExampleMNIST()
}
Loading

0 comments on commit 3fec4d9

Please sign in to comment.