diff --git a/.travis.yml b/.travis.yml index 79881db88..3d7eb09bb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -82,14 +82,15 @@ jobs: distributions: "sdist " on: all_branches: true + tags: false condition: $TRAVIS_BRANCH =~ ^release-candidate-* - condition: $TRAVIS_TAG = "" - provider: pypi user: drasmuss password: $PYPI_TOKEN distributions: "sdist " on: all_branches: true + tags: true condition: $TRAVIS_TAG =~ ^v[0-9]* before_install: diff --git a/CHANGES.rst b/CHANGES.rst index 7d4762cd8..449d3dc4e 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -21,6 +21,31 @@ Release history 3.1.1 (unreleased) ------------------ +**Added** + +- Compatible with TensorFlow 2.2.0. (`#140`_) + +**Changed** + +- Saved simulator state will no longer be included in ``Simulator.keras_model.weights``. + This means that ``Simulator.keras_model.save/load_weights`` will not include the + saved simulator state, making it easier to reuse weights between models (as long as + the models have the same weights, they do not need to have the same state variables). + ``Simulator.save/load_params(..., include_state=True)`` can be used to explicitly + save the simulator state, if desired. (`#140`_) +- Model parameters (e.g., connection weights) that are not trainable (because they've + been marked non-trainable by user or targeted by an online learning rule) will now + be treated separately from simulator state. For example, + ``Simulator.save_params(..., include_state=False)`` will still include those + parameters, and the results of any online learning will persist between calls even + with ``stateful=False``. (`#140`_) + +**Deprecated** + +- Renamed ``Simulator.save/load_params`` ``include_non_trainable`` parameter to + ``include_state``. (`#140`_) + +.. _#140: https://github.com/nengo/nengo-dl/pull/140 3.1.0 (March 4, 2020) --------------------- diff --git a/docs/examples/lmu.ipynb b/docs/examples/lmu.ipynb index 4ddcd573e..91f729394 100644 --- a/docs/examples/lmu.ipynb +++ b/docs/examples/lmu.ipynb @@ -253,7 +253,7 @@ " else:\n", " urlretrieve(\n", " \"https://drive.google.com/uc?export=download&\"\n", - " \"id=1qhKmpnipk8AK2bsilPlJBFQg1t7XAEZr\",\n", + " \"id=1epcfVDdUaHkwNo1kD4kjIF7qlXgJmb2i\",\n", " \"lmu_params.npz\")\n", " sim.load_params(\"./lmu_params\")\n", "\n", diff --git a/docs/examples/spa-memory.ipynb b/docs/examples/spa-memory.ipynb index c33f6091d..c6b8dae71 100644 --- a/docs/examples/spa-memory.ipynb +++ b/docs/examples/spa-memory.ipynb @@ -592,7 +592,7 @@ " # download pretrained parameters\n", " urlretrieve(\n", " \"https://drive.google.com/uc?export=download&\"\n", - " \"id=1aspi_dayS37Rx_IvDuRrXoybdIC_-tkn\",\n", + " \"id=1Ym44i2sBLbNUiNgJaP3l1Obhf_NFenU7\",\n", " \"mem_binding_params.npz\")" ] }, diff --git a/nengo_dl/compat.py b/nengo_dl/compat.py index 213214dcd..d7a53bf7d 100644 --- a/nengo_dl/compat.py +++ b/nengo_dl/compat.py @@ -10,6 +10,8 @@ import nengo from nengo._vendor.scipy.sparse import linalg_interface, linalg_onenormest import tensorflow as tf +from tensorflow.python.keras import backend +from tensorflow.python.keras.engine import network # TensorFlow compatibility @@ -74,6 +76,37 @@ def filter(self, record): tf.get_logger().addFilter(TFLogFilter(err_on_deprecation=False)) +if LooseVersion(tf.__version__) < "2.2.0": + + def global_learning_phase(): + """Returns the global (eager) Keras learning phase.""" + + return backend._GRAPH_LEARNING_PHASES.get(backend._DUMMY_EAGER_GRAPH, None) + + +else: + + def global_learning_phase(): + """Returns the global (eager) Keras learning phase.""" + + return backend._GRAPH_LEARNING_PHASES.get(backend._DUMMY_EAGER_GRAPH.key, None) + + # monkeypatch to fix bug in TF2.2, see + # https://github.com/tensorflow/tensorflow/issues/37548 + old_conform = network.Network._conform_to_reference_input + + def _conform_to_reference_input(self, tensor, ref_input): + keras_history = getattr(tensor, "_keras_history", None) + + tensor = old_conform(self, tensor, ref_input) + + if keras_history is not None: + tensor._keras_history = keras_history + + return tensor + + network.Network._conform_to_reference_input = _conform_to_reference_input + # Nengo compatibility # monkeypatch fix for https://github.com/nengo/nengo/pull/1587 diff --git a/nengo_dl/config.py b/nengo_dl/config.py index dca7dae60..a0f6062cb 100644 --- a/nengo_dl/config.py +++ b/nengo_dl/config.py @@ -22,7 +22,7 @@ def configure_settings(**kwargs): Parameters ---------- trainable : bool or None - Adds a parameter to Nengo Ensembles/Connections/Networks that controls + Adds a parameter to Nengo Ensembles/Connections that controls whether or not they will be optimized by `.Simulator.fit`. Passing ``None`` will use the default ``nengo_dl`` trainable settings, or True/False will override the default for all objects. In either diff --git a/nengo_dl/converter.py b/nengo_dl/converter.py index 0e659707d..77891b1ce 100644 --- a/nengo_dl/converter.py +++ b/nengo_dl/converter.py @@ -712,7 +712,7 @@ def get_history(tensor): """ Returns the Keras history (layer/node_idx/tensor_idx) that defined this tensor. - This function contains additional logic so that if ``tensor`` is the output of + This function contains additional logic so that if ``tensor`` is associated with a Model then the history will trace into the internal layers of that Model (rather than skipping to the input of that Model, which is the default Keras history). @@ -735,11 +735,9 @@ def get_history(tensor): layer, node_index, tensor_index = tensor._keras_history while isinstance(layer, tf.keras.Model): - # models have an output Identity transform that stores the history that - # "skips" the internals of the model; we want to traverse into the internals - # of the model, so we go back to the input of that identity op (which - # is the real output tensor from the model) - assert tensor.op.type == "Identity" + # models may have internal tensors representing some transforms on the + # input/output of the model. but we want to know the actual internal layer + # that generated this tensor, so we skip past these "model" tensors tensor = tensor.op.inputs[0] layer, node_index, tensor_index = tensor._keras_history @@ -877,7 +875,8 @@ def trace_tensors(self, tensors, results=None): ------- results : list of ``tf.Tensor`` The same as the ``results`` parameter (returned so that the top-level call, - which may not have a reference to the ``results`` list can get the results). + which may not have a reference to the ``results`` list, can get the + results). """ # brief intro to the keras functional graph structure: # - a node represents the application of some layer to an input tensor diff --git a/nengo_dl/simulator.py b/nengo_dl/simulator.py index 781c38c78..db3738660 100644 --- a/nengo_dl/simulator.py +++ b/nengo_dl/simulator.py @@ -28,7 +28,6 @@ from nengo.utils.magic import decorator import numpy as np import tensorflow as tf -from tensorflow.python.keras import backend from nengo_dl import callbacks, compat, config, utils from nengo_dl.builder import NengoBuilder, NengoModel @@ -550,9 +549,7 @@ def _build_keras(self, progress=None): inputs, stateful=self.stateful, # if the global learning phase is set, use that - training=backend._GRAPH_LEARNING_PHASES.get( - backend._DUMMY_EAGER_GRAPH, None - ), + training=compat.global_learning_phase(), progress=progress, ) @@ -614,8 +611,8 @@ def soft_reset(self, include_trainable=False, include_probes=False): Parameters ---------- include_trainable : bool - If True, also reset any training that has been performed on - simulator parameters (e.g., connection weights). + If True, also reset any online or offline training that has been performed + on simulator parameters (e.g., connection weights). include_probes : bool If True, also clear probe data. @@ -1160,7 +1157,7 @@ def loss(self, *args, **kwargs): @require_open @with_self - def save_params(self, path, include_non_trainable=False): + def save_params(self, path, include_state=False, include_non_trainable=None): """ Save network parameters to the given ``path``. @@ -1168,10 +1165,11 @@ def save_params(self, path, include_non_trainable=False): ---------- path : str Filepath of parameter output file. - include_non_trainable : bool - If True (default False) also save information representing - non-trainable parameters of the network (this includes the internal - simulation state). + include_state : bool + If True (default False) also save the internal simulation state. + + .. versionchanged:: 3.1.1 + Renamed from ``include_non_trainable`` to ``include_state``. Notes ----- @@ -1180,19 +1178,24 @@ def save_params(self, path, include_non_trainable=False): `.get_nengo_params`. """ - vars = ( - self.keras_model.weights - if include_non_trainable - else self.keras_model.trainable_weights - ) + if include_non_trainable is not None: + warnings.warn( + "include_non_trainable is deprecated, use include_state instead", + DeprecationWarning, + ) + include_state = include_non_trainable + + params = list(self.keras_model.weights) + if include_state: + params.extend(self.tensor_graph.saved_state.values()) - np.savez_compressed(path + ".npz", *tf.keras.backend.batch_get_value(vars)) + np.savez_compressed(path + ".npz", *tf.keras.backend.batch_get_value(params)) logger.info("Model parameters saved to %s.npz", path) @require_open @with_self - def load_params(self, path, include_non_trainable=False): + def load_params(self, path, include_state=False, include_non_trainable=None): """ Load network parameters from the given ``path``. @@ -1200,10 +1203,11 @@ def load_params(self, path, include_non_trainable=False): ---------- path : str Filepath of parameter input file. - include_non_trainable : bool - If True (default False) also load information representing - non-trainable parameters of the network (this includes the internal - simulation state). + include_state : bool + If True (default False) also save the internal simulation state. + + .. versionchanged:: 3.1.1 + Renamed from ``include_non_trainable`` to ``include_state``. Notes ----- @@ -1212,20 +1216,25 @@ def load_params(self, path, include_non_trainable=False): `.get_nengo_params`. """ - vars = ( - self.keras_model.weights - if include_non_trainable - else self.keras_model.trainable_weights - ) + if include_non_trainable is not None: + warnings.warn( + "include_non_trainable is deprecated, use include_state instead", + DeprecationWarning, + ) + include_state = include_non_trainable + + params = list(self.keras_model.weights) + if include_state: + params.extend(self.tensor_graph.saved_state.values()) with np.load(path + ".npz") as vals: - if len(vars) != len(vals.files): + if len(params) != len(vals.files): raise SimulationError( "Number of saved parameters in %s (%d) != number of variables in " - "the model (%d)" % (path, len(vals.files), len(vars)) + "the model (%d)" % (path, len(vals.files), len(params)) ) tf.keras.backend.batch_set_value( - zip(vars, (vals["arr_%d" % i] for i in range(len(vals.files)))) + zip(params, (vals["arr_%d" % i] for i in range(len(vals.files)))) ) logger.info("Model parameters loaded from %s.npz", path) diff --git a/nengo_dl/tensor_graph.py b/nengo_dl/tensor_graph.py index a0cc636db..b5fe119f4 100644 --- a/nengo_dl/tensor_graph.py +++ b/nengo_dl/tensor_graph.py @@ -173,8 +173,8 @@ def __init__( logger.info("Optimized plan length: %d", len(self.plan)) logger.info( - "Number of base arrays: %d, %d", - *tuple(len(x) for x in self.base_arrays_init), + "Number of base arrays: (%s, %d), (%s, %d), (%s, %d)", + *tuple((k, len(x)) for k, x in self.base_arrays_init.items()), ) def build_inputs(self): @@ -256,16 +256,18 @@ def get_initializer(init_vals): with trackable.no_automatic_dependency_tracking_scope(self): self.base_params = OrderedDict() assert len(self.base_params) == 0 - for k, v in self.base_arrays_init[True].items(): - initializer, shape, dtype = get_initializer(v) - assert initializer is not None # trainable params should never be set - self.base_params[k] = self.add_weight( - initializer=initializer, - shape=shape, - dtype=dtype, - trainable=True, - name="base_params/%s_%s" % (dtype, "_".join(str(x) for x in shape)), - ) + for sig_type in ("trainable", "non_trainable"): + for k, v in self.base_arrays_init[sig_type].items(): + initializer, shape, dtype = get_initializer(v) + assert initializer is not None # params should never be set + self.base_params[k] = self.add_weight( + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=sig_type == "trainable", + name="base_params/%s_%s_%s" + % (sig_type, dtype, "_".join(str(x) for x in shape)), + ) logger.debug("created base param variables") logger.debug([str(x) for x in self.base_params.values()]) @@ -273,13 +275,13 @@ def get_initializer(init_vals): # variables to save the internal state of simulation between runs with trackable.no_automatic_dependency_tracking_scope(self): self.saved_state = OrderedDict() - for k, v in self.base_arrays_init[False].items(): + for k, v in self.base_arrays_init["state"].items(): initializer, shape, dtype = get_initializer(v) if initializer is not None: # don't need to save the state for signals where the initial value # doesn't matter - self.saved_state[k] = self.add_weight( - initializer=initializer, + self.saved_state[k] = tf.Variable( + initial_value=lambda: initializer(shape=shape, dtype=dtype), shape=shape, dtype=dtype, trainable=False, @@ -367,6 +369,17 @@ def unbuild(layer): tf.keras.backend.batch_set_value(zip(weight_sets, weight_vals)) + # initialize state variables (need to do this manually because we're not + # adding them to self.weights) + # note: don't need to do this in eager mode, since variables are + # initialized on creation + # TODO: why does this cause problems if it is done before the tensornode + # weight get/sets above? + if not context.executing_eagerly(): + tf.keras.backend.batch_get_value( + [var.initializer for var in self.saved_state.values()] + ) + # @tf.function # TODO: get this working? does this help? @tf.autograph.experimental.do_not_convert def call(self, inputs, training=None, progress=None, stateful=False): @@ -446,7 +459,7 @@ def call(self, inputs, training=None, progress=None, stateful=False): # build stage with progress.sub("build stage", max_value=len(self.plan) * self.unroll) as sub: - steps_run, probe_arrays, final_internal_state = ( + steps_run, probe_arrays, final_internal_state, final_base_params = ( self._build_loop(sub) if self.use_loop else self._build_no_loop(sub) ) @@ -455,6 +468,7 @@ def call(self, inputs, training=None, progress=None, stateful=False): self.steps_run = steps_run self.probe_arrays = probe_arrays self.final_internal_state = final_internal_state + self.final_base_params = final_base_params # logging logger.info( @@ -472,13 +486,31 @@ def call(self, inputs, training=None, progress=None, stateful=False): # number of steps, even if there are no output probes outputs = list(probe_arrays.values()) + [steps_run] + updates = [] if stateful: # update saved state - state_updates = [ + updates.extend( var.assign(val) for var, val in zip(self.saved_state.values(), final_internal_state) - ] - with tf.control_dependencies(state_updates): + ) + + # if any of the base params have changed (due to online learning rules) then we + # also need to assign those back to the original variable (so that their + # values will persist). any parameters targeted by online learning rules + # will be minibatched, so we only need to update the minibatched params. + for (key, var), val in zip(self.base_params.items(), final_base_params): + try: + minibatched = self.base_arrays_init["non_trainable"][key][-1] + except KeyError: + minibatched = self.base_arrays_init["trainable"][key][-1] + + if minibatched: + updates.append(var.assign(val)) + + logger.info("Number of variable updates: %d", len(updates)) + + if len(updates) > 0: + with tf.control_dependencies(updates): outputs = [tf.identity(x) for x in outputs] return outputs @@ -501,7 +533,7 @@ def _fill_bases(self, saved_state, base_params): self.signals.bases[key] = tf.identity(val) for key, val in base_params.items(): self.signals.bases[key] = tf.identity(val) - for key, (_, shapes, _, minibatched) in self.base_arrays_init[False].items(): + for key, (_, shapes, _, minibatched) in self.base_arrays_init["state"].items(): if key not in self.signals.bases: # no saved state for this base, so we just temporarily insert # the shape information so that future scatters will know @@ -607,8 +639,9 @@ def update_probes(probe_tensors, loop_i): probe_arrays[p] = x final_internal_state = loop_vars[3] + final_base_params = loop_vars[4] - return steps_run, probe_arrays, final_internal_state + return steps_run, probe_arrays, final_internal_state, final_base_params def _build_no_loop(self, progress): """ @@ -668,8 +701,9 @@ def update_probes(probe_tensors, _): final_internal_state = tuple( self.signals.bases[key] for key in self.saved_state ) + final_base_params = tuple(self.signals.bases[key] for key in self.base_params) - return steps_run, probe_arrays, final_internal_state + return steps_run, probe_arrays, final_internal_state, final_base_params def _build_inner_loop(self, loop_i, update_probes, progress): """ @@ -836,7 +870,14 @@ def mark_signals(self): Users can manually specify whether signals are trainable or not using the config system (e.g., - ``net.config[nengo.Ensemble].trainable = False``) + ``net.config[nengo.Ensemble].trainable = False``). + + The trainable attribute will be set to one of three values: + + - ``True``: Signal is trainable + - ``False``: Signal could be trainable, but has been set to non-trainable + (e.g., because the user manually configured that object not to be trainable). + - ``None``: Signal is never trainable (e.g., simulator state) """ def get_trainable(parent_configs, obj): @@ -845,19 +886,22 @@ def get_trainable(parent_configs, obj): if self.inference_only: return False - trainable = None + # default to 1 (so that we can distinguish between an object being + # set to trainable vs defaulting to trainable) + trainable = 1 # we go from top down (so lower level settings will override) for cfg in parent_configs: try: - trainable = getattr(cfg[obj], "trainable", trainable) + cfg_trainable = getattr(cfg[obj], "trainable", None) except ConfigError: # object not configured in this network config - pass + cfg_trainable = None - # default to 1 (so that we can distinguish between an object being - # set to trainable vs defaulting to trainable) - return 1 if trainable is None else trainable + if cfg_trainable is not None: + trainable = cfg_trainable + + return trainable def mark_network(parent_configs, net): """Recursively marks the signals for objects within each subnetwork.""" @@ -941,13 +985,13 @@ def mark_network(parent_configs, net): for obj, seed in self.model.seeds.items(): if isinstance(obj, Connection) and seed in probe_seeds: if compat.conn_has_weights(obj): - self.model.sig[obj]["weights"].trainable = False + self.model.sig[obj]["weights"].trainable = None self.model.sig[obj]["weights"].minibatched = False # time/step are not minibatched and not trainable - self.model.step.trainable = False + self.model.step.trainable = None self.model.step.minibatched = False - self.model.time.trainable = False + self.model.time.trainable = None self.model.time.minibatched = False # fill in defaults for all other signals @@ -956,7 +1000,7 @@ def mark_network(parent_configs, net): for op in self.model.operators: for sig in op.all_signals: if not hasattr(sig.base, "trainable"): - sig.base.trainable = False + sig.base.trainable = None if not hasattr(sig.base, "minibatched"): sig.base.minibatched = not sig.base.trainable @@ -980,7 +1024,13 @@ def create_signals(self, sigs): memory (e.g., output from `.graph_optimizer.order_signals`) """ - base_arrays = [OrderedDict(), OrderedDict()] + base_arrays = OrderedDict( + [ + ("trainable", OrderedDict()), + ("non_trainable", OrderedDict()), + ("state", OrderedDict()), + ] + ) curr_keys = {} sig_idxs = {s: i for i, s in enumerate(sigs)} @@ -1092,18 +1142,25 @@ def special_set(s, op): if sig.minibatched: shape = (self.minibatch_size,) + shape - if key in base_arrays[sig.trainable]: - base_arrays[sig.trainable][key][0].append(initial_value) - base_arrays[sig.trainable][key][1].append(shape) + if sig.trainable is None: + sig_type = "state" + elif sig.trainable: + sig_type = "trainable" + else: + sig_type = "non_trainable" + + if key in base_arrays[sig_type]: + base_arrays[sig_type][key][0].append(initial_value) + base_arrays[sig_type][key][1].append(shape) else: - base_arrays[sig.trainable][key] = [ + base_arrays[sig_type][key] = [ [initial_value], [shape], dtype, sig.minibatched, ] - n = sum(x[sig.minibatched] for x in base_arrays[sig.trainable][key][1]) + n = sum(x[sig.minibatched] for x in base_arrays[sig_type][key][1]) slices = [(n - shape[sig.minibatched], n)] tensor_sig = self.signals.get_tensor_signal( diff --git a/nengo_dl/tests/test_keras.py b/nengo_dl/tests/test_keras.py index 4ba225774..9e87a4127 100644 --- a/nengo_dl/tests/test_keras.py +++ b/nengo_dl/tests/test_keras.py @@ -7,6 +7,7 @@ import tensorflow as tf from nengo_dl import config, dists +from nengo_dl.tests import dummies @pytest.mark.parametrize("minibatch_size", (None, 1, 3)) @@ -351,3 +352,21 @@ def test_learning_phase_warning(Simulator): with tf.keras.backend.learning_phase_scope(1): with Simulator(net): pass + + +def test_save_load_weights(Simulator, tmpdir): + net = dummies.linear_net()[0] + + net.connections[0].transform = 2 + + with Simulator(net, minibatch_size=1) as sim0: + sim0.keras_model.save_weights(str(tmpdir.join("tmp"))) + + net.connections[0].transform = 3 + + with Simulator(net, minibatch_size=2) as sim1: + assert np.allclose(sim1.data[net.connections[0]].weights, 3) + + sim1.keras_model.load_weights(str(tmpdir.join("tmp"))) + + assert np.allclose(sim1.data[net.connections[0]].weights, 2) diff --git a/nengo_dl/tests/test_learning_rules.py b/nengo_dl/tests/test_learning_rules.py index 1577a8aa9..2b8cca1da 100644 --- a/nengo_dl/tests/test_learning_rules.py +++ b/nengo_dl/tests/test_learning_rules.py @@ -77,3 +77,39 @@ def test_merged_learning(Simulator, rule, weights, seed): for i in range(sim.minibatch_size): assert np.allclose(sim.data[p0][i], canonical[0]) assert np.allclose(sim.data[p1][i], canonical[1]) + + +def test_online_learning_reset(Simulator, tmpdir): + with nengo.Network() as net: + inp = nengo.Ensemble(10, 1) + out = nengo.Node(size_in=1) + conn = nengo.Connection(inp, out, learning_rule_type=nengo.PES(1)) + nengo.Connection(nengo.Node([1]), conn.learning_rule) + + with Simulator(net) as sim: + w0 = np.array(sim.data[conn].weights) + + sim.run(0.1, stateful=False) + + w1 = np.array(sim.data[conn].weights) + + sim.save_params(str(tmpdir.join("tmp"))) + + # test that learning has changed weights + assert not np.allclose(w0, w1) + + # test that soft reset does NOT reset the online learning weights + sim.soft_reset() + assert np.allclose(w1, sim.data[conn].weights) + + # test that full reset DOES reset the online learning weights + sim.reset() + assert np.allclose(w0, sim.data[conn].weights) + + # test that weights load correctly + with Simulator(net) as sim: + assert not np.allclose(w1, sim.data[conn].weights) + + sim.load_params(str(tmpdir.join("tmp"))) + + assert np.allclose(w1, sim.data[conn].weights) diff --git a/nengo_dl/tests/test_objectives.py b/nengo_dl/tests/test_objectives.py index 8554a357b..e36c94eeb 100644 --- a/nengo_dl/tests/test_objectives.py +++ b/nengo_dl/tests/test_objectives.py @@ -9,7 +9,6 @@ @pytest.mark.parametrize("axis, order", [(None, 1), (None, "euclidean"), (1, 2)]) -@pytest.mark.training def test_regularize(axis, order, rng): x_init = rng.randn(2, 3, 4, 5) diff --git a/nengo_dl/tests/test_simulator.py b/nengo_dl/tests/test_simulator.py index c163563d4..61cb5a50a 100644 --- a/nengo_dl/tests/test_simulator.py +++ b/nengo_dl/tests/test_simulator.py @@ -515,9 +515,8 @@ def test_generate_inputs(Simulator, seed): sim._generate_inputs(data={p[0]: np.zeros((minibatch_size, n_steps, 1))}) -@pytest.mark.parametrize("include_non_trainable", (True, False)) -@pytest.mark.training -def test_save_load_params(Simulator, include_non_trainable, tmpdir): +@pytest.mark.parametrize("include_state", (True, False)) +def test_save_load_params(Simulator, include_state, tmpdir): def get_network(seed): with nengo.Network(seed=seed) as net: configure_settings(simplifications=[]) @@ -543,7 +542,7 @@ def get_network(seed): sim_save.run_steps(10) - sim_save.save_params(str(tmpdir), include_non_trainable=include_non_trainable) + sim_save.save_params(str(tmpdir), include_state=include_state) sim_save.run_steps(10) @@ -560,24 +559,20 @@ def get_network(seed): ) assert not np.allclose(weights0, weights1) assert not np.allclose(enc0, enc1) + assert not np.allclose(bias0, bias1) pre_model = sim_load.keras_model - sim_load.load_params(str(tmpdir), include_non_trainable=include_non_trainable) + sim_load.load_params(str(tmpdir), include_state=include_state) weights2, enc2, bias2 = sim_load.data.get_params( (conn1, "weights"), (ens1, "encoders"), (ens1, "bias") ) - # check if weights match + # check if params match assert np.allclose(weights0, weights2) - - if include_non_trainable: - assert np.allclose(enc0, enc2) - assert np.allclose(bias0, bias2) - else: - assert np.allclose(enc1, enc2) - assert np.allclose(bias1, bias2) + assert np.allclose(enc0, enc2) + assert np.allclose(bias0, bias2) # check if a new model was created or one was modified in-place assert sim_load.keras_model is pre_model @@ -586,16 +581,45 @@ def get_network(seed): sim_load.run_steps(10) # check if simulation state resumed correctly - if include_non_trainable: + if include_state: + # state saved, so we should match the point at which that state was saved assert np.allclose(sim_load.data[p1], sim_save.data[p0][10:]) else: - assert not np.allclose(sim_load.data[p1], sim_save.data[p0][10:]) + # state not saved, but other seeded params are, so we should match the first + # timesteps of `sim_save` (despite the networks not having the same seeds) + assert np.allclose(sim_load.data[p1], sim_save.data[p0][:10]) with Simulator(nengo.Network()) as sim: with pytest.raises(SimulationError, match="!= number of variables"): sim.load_params(str(tmpdir)) +def test_save_load_params_deprecation(Simulator, tmpdir): + with nengo.Network() as net: + a = nengo.Node([1]) + p = nengo.Probe(a, synapse=0.1) + + with Simulator(net) as sim0: + sim0.run_steps(5) + + with pytest.warns( + DeprecationWarning, match="include_non_trainable is deprecated" + ): + sim0.save_params(str(tmpdir.join("tmp")), include_non_trainable=True) + + sim0.run_steps(5) + + with Simulator(net) as sim1: + with pytest.warns( + DeprecationWarning, match="include_non_trainable is deprecated" + ): + sim1.load_params(str(tmpdir.join("tmp")), include_non_trainable=True) + + sim1.run_steps(5) + + assert np.allclose(sim0.data[p][-5:], sim1.data[p]) + + def test_model_passing(Simulator, seed): # make sure that passing a built model to the Simulator works properly @@ -785,16 +809,18 @@ def test_profile(Simulator, mode, tmpdir): net, a, p = dummies.linear_net() with Simulator(net) as sim: + # note: TensorFlow bug if using profile_batch=1, see + # https://github.com/tensorflow/tensorflow/issues/37543 callback = callbacks.TensorBoard( - log_dir=str(tmpdir.join("profile")), profile_batch=1 + log_dir=str(tmpdir.join("profile")), profile_batch=2 ) if mode == "predict": - sim.predict(n_steps=5, callbacks=[callback]) + sim.predict(np.zeros((2, 5, 1)), callbacks=[callback]) else: sim.compile(tf.optimizers.SGD(1), loss=tf.losses.mse) sim.fit( - {a: np.zeros((1, 5, 1))}, {p: np.zeros((1, 5, 1))}, callbacks=[callback] + {a: np.zeros((2, 5, 1))}, {p: np.zeros((2, 5, 1))}, callbacks=[callback] ) assert os.path.exists(str(tmpdir.join("profile", "train"))) @@ -1120,10 +1146,7 @@ def test_simulation_data(Simulator, seed): sig = sim.model.sig[a]["encoders"] tensor_sig = sim.tensor_graph.signals[sig] - if sim.tensor_graph.inference_only: - base = sim.tensor_graph.saved_state[tensor_sig.key] - else: - base = sim.tensor_graph.base_params[tensor_sig.key] + base = sim.tensor_graph.base_params[tensor_sig.key] tf.keras.backend.set_value( base, np.ones(base.shape, dtype=base.dtype.as_numpy_dtype()) ) diff --git a/nengo_dl/tests/test_tensor_graph.py b/nengo_dl/tests/test_tensor_graph.py index 9c5aa7572..7c833b06d 100644 --- a/nengo_dl/tests/test_tensor_graph.py +++ b/nengo_dl/tests/test_tensor_graph.py @@ -1,5 +1,7 @@ # pylint: disable=missing-docstring +import logging + import nengo from nengo.builder.operator import Reset import numpy as np @@ -224,7 +226,7 @@ def test_signal_order_deterministic(Simulator, seed): with Simulator(net, seed=seed) as sim1: with Simulator(net, seed=seed) as sim2: - for trainable in (True, False): + for trainable in ("trainable", "non_trainable", "state"): for v, v2 in zip( sim1.tensor_graph.base_arrays_init[trainable].values(), sim2.tensor_graph.base_arrays_init[trainable].values(), @@ -292,11 +294,11 @@ def test_create_signals(): plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) - assert graph.base_arrays_init[False][graph.signals[sigs[0]].key][1] == [ + assert graph.base_arrays_init["non_trainable"][graph.signals[sigs[0]].key][1] == [ (10, 10), (10, 5), ] - assert graph.base_arrays_init[False][graph.signals[sigs[2]].key][1] == [ + assert graph.base_arrays_init["non_trainable"][graph.signals[sigs[2]].key][1] == [ (10, 10, 1), (10, 5, 1), ] @@ -314,8 +316,11 @@ def test_create_signals(): plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) - assert graph.base_arrays_init[True][graph.signals[sigs[0]].key][1] == [(1,), (1,)] - assert graph.base_arrays_init[False][graph.signals[sigs[2]].key][1] == [ + assert graph.base_arrays_init["trainable"][graph.signals[sigs[0]].key][1] == [ + (1,), + (1,), + ] + assert graph.base_arrays_init["non_trainable"][graph.signals[sigs[2]].key][1] == [ (10, 1), (10, 1), ] @@ -328,14 +333,17 @@ def test_create_signals(): plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) - assert list(graph.base_arrays_init[False].values())[0][1] == [(10, 1), (10, 4)] + assert list(graph.base_arrays_init["non_trainable"].values())[0][1] == [ + (10, 1), + (10, 4), + ] # check that boolean signals are handled correctly sigs = [dummies.Signal(dtype=np.bool, shape=())] plan = [(dummies.Op(reads=sigs),)] graph = dummies.TensorGraph(plan, tf.float32, 1) graph.create_signals(sigs) - assert list(graph.base_arrays_init[False].values())[0][2] == "bool" + assert list(graph.base_arrays_init["non_trainable"].values())[0][2] == "bool" def test_create_signals_views(): @@ -347,7 +355,10 @@ def test_create_signals_views(): plan = [tuple(dummies.Op(reads=[x]) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs[2:]) - assert list(graph.base_arrays_init[False].values())[0][1] == [(10, 4), (10, 4)] + assert list(graph.base_arrays_init["non_trainable"].values())[0][1] == [ + (10, 4), + (10, 4), + ] assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key @@ -391,7 +402,7 @@ def test_create_signals_partition(): plan = [tuple(Reset(x) for x in sigs)] graph = dummies.TensorGraph(plan, tf.float32, 10) graph.create_signals(sigs) - assert len(graph.base_arrays_init[False]) == 4 + assert len(graph.base_arrays_init["non_trainable"]) == 4 @pytest.mark.parametrize("use_loop", (True, False)) @@ -465,7 +476,12 @@ def test_build(trainable, rng): graph.create_signals(sigs) graph.build() - assert len(graph.weights) == 3 + if trainable: + assert len(graph.trainable_weights) == 3 + assert len(graph.non_trainable_weights) == 0 + else: + assert len(graph.trainable_weights) == 0 + assert len(graph.non_trainable_weights) == 3 init0 = graph.weights[0].numpy() assert init0.shape == (5, 1) if trainable else (16, 5, 1) @@ -484,3 +500,37 @@ def test_build(trainable, rng): assert init2.shape == (16, 13, 1) assert np.allclose(init2[:, :6], sigs[4].initial_value) assert np.allclose(init2[:, 6:], sigs[5].initial_value) + + +@pytest.mark.parametrize("use_loop", (True, False)) +def test_conditional_update(Simulator, use_loop, caplog): + caplog.set_level(logging.INFO) + + with nengo.Network() as net: + config.configure_settings(stateful=False, use_loop=use_loop) + + a = nengo.Ensemble(10, 1) + b = nengo.Node(size_in=1) + conn = nengo.Connection(a, b) + + with Simulator(net): + pass + + assert "Number of variable updates: 0" in caplog.text + caplog.clear() + + conn.learning_rule_type = nengo.PES() + + with Simulator(net): + pass + + assert "Number of variable updates: 1" in caplog.text + caplog.clear() + + with net: + config.configure_settings(trainable=True) + + with Simulator(net): + pass + + assert "Number of variable updates: 1" in caplog.text diff --git a/nengo_dl/tests/test_tensor_node.py b/nengo_dl/tests/test_tensor_node.py index 5af8fb590..36baffb65 100644 --- a/nengo_dl/tests/test_tensor_node.py +++ b/nengo_dl/tests/test_tensor_node.py @@ -262,11 +262,13 @@ def call(self, x): # note: when inference-only=True the weights will be marked as non-trainable if sim.tensor_graph.inference_only: - assert len(sim.keras_model.non_trainable_variables) == 4 + assert len(sim.tensor_graph.saved_state) == 2 + assert len(sim.keras_model.non_trainable_variables) == 2 assert len(sim.keras_model.trainable_variables) == 0 - vars = sim.keras_model.non_trainable_variables[-2:] + vars = sim.keras_model.non_trainable_variables else: - assert len(sim.keras_model.non_trainable_variables) == 2 + assert len(sim.tensor_graph.saved_state) == 2 + assert len(sim.keras_model.non_trainable_variables) == 0 assert len(sim.keras_model.trainable_variables) == 2 vars = sim.keras_model.trainable_variables