-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample.py
118 lines (93 loc) · 3.8 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from email.mime.text import MIMEText
from email.utils import formataddr
import smtplib
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import numpy as np
def retry_decorator(retries=3):
def decorator(func):
def wrapper(*args, **kwargs):
for i in range(retries):
try:
func(*args, **kwargs)
break
except Exception as e:
print(f"Failed to send notification. Error: {e}")
if i < retries - 1: # If not the last retry
print("Retrying...")
else: # On last retry, raise error
raise
return wrapper
return decorator
class NotificationCallback():
def __init__(self, sender_email, sender_auth, receiver_email, server="smtp.qq.com", port=465):
self.sender_email = sender_email
self.sender_auth = sender_auth
self.receiver_email = receiver_email
self.server = server
self.port = port
self.epoch = 0
def on_train_begin(self):
self.epoch = 0
def on_epoch_end(self, metrics):
val_loss, accuracy = metrics[0], metrics[1]
message = f"epoch: {self.epoch} val loss: {val_loss:.7f} val acc: {accuracy:.7f}"
# self.send_notification(message)
print(message)
self.epoch += 1
def on_train_end(self, metrics):
val_loss, accuracy = metrics[0], metrics[1]
message = f"epoch: {self.epoch} val loss: {val_loss:.7f} val acc: {accuracy:.7f}"
self.send_notification(message)
@retry_decorator()
def send_notification(self, msg, subject="Training Update", from_email="ML Training", to_email="User"):
msg = MIMEText(msg, 'plain', 'utf-8')
msg['From'] = formataddr([ from_email, self.sender_email])
msg['To'] = formataddr([to_email, self.receiver_email])
msg['Subject'] = subject
server = smtplib.SMTP_SSL(self.server, self.port)
server.login(self.sender_email, self.sender_auth)
server.sendmail(self.sender_email, [self.receiver_email, ], msg.as_string())
server.quit()
class ModelTraining:
def __init__(self):
self.callbacks = []
def add_callback(self, callback):
self.callbacks.append(callback)
def on_train_begin(self):
for callback in self.callbacks:
callback.on_train_begin()
def on_epoch_end(self, metrics):
for callback in self.callbacks:
callback.on_epoch_end(metrics)
def on_train_end(self, metrics):
for callback in self.callbacks:
callback.on_train_end(metrics)
def train_model(self):
# loading some example data
iris = load_iris()
X = iris.data
y = iris.target
# split the data with 50% in each set
X1, X2, y1, y2 = train_test_split(X, y, random_state=42, train_size=0.5, stratify=y)
# fit the model on one set of data
model = LogisticRegression(max_iter=200)
self.on_train_begin()
for i in range(100): # assuming 100 epochs
model.fit(X1, y1)
# evaluate the model on the second set of data
score = model.score(X2, y2)
# simulate loss
loss = np.random.rand()
# notify on epoch end
self.on_epoch_end((loss, score))
# notify on train end
self.on_train_end((loss, score))
if __name__ == '__main__':
sender_email = "[email protected]"
sender_auth = "sender_auth"
receiver_email = "[email protected]"
training = ModelTraining()
training.add_callback(NotificationCallback(sender_email, sender_auth, receiver_email))
training.train_model()