-
Notifications
You must be signed in to change notification settings - Fork 601
/
Copy pathmeta.py
402 lines (325 loc) · 13.3 KB
/
meta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
# Copyright 2016 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Learning to learn (meta) optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import contextlib
import os
import mock
import sonnet as snt
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.util import nest
import networks
def _nested_assign(ref, value):
"""Returns a nested collection of TensorFlow assign operations.
Args:
ref: Nested collection of TensorFlow variables.
value: Values to be assigned to the variables. Must have the same structure
as `ref`.
Returns:
Nested collection (same structure as `ref`) of TensorFlow assign operations.
Raises:
ValueError: If `ref` and `values` have different structures.
"""
if isinstance(ref, list) or isinstance(ref, tuple):
if len(ref) != len(value):
raise ValueError("ref and value have different lengths.")
result = [_nested_assign(r, v) for r, v in zip(ref, value)]
if isinstance(ref, tuple):
return tuple(result)
return result
else:
return tf.assign(ref, value)
def _nested_variable(init, name=None, trainable=False):
"""Returns a nested collection of TensorFlow variables.
Args:
init: Nested collection of TensorFlow initializers.
name: Variable name.
trainable: Make variables trainable (`False` by default).
Returns:
Nested collection (same structure as `init`) of TensorFlow variables.
"""
if isinstance(init, list) or isinstance(init, tuple):
result = [_nested_variable(i, name, trainable) for i in init]
if isinstance(init, tuple):
return tuple(result)
return result
else:
return tf.Variable(init, name=name, trainable=trainable)
def _wrap_variable_creation(func, custom_getter):
"""Provides a custom getter for all variable creations."""
original_get_variable = tf.get_variable
def custom_get_variable(*args, **kwargs):
if hasattr(kwargs, "custom_getter"):
raise AttributeError("Custom getters are not supported for optimizee "
"variables.")
return original_get_variable(*args, custom_getter=custom_getter, **kwargs)
# Mock the get_variable method.
with mock.patch("tensorflow.get_variable", custom_get_variable):
return func()
def _get_variables(func):
"""Calls func, returning any variables created, but ignoring its return value.
Args:
func: Function to be called.
Returns:
A tuple (variables, constants) where the first element is a list of
trainable variables and the second is the non-trainable variables.
"""
variables = []
constants = []
def custom_getter(getter, name, **kwargs):
trainable = kwargs["trainable"]
kwargs["trainable"] = False
variable = getter(name, **kwargs)
if trainable:
variables.append(variable)
else:
constants.append(variable)
return variable
with tf.name_scope("unused_graph"):
_wrap_variable_creation(func, custom_getter)
return variables, constants
def _make_with_custom_variables(func, variables):
"""Calls func and replaces any trainable variables.
This returns the output of func, but whenever `get_variable` is called it
will replace any trainable variables with the tensors in `variables`, in the
same order. Non-trainable variables will re-use any variables already
created.
Args:
func: Function to be called.
variables: A list of tensors replacing the trainable variables.
Returns:
The return value of func is returned.
"""
variables = collections.deque(variables)
def custom_getter(getter, name, **kwargs):
if kwargs["trainable"]:
return variables.popleft()
else:
kwargs["reuse"] = True
return getter(name, **kwargs)
return _wrap_variable_creation(func, custom_getter)
MetaLoss = collections.namedtuple("MetaLoss", "loss, update, reset, fx, x")
MetaStep = collections.namedtuple("MetaStep", "step, update, reset, fx, x")
def _make_nets(variables, config, net_assignments):
"""Creates the optimizer networks.
Args:
variables: A list of variables to be optimized.
config: A dictionary of network configurations, each of which will be
passed to networks.Factory to construct a single optimizer net.
net_assignments: A list of tuples where each tuple is of the form (netid,
variable_names) and is used to assign variables to networks. netid must
be a key in config.
Returns:
A tuple (nets, keys, subsets) where nets is a dictionary of created
optimizer nets such that the net with key keys[i] should be applied to the
subset of variables listed in subsets[i].
Raises:
ValueError: If net_assignments is None and the configuration defines more
than one network.
"""
# create a dictionary which maps a variable name to its index within the
# list of variables.
name_to_index = dict((v.name.split(":")[0], i)
for i, v in enumerate(variables))
if net_assignments is None:
if len(config) != 1:
raise ValueError("Default net_assignments can only be used if there is "
"a single net config.")
with tf.variable_scope("vars_optimizer"):
key = next(iter(config))
kwargs = config[key]
net = networks.factory(**kwargs)
nets = {key: net}
keys = [key]
subsets = [range(len(variables))]
else:
nets = {}
keys = []
subsets = []
with tf.variable_scope("vars_optimizer"):
for key, names in net_assignments:
if key in nets:
raise ValueError("Repeated netid in net_assigments.")
nets[key] = networks.factory(**config[key])
subset = [name_to_index[name] for name in names]
keys.append(key)
subsets.append(subset)
print("Net: {}, Subset: {}".format(key, subset))
# subsets should be a list of disjoint subsets (as lists!) of the variables
# and nets should be a list of networks to apply to each subset.
return nets, keys, subsets
class MetaOptimizer(object):
"""Learning to learn (meta) optimizer.
Optimizer which has an internal RNN which takes as input, at each iteration,
the gradient of the function being minimized and returns a step direction.
This optimizer can then itself be optimized to learn optimization on a set of
tasks.
"""
def __init__(self, **kwargs):
"""Creates a MetaOptimizer.
Args:
**kwargs: A set of keyword arguments mapping network identifiers (the
keys) to parameters that will be passed to networks.Factory (see docs
for more info). These can be used to assign different optimizee
parameters to different optimizers (see net_assignments in the
meta_loss method).
"""
self._nets = None
if not kwargs:
# Use a default coordinatewise network if nothing is given. this allows
# for no network spec and no assignments.
self._config = {
"coordinatewise": {
"net": "CoordinateWiseDeepLSTM",
"net_options": {
"layers": (20, 20),
"preprocess_name": "LogAndSign",
"preprocess_options": {"k": 5},
"scale": 0.01,
}}}
else:
self._config = kwargs
def save(self, sess, path=None):
"""Save meta-optimizer."""
result = {}
for k, net in self._nets.items():
if path is None:
filename = None
key = k
else:
filename = os.path.join(path, "{}.l2l".format(k))
key = filename
net_vars = networks.save(net, sess, filename=filename)
result[key] = net_vars
return result
def meta_loss(self,
make_loss,
len_unroll,
net_assignments=None,
second_derivatives=False):
"""Returns an operator computing the meta-loss.
Args:
make_loss: Callable which returns the optimizee loss; note that this
should create its ops in the default graph.
len_unroll: Number of steps to unroll.
net_assignments: variable to optimizer mapping. If not None, it should be
a list of (k, names) tuples, where k is a valid key in the kwargs
passed at at construction time and names is a list of variable names.
second_derivatives: Use second derivatives (default is false).
Returns:
namedtuple containing (loss, update, reset, fx, x)
"""
# Construct an instance of the problem only to grab the variables. This
# loss will never be evaluated.
x, constants = _get_variables(make_loss)
print("Optimizee variables")
print([op.name for op in x])
print("Problem variables")
print([op.name for op in constants])
# Create the optimizer networks and find the subsets of variables to assign
# to each optimizer.
nets, net_keys, subsets = _make_nets(x, self._config, net_assignments)
# Store the networks so we can save them later.
self._nets = nets
# Create hidden state for each subset of variables.
state = []
with tf.name_scope("states"):
for i, (subset, key) in enumerate(zip(subsets, net_keys)):
net = nets[key]
with tf.name_scope("state_{}".format(i)):
state.append(_nested_variable(
[net.initial_state_for_inputs(x[j], dtype=tf.float32)
for j in subset],
name="state", trainable=False))
def update(net, fx, x, state):
"""Parameter and RNN state update."""
with tf.name_scope("gradients"):
gradients = tf.gradients(fx, x)
# Stopping the gradient here corresponds to what was done in the
# original L2L NIPS submission. However it looks like things like
# BatchNorm, etc. don't support second-derivatives so we still need
# this term.
if not second_derivatives:
gradients = [tf.stop_gradient(g) for g in gradients]
with tf.name_scope("deltas"):
deltas, state_next = zip(*[net(g, s) for g, s in zip(gradients, state)])
state_next = list(state_next)
return deltas, state_next
def time_step(t, fx_array, x, state):
"""While loop body."""
x_next = list(x)
state_next = []
with tf.name_scope("fx"):
fx = _make_with_custom_variables(make_loss, x)
fx_array = fx_array.write(t, fx)
with tf.name_scope("dx"):
for subset, key, s_i in zip(subsets, net_keys, state):
x_i = [x[j] for j in subset]
deltas, s_i_next = update(nets[key], fx, x_i, s_i)
for idx, j in enumerate(subset):
x_next[j] += deltas[idx]
state_next.append(s_i_next)
with tf.name_scope("t_next"):
t_next = t + 1
return t_next, fx_array, x_next, state_next
# Define the while loop.
fx_array = tf.TensorArray(tf.float32, size=len_unroll + 1,
clear_after_read=False)
_, fx_array, x_final, s_final = tf.while_loop(
cond=lambda t, *_: t < len_unroll,
body=time_step,
loop_vars=(0, fx_array, x, state),
parallel_iterations=1,
swap_memory=True,
name="unroll")
with tf.name_scope("fx"):
fx_final = _make_with_custom_variables(make_loss, x_final)
fx_array = fx_array.write(len_unroll, fx_final)
loss = tf.reduce_sum(fx_array.stack(), name="loss")
# Reset the state; should be called at the beginning of an epoch.
with tf.name_scope("reset"):
variables = (nest.flatten(state) +
x + constants)
# Empty array as part of the reset process.
reset = [tf.variables_initializer(variables), fx_array.close()]
# Operator to update the parameters and the RNN state after our loop, but
# during an epoch.
with tf.name_scope("update"):
update = (nest.flatten(_nested_assign(x, x_final)) +
nest.flatten(_nested_assign(state, s_final)))
# Log internal variables.
for k, net in nets.items():
print("Optimizer '{}' variables".format(k))
print([op.name for op in snt.get_variables_in_module(net)])
return MetaLoss(loss, update, reset, fx_final, x_final)
def meta_minimize(self, make_loss, len_unroll, learning_rate=0.01, **kwargs):
"""Returns an operator minimizing the meta-loss.
Args:
make_loss: Callable which returns the optimizee loss; note that this
should create its ops in the default graph.
len_unroll: Number of steps to unroll.
learning_rate: Learning rate for the Adam optimizer.
**kwargs: keyword arguments forwarded to meta_loss.
Returns:
namedtuple containing (step, update, reset, fx, x)
"""
info = self.meta_loss(make_loss, len_unroll, **kwargs)
optimizer = tf.train.AdamOptimizer(learning_rate)
step = optimizer.minimize(info.loss)
return MetaStep(step, *info[1:])