-
Notifications
You must be signed in to change notification settings - Fork 5
/
metrics.py
74 lines (64 loc) · 2.69 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# -*- coding: utf-8 -*-
from collections import defaultdict
import numpy as np
from keras.callbacks import Callback
from sklearn.metrics import f1_score, accuracy_score
from boltons.iterutils import chunked_iter
from keras.preprocessing.sequence import pad_sequences
class Metrics(Callback):
def __init__(self,
batch_size,
max_len,
eval_data,
save_path,
min_delta=1e-4,
patience=10):
self.patience = patience
self.min_delta = min_delta
self.monitor_op = np.greater
self.save_path = save_path
self.batch_size = batch_size
self.max_len = max_len
self.eval_data = eval_data
self.history = defaultdict(list)
def on_train_begin(self, logs=None):
self.step = 0
self.wait = 0
self.stopped_epoch = 0
self.warmup_epochs = 2
self.best = -np.Inf
def calc_metrics(self):
y_true, y_pred = [], []
for chunk in chunked_iter(self.eval_data, self.batch_size):
token_ids = [obj['token_ids'] for obj in chunk]
segment_ids = [obj['segment_ids'] for obj in chunk]
tcol_ids = [obj['tcol_ids'] for obj in chunk]
true_labels = [obj['label_id'] for obj in chunk]
token_ids = pad_sequences(token_ids, maxlen=self.max_len, padding='post', truncating='post')
segment_ids = pad_sequences(segment_ids, maxlen=self.max_len, padding='post', truncating='post')
tcol_ids = pad_sequences(tcol_ids)
pred = self.model.predict([token_ids, segment_ids, tcol_ids])
pred = np.argmax(pred, 1)
y_true += list(true_labels)
y_pred += list(pred)
acc = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average="macro")
return f1, acc
def on_epoch_end(self, epoch, logs=None):
val_f1, val_acc = self.calc_metrics()
self.history['val_acc'].append(val_acc)
self.history['val_f1'].append(val_f1)
print(f"- val_acc {val_acc} - val_f1 {val_f1}")
if self.monitor_op(val_f1 - self.min_delta, self.best) or self.monitor_op(self.min_delta, val_f1):
self.best = val_f1
self.wait = 0
print(f'new best model, save model to {self.save_path}...')
self.model.save_weights(self.save_path)
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))