-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathhooks.py
92 lines (65 loc) · 2.52 KB
/
hooks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.training import session_run_hook
from tensorflow.python.training.basic_session_run_hooks import _as_graph_element
import numpy as np
import time
class EarlyStopping(session_run_hook.SessionRunHook):
def __init__(self, metric, start_time, target=0.97, check_every=100, max_secs=10):
self.metric = metric
self.target = target
self.counter = 0
self.check_every = check_every
self.max_secs = max_secs
self.start_time = start_time
def before_run(self, run_context):
self.counter += 1
self.should_check = (self.counter % self.check_every) == 0
if self.should_check:
return session_run_hook.SessionRunArgs([self.metric])
def after_run(self, run_context, run_values):
if self.should_check and run_values.results is not None:
t = run_values.results[0][1]
if t > self.target:
tf.logging.info(f"Early stopping as exceeded target {t} > {self.target}")
run_context.request_stop()
if (time.time() - self.start_time) > self.max_secs:
tf.logging.info(f"EarlyStopping as time run out {time.time() - self.start_time} > {self.max_secs}")
run_context.request_stop()
class CallbackHook(session_run_hook.SessionRunHook):
def __init__(self, metrics=None, callback_after=None, callback_end=None):
self.metrics = metrics
self.callback_after = callback_after
self.callback_end = callback_end
def before_run(self, run_context):
if self.metrics is not None:
return session_run_hook.SessionRunArgs(self.metrics)
def after_run(self, run_context, run_values):
if self.callback_after is not None:
self.callback_after(run_context, run_values)
def end(self, session):
if self.callback_end is not None:
self.callback_end(session)
class LastMetricHook(session_run_hook.SessionRunHook):
def __init__(self, metric, cb):
self.metric = metric
self.cb = cb
self.reading = None
def before_run(self, run_context):
return session_run_hook.SessionRunArgs([self.metric])
def after_run(self, run_context, run_values):
self.reading = run_values.results[0][1]
def end(self, session):
self.cb(self.reading)
class MetricHook(session_run_hook.SessionRunHook):
def __init__(self, metric, cb):
self.metric = metric
self.cb = cb
self.readings = []
def before_run(self, run_context):
return session_run_hook.SessionRunArgs([self.metric])
def after_run(self, run_context, run_values):
self.readings.append(run_values.results[0][1])
def end(self, session):
self.cb(np.average(self.readings))
self.readings.clear()