This repository has been archived by the owner on Dec 1, 2017. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
tfhelper.py
249 lines (196 loc) · 8.15 KB
/
tfhelper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import functools
import itertools
import operator
import os
import signal
import tensorflow as tf
def create_reset_metric(metric, scope, **metric_args):
"""Creates a ops to handle streaming metrics.
This is a wrapper function to create a streaming metric (usually
tf.contrib.metrics.streaming_*) with a reset operation.
Args:
metric: The metric function
scope: The variable scope name (should be unique, as the variables of
this scope will be reset every time the reset op is evaluated)
metric_args: The arguments to be passed on to the metric.
Returns:
Three ops: the metric read_out op, the update op and the reset op:
metric_op, update_op, reset_op
"""
with tf.variable_scope(scope) as scope:
metric_op, update_op = metric(**metric_args)
vars = tf.contrib.framework.get_variables(
scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
reset_op = tf.variables_initializer(vars)
return metric_op, update_op, reset_op
def make_template(scope=None, create_scope_now_=False, unique_name_=None,
custom_getter_=None, **kwargs):
"""A decorator to map a function as a tf template using tf.make_template.
This enables variable sharing between multiple instances of that function.
Args:
scope: The scope for this template. Defaults to the function name.
create_scope_now_: Passed to the tf.make_template function.
unique_name_: Passed to the tf.make_template function.
custom_getter_: Passed to the tf.make_template function.
kwargs: Passed to the tf.make_template function.
Returns:
The function wrapped inside a tf.make_template.
"""
def make_tf_template(function):
template = tf.make_template(function.__name__
if scope is None or callable(scope)
else scope,
function,
create_scope_now_=create_scope_now_,
unique_name_=unique_name_,
custom_getter_=custom_getter_,
**kwargs)
@functools.wraps(function)
def wrapper(*caller_args, **caller_kwargs):
return template(*caller_args, **caller_kwargs)
return wrapper
if callable(scope):
return make_tf_template(scope)
return make_tf_template
def name_scope(scope, *scopeargs, **scopekwargs):
"""A decorator to wrap a function into a tf.name_scope.
Args:
scope: The scope name.
Returns:
The wrapped function.
"""
def add_scope(function):
@functools.wraps(function)
def wrapper(*args, **kwargs):
with tf.name_scope(scope, *scopeargs, **scopekwargs):
return function(*args, **kwargs)
return wrapper
return add_scope
def variable_scope(scope, *scopeargs, **scopekwargs):
"""A decorator to wrap a function into a tf.variable_scope.
Args:
scope: The scope name.
Returns:
The wrapped function.
"""
def add_scope(function):
@functools.wraps(function)
def wrapper(*args, **kwargs):
with tf.variable_scope(scope, *scopeargs, **scopekwargs):
return function(*args, **kwargs)
return wrapper
return add_scope
def with_device(device):
"""A decorator to specify a device for a function.
Args:
device: The device name.
Returns:
The wrapped function.
"""
def set_device(function):
@functools.wraps(function)
def wrapper(*args, **kwargs):
with tf.device(device):
return function(*args, **kwargs)
return wrapper
return set_device
def estimate_size_of(graphkey):
"""Estimates the size of all tensors in a collection.
Args:
graphkey: The GraphKey key.
Returns:
The estimated size in MB.
"""
return sum([functools.reduce(operator.mul, [int(s) for s in v.shape])
for v in tf.get_collection(graphkey)]) * 4 / 1024 / 1024
def create_summary_hook(graphkey, ckptdir, steps=150):
"""Adds a summary hook with scalar summaries of tensor values for
tensors inside the collection of graphkey.
Args:
graphkey: The key which tensors should be summarized.
ckptdir: The checkpoint directory.
steps: The summary will be stored every N steps.
Returns:
A SummarySaverHook which saves the requested summaries.
"""
tensors = tf.get_collection(graphkey)
summaries = []
for tensor in tensors:
name = '/'.join(tensor.name.split('/')[0:2]).split(':')[0]
summaries.append(tf.summary.scalar(name, tensor, []))
summary_op = tf.summary.merge(summaries)
return tf.train.SummarySaverHook(save_steps=steps,
output_dir=ckptdir,
summary_op=summary_op)
class StopAtSignalHook(tf.train.SessionRunHook):
"""Hook that requests stop when a signal is received."""
def __init__(self, signals=None):
"""Initializes a `StopAtSignalHook`.
The hook requests stop if one of the specified signals is received.
Handles by default these signals (if signals is None):
SIGUSR1, SIGUSR2, SIGALRM, SIGINT, SIGTERM
The list can be overwritten by setting signals manually.
Args:
signals: List of signals to handle.
"""
self.signal_received = 0
if signals is None:
signals = [signal.SIGUSR1, signal.SIGUSR2,
signal.SIGALRM, signal.SIGINT, signal.SIGTERM]
for s in signals:
signal.signal(s, self.__signal_handler)
def __signal_handler(self, signum, frame):
"""Sets self.signal_received to signum."""
self.signal_received = signum
def after_run(self, run_context, run_values):
"""If a signal was received, a stop will be requested."""
if self.signal_received:
run_context.request_stop()
class TraceHook(tf.train.SessionRunHook):
"""Hook to perform Traces every N steps."""
def __init__(self, ckptdir, every_step=50,
trace_level=tf.RunOptions.FULL_TRACE):
"""Initializes the TraceHook.
Traces the 1st (after every restart) and every N-th (total) step.
Args:
ckptdir: The checkpoint directory.
every_step: Each N-th step a trace will be performed.
trace_level: The trace level to be passed to tf.RunOptions.
"""
self._trace = True
self.writer = tf.summary.FileWriter(ckptdir)
self.trace_level = trace_level
self.every_step = every_step
def begin(self):
"""Check if the global step is available inside the graph."""
self._global_step_tensor = tf.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use _TraceHook.")
def before_run(self, run_context):
"""If a trace is requested, adds tf.RunOptions to the session args.
Always requests the global step.
Args:
run_context: The run context.
Returns:
SessionRunArgs as described above.
"""
if self._trace:
options = tf.RunOptions(trace_level=self.trace_level)
else:
options = None
return tf.train.SessionRunArgs(fetches=self._global_step_tensor,
options=options)
def after_run(self, run_context, run_values):
"""If a trace was requested for this run, store the results.
Otherwise check if the next step should request a trace.
Args:
run_context: The original run context.
run_values: The resulting run values.
"""
global_step = run_values.results
if self._trace:
self._trace = False
self.writer.add_run_metadata(run_values.run_metadata,
f'{global_step}', global_step)
if not (global_step + 1) % self.every_step:
self._trace = True