-
Notifications
You must be signed in to change notification settings - Fork 20
/
plot.py
51 lines (35 loc) · 1.97 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
'''
Graph plotting functions.
'''
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
fig = plt.figure(figsize=(20, 5))
def plot_loss_acc(path, num_epoch, train_accuracies_superclass, train_accuracies_subclass, train_losses,
test_accuracies_superclass, test_accuracies_subclass, test_losses):
'''
Plot line graphs for the accuracies and loss at every epochs for both training and testing.
'''
plt.clf()
epochs = [x for x in range(num_epoch+1)]
train_superclass_accuracy_df = pd.DataFrame({"Epochs":epochs, "Accuracy":train_accuracies_superclass, "Mode":['train']*(num_epoch+1)})
train_subclass_accuracy_df = pd.DataFrame({"Epochs":epochs, "Accuracy":train_accuracies_subclass, "Mode":['train']*(num_epoch+1)})
test_superclass_accuracy_df = pd.DataFrame({"Epochs":epochs, "Accuracy":test_accuracies_superclass, "Mode":['test']*(num_epoch+1)})
test_subclass_accuracy_df = pd.DataFrame({"Epochs":epochs, "Accuracy":test_accuracies_subclass, "Mode":['test']*(num_epoch+1)})
data_superclass = pd.concat([train_superclass_accuracy_df, test_superclass_accuracy_df])
data_subclass = pd.concat([train_subclass_accuracy_df, test_subclass_accuracy_df])
sns.lineplot(data=data_superclass, x='Epochs', y='Accuracy', hue='Mode')
plt.title('Superclass Accuracy Graph')
plt.savefig(path+f'accuracy_superclass_epoch.png')
plt.clf()
sns.lineplot(data=data_subclass, x='Epochs', y='Accuracy', hue='Mode')
plt.title('Subclass Accuracy Graph')
plt.savefig(path+f'accuracy_subclass_epoch.png')
plt.clf()
train_loss_df = pd.DataFrame({"Epochs":epochs, "Loss":train_losses, "Mode":['train']*(num_epoch+1)})
test_loss_df = pd.DataFrame({"Epochs":epochs, "Loss":test_losses, "Mode":['test']*(num_epoch+1)})
data = pd.concat([train_loss_df, test_loss_df])
sns.lineplot(data=data, x='Epochs', y='Loss', hue='Mode')
plt.title('Loss Graph')
plt.savefig(path+f'loss_epoch.png')
return None