-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
231 lines (192 loc) · 5.79 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
// Subayes main package is a bayesian cli build around github.com/jbrukh/bayesian
package main
// 2023/06 : cat : subayes : mail subject classification using bayesian filter
//
// v0.1 : working draft.
// v0.2 : minlen words, better split func, default bayes class, +main_test.go
// v0.3 : -E options for explaining and showing scores
// v1.0 tag for go doc
// V1.1 : ignore numbers
// v1.2 : lowerCase for items
// v1.3 : reading learn from Stdin if data filename is ""
//
// TODO :
// - how to remove item from db ?
import (
"bufio"
"errors"
"flag"
"fmt"
"io"
"os"
"regexp"
"strings"
// Credits to "github.com/jbrukh/bayesian"
"github.com/jbrukh/bayesian"
)
var (
db, data string // db Path, default datafile
learnSpam, learnHam bool
explain, verbose bool
lowerCase = true // Flags
)
func main() {
var (
minlength = 4 // Minimal word length
Spam bayesian.Class = "Spam"
Ham bayesian.Class = "Ham"
)
// db is classes data store path
flag.StringVar(&db, "db", "db", " db path")
// data is the file to be read when learning
flag.StringVar(&data, "d", "", "read input from data filename (stdin if empty)")
flag.IntVar(&minlength, "m", 4, "word min length")
flag.BoolVar(&lowerCase, "l", lowerCase, "lowerCase items")
// choosing between learning Spam or Ham (write db/classes files)
flag.BoolVar(&learnSpam, "learnSpam", false, "learn Spam subjects")
flag.BoolVar(&learnHam, "learnHam", false, "learn Ham subjects")
flag.BoolVar(&explain, "E", false, "explain words scores")
flag.BoolVar(&verbose, "v", false, "verbose")
// Default is to read stdin line per line for classification
flag.Parse()
K := bayesian.NewClassifier(Ham, Spam)
switch {
case learnHam && learnSpam:
errcheck(errors.New("please choose learn Ham or Spam, not Both"))
case learnHam:
errcheck(learn(K, db, data, Ham, minlength))
showClassesCount(K)
case learnSpam:
errcheck(learn(K, db, data, Spam, minlength))
showClassesCount(K)
case !learnHam && !learnSpam:
errcheck(K.ReadClassFromFile(Spam, db))
errcheck(K.ReadClassFromFile(Ham, db))
// Is it needed ? TF-IDF
K.ConvertTermsFreqToTfIdf()
showClassesCount(K)
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() { // read line per line
text := scanner.Text()
if len(text) > minlength { // Minimum 'Re: '
spl := removeDuplicate(split(text), minlength)
fmt.Printf("%v: %s\n",
classify(K, spl, Ham),
text)
} else {
if verbose {
fmt.Fprintf(os.Stderr, "Warning short string : \"%s\"", text)
}
}
}
if err := scanner.Err(); err != nil {
errcheck(err)
}
}
}
// learn ingest data file into bayesian class and save to classifier db
func learn(c *bayesian.Classifier, xdb string, input string, class bayesian.Class, minilength int) (err error) {
c.ReadClassFromFile(class, xdb)
// if db/class don't exist, we will create it, so any err is acceptable
// errcheck(err)
// Better error handling should test error for acceptables ones
showClassesCount(c)
var in []byte
if len(input) != 0 {
in, err = os.ReadFile(input) // in type is []byte
} else {
in, err = io.ReadAll(os.Stdin)
}
errcheck(err)
ins := string(in) // ins type is string
indata := split(ins) // indata is []string
indedup := removeDuplicate(indata, minilength)
c.Learn(indedup, class)
err = c.WriteClassToFile(class, xdb)
errcheck(err)
return nil
}
// classify return bayesian Class of []string from a classifier
// if explain option is given, it print out class score on stderr
func classify(c *bayesian.Classifier, pattern []string, d bayesian.Class) bayesian.Class {
if len(pattern) == 0 { // return default class
if verbose {
fmt.Fprintf(os.Stderr, "Warning Empty pattern\n")
}
return d
}
if explain {
for _, word := range pattern {
var wordarr []string
wordarr = append(wordarr, word)
scores, likelyb, _ := c.ProbScores(wordarr)
fmt.Fprintf(os.Stderr, "[ %s = %s ] : ", word, c.Classes[likelyb])
for i := 0; i < len(c.Classes); i++ {
fmt.Fprintf(os.Stderr, "[%v]{ %.4f } ", c.Classes[i], scores[i])
}
fmt.Fprintf(os.Stderr, "\n")
}
}
// ProbScores return scores ([]float64), indexofclass, strict(?)
_, likelyb, _ := c.ProbScores(pattern)
// Would testing strict should be done ?
// _, likelyb, strict := c.ProbScores(pattern)
// if false returning default class d ?
return c.Classes[likelyb]
}
// showClassesCount display classes and item counts
func showClassesCount(c *bayesian.Classifier) {
if !verbose {
return
}
fmt.Fprintf(os.Stderr, "INFO classifier corpus : ")
for i := 0; i < len(c.Classes); i++ {
if c.WordCount()[i] > 0 {
fmt.Fprintf(os.Stderr,
" [ %v -> %d items ]",
c.Classes[i],
c.WordCount()[i])
}
}
fmt.Fprintln(os.Stderr)
}
// errcheck func perform basic error check
func errcheck(e error) {
if e != nil {
fmt.Fprintf(os.Stderr, "%v", e)
os.Exit(-1)
}
}
// split function return []string of words from string
func split(s string) []string {
var words = regexp.MustCompile(`[\p{L}]+`)
// See http://www.unicode.org/reports/tr44/#General_Category_Values
// return rxp.Split(s, -1)
// words := regexp.MustCompile("\\w+")
// words := regexp.MustCompile("\\P{M}+")
// words := regexp.MustCompile("[\\p{L}]+")
return words.FindAllString(s, -1)
}
// removeDuplicate function remove duplicate entries from []string
// and entries length must be > length parameter
func removeDuplicate(sliceList []string, length int) []string {
var digits = regexp.MustCompile(`^[0-9\.]+$`)
allKeys := make(map[string]bool)
list := []string{}
for _, item := range sliceList {
if digits.MatchString(item) {
continue
}
if len(item) < length {
continue
}
if lowerCase {
item = strings.ToLower(item)
}
if _, value := allKeys[item]; !value {
allKeys[item] = true
list = append(list, item)
}
}
return list
}