-
Notifications
You must be signed in to change notification settings - Fork 2
/
svm.py
37 lines (28 loc) · 1.04 KB
/
svm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# from: https://www.geeksforgeeks.org/multiclass-classification-using-scikit-learn/
# importing necessary libraries
from sklearn import datasets
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
# loading the iris dataset
iris = datasets.load_iris()
# X -> features, y -> label
X = iris.data
y = iris.target
# dividing X, y into train and test data
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0)
# print(X_train.shape)
print(y_train.shape)
# print(X_test.shape)
print(y_test.shape)
# print(type(X_train), type(y_train))
# print(type(X_test), type(y_test))
# training a linear SVM classifier
from sklearn.svm import SVC
svm_model_linear = SVC(kernel = 'linear', C = 1).fit(X_train, y_train)
svm_predictions = svm_model_linear.predict(X_test)
# model accuracy for X_test
accuracy = svm_model_linear.score(X_test, y_test)
# creating a confusion matrix
cm = confusion_matrix(y_test, svm_predictions)
print(accuracy)
print(cm)