-
Notifications
You must be signed in to change notification settings - Fork 133
/
Copy pathtfidf_test.go
151 lines (117 loc) · 3.89 KB
/
tfidf_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package text
import (
"fmt"
"os"
"sort"
"testing"
"github.com/cdipaolo/goml/base"
"github.com/stretchr/testify/assert"
)
func init() {
// create the /tmp/.goml/ dir for persistance testing
// if it doesn't already exist!
err := os.MkdirAll("/tmp/.goml", os.ModePerm)
if err != nil {
panic(err.Error())
}
}
func TestExampleTFIDFShouldPass1(t *testing.T) {
// create the channel of data and errors
stream := make(chan base.TextDatapoint, 40)
errors := make(chan error)
// make a new NaiveBayes model with
// 2 classes expected (classes in
// datapoints will now expect {0,1}.
// in general, given n as the classes
// variable, the model will expect
// datapoint classes in {0,...,n-1})
model := NewNaiveBayes(stream, 3, base.OnlyWordsAndNumbers)
go model.OnlineLearn(errors)
stream <- base.TextDatapoint{
X: "I love the city",
}
stream <- base.TextDatapoint{
X: "I hate Los Angeles",
}
stream <- base.TextDatapoint{
X: "My mother is not a nice lady lady lady lady",
}
close(stream)
for {
err, more := <-errors
if more {
fmt.Printf("Error passed: %v", err)
} else {
// training is done!
break
}
}
// cast NaiveBayes model to TFIDF
tf := TFIDF(*model)
greater := tf.TFIDF("I", "I don't think my mother is not a nice lady and I know you're wrong and I can prove it!!!!")
lesser := tf.TFIDF("lady", "I don't think my mother is not a nice lady and I know you're wrong and I can prove it!!!!")
assert.True(t, greater > lesser, "TFIDF for 'I' (%v) should be greater than TFIDF for 'lady' (%v)", greater, lesser)
freq := tf.MostImportantWords("I don't think my mother is not a nice lady and I know you're wrong!", 4)
assert.Len(t, freq, 4, "Length of Frequencies (%v) should be 4", freq)
assert.True(t, sort.IsSorted(sort.Reverse(freq)), "Frequencies (%v) should be sorted!", freq)
fmt.Printf("Freq: %v\nI: %v\tlady: %v\n\n", freq, greater, lesser)
}
func TestAreaTFIDFShouldPass1(t *testing.T) {
// create the channel of data and errors
stream := make(chan base.TextDatapoint, 40)
errors := make(chan error)
// make a new NaiveBayes model with
// 2 classes expected (classes in
// datapoints will now expect {0,1}.
// in general, given n as the classes
// variable, the model will expect
// datapoint classes in {0,...,n-1})
model := NewNaiveBayes(stream, 3, base.OnlyWordsAndNumbers)
go model.OnlineLearn(errors)
stream <- base.TextDatapoint{
X: "Indian cities look alright",
}
stream <- base.TextDatapoint{
X: "New Delhi, a city in India, gets very hot",
}
stream <- base.TextDatapoint{
X: "Indian food is oftentimes based on vegetables",
}
stream <- base.TextDatapoint{
X: "China is a large country",
}
stream <- base.TextDatapoint{
X: "Chinese food tastes good",
}
stream <- base.TextDatapoint{
X: "Chinese, as a country, has a lot of people in it",
}
stream <- base.TextDatapoint{
X: "Japan makes sushi and cars",
}
stream <- base.TextDatapoint{
X: "Many Japanese people are Buddhist",
}
stream <- base.TextDatapoint{
X: "Japanese architecture looks nice",
}
close(stream)
for {
err, more := <-errors
if more {
fmt.Printf("Error passed: %v", err)
} else {
// training is done!
break
}
}
// cast NaiveBayes model to TFIDF
tf := TFIDF(*model)
greater := tf.TFIDF("sushi", "sushi is my favorite buddhist related food and sushi is fun")
lesser := tf.TFIDF("buddhist", "sushi is my favorite buddhist related food and sushi is fun")
assert.True(t, greater > lesser, "TFIDF for 'buddhist' (%v) should be less than TFIDF for 'sushi' (%v)", lesser, greater)
freq := tf.MostImportantWords("Sushi is really great and sushi is awesome and sushi sushi sushi sushi sushi sushi!!", 4)
assert.Len(t, freq, 4, "Length of Frequencies (%v) should be 4", freq)
assert.True(t, sort.IsSorted(sort.Reverse(freq)), "Frequencies (%v) should be sorted!", freq)
fmt.Printf("Freq: %v\nSushi: %v\tBuddhist: %v\n\n", freq, greater, lesser)
}