-
Notifications
You must be signed in to change notification settings - Fork 3
/
class_saver.py
67 lines (47 loc) · 1.96 KB
/
class_saver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import pandas as pd
import numpy as np
from sklearn.metrics import *
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network.multilayer_perceptron import MLPClassifier
from sklearn.svm import SVC
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
import pickle
import os, sys, re, time
proj_path = "/home/jwales/eclipse-workspace/crowdnews/"
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "crowdnews.settings")
sys.path.append(proj_path)
os.chdir(proj_path)
from django.core.wsgi import get_wsgi_application
application = get_wsgi_application()
from django.contrib.gis.views import feed
from newsbot.models import *
from newsbot.util import *
print("Setting up..")
cDict = loadCanonDict()
qs_Examples = ArticleExample.objects.filter(quality_class__lt = 5)
print("Processing examples")
(Y_vector, examplesMatrix) = processExamples(qs_Examples, cDict)
X_train, X_test, y_train, y_test = train_test_split(examplesMatrix, Y_vector, test_size=0.2)
chosen_models = {}
chosen_models['newsbot/MLPC_model.sav'] = MLPClassifier(hidden_layer_sizes=(128,64,32,16,8), max_iter=2500)
chosen_models['newsbot/svc_model.sav'] = SVC(gamma='scale', probability = True)
chosen_models['newsbot/log_model.sav'] = LogisticRegression()
for fname, model in chosen_models.items():
print("Working on " + fname)
model.fit(X_train, y_train)
predictions = model.predict(X_test)
print("Classification report: ")
print(classification_report(predictions, y_test))
print("***************")
dosave = input("Save " + fname + "? ")
if(dosave == 'y' or dosave == 'Y'):
print("Saving...")
pickle.dump(model, open(fname, 'wb'))
print("Saved!")
else:
print("Not saved!")