diff --git a/CHANGES.rst b/CHANGES.rst index 4d6666863..a697ba014 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -34,6 +34,7 @@ Release history - Added ``nengo_dl.LeakyReLU`` and ``nengo_dl.SpikingLeakyReLU`` neuron models. (`#126`_) - Added support for leaky ReLU Keras layers to ``nengo_dl.Converter``. (`#126`_) +- Added a new ``remove_reset_incs`` graph simplification step. (`#129`_) **Changed** @@ -52,6 +53,10 @@ Release history (see `Nengo#1591`_). Note that this may change the number of trainable parameters in a network as the scalar default ``transform=1`` weights on non-Ensemble connections will no longer be present. (`#128`_) +- Re-enabled the ``remove_constant_copies`` graph simplification by default. (`#129`_) +- Reduced the amount of state that needs to be stored in the simulation. (`#129`_) +- Added more information to the error message when loading saved parameters that + don't match the current model. (`#129`_) **Fixed** @@ -68,6 +73,7 @@ Release history .. _#119: https://github.com/nengo/nengo-dl/pull/119 .. _#126: https://github.com/nengo/nengo-dl/pull/126 .. _#128: https://github.com/nengo/nengo-dl/pull/128 +.. _#129: https://github.com/nengo/nengo-dl/pull/129 .. _#136: https://github.com/nengo/nengo-dl/pull/136 .. _Nengo#1591: https://github.com/nengo/nengo/pull/1591 diff --git a/docs/config.rst b/docs/config.rst index 552d79ca1..ae78e2eb6 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -175,6 +175,8 @@ disable sorting via with nengo.Network() as net: nengo_dl.configure_settings(sorter=noop_order_signals) +.. _config-simplifications: + simplifications --------------- diff --git a/docs/examples/spa-memory.ipynb b/docs/examples/spa-memory.ipynb index 1bb5edfb9..c33f6091d 100644 --- a/docs/examples/spa-memory.ipynb +++ b/docs/examples/spa-memory.ipynb @@ -592,7 +592,7 @@ " # download pretrained parameters\n", " urlretrieve(\n", " \"https://drive.google.com/uc?export=download&\"\n", - " \"id=182toR0aSWv1uExA7F5kn8t2lEU1SVrx1\",\n", + " \"id=1aspi_dayS37Rx_IvDuRrXoybdIC_-tkn\",\n", " \"mem_binding_params.npz\")" ] }, diff --git a/nengo_dl/benchmarks.py b/nengo_dl/benchmarks.py index deff7ef82..401763670 100644 --- a/nengo_dl/benchmarks.py +++ b/nengo_dl/benchmarks.py @@ -649,7 +649,7 @@ def run_profile( sim.minibatch_size * n_batches, n_steps, net.inp.size_out ) } - else: + elif hasattr(net, "inp_a"): x = { net.inp_a: np.random.randn( sim.minibatch_size * n_batches, n_steps, net.inp_a.size_out @@ -658,6 +658,8 @@ def run_profile( sim.minibatch_size * n_batches, n_steps, net.inp_b.size_out ), } + else: + x = None if train: y = { @@ -670,14 +672,14 @@ def run_profile( # run once to eliminate startup overhead start = timeit.default_timer() - sim.fit(x, y, epochs=1) + sim.fit(x, y, epochs=1, n_steps=n_steps) print("Warmup time:", timeit.default_timer() - start) for _ in range(reps): if do_profile: profiler.start() start = timeit.default_timer() - sim.fit(x, y, epochs=1) + sim.fit(x, y, epochs=1, n_steps=n_steps) exec_time = min(timeit.default_timer() - start, exec_time) if do_profile: profiler.save("profile", profiler.stop()) diff --git a/nengo_dl/config.py b/nengo_dl/config.py index 3ef7d9fb0..dca7dae60 100644 --- a/nengo_dl/config.py +++ b/nengo_dl/config.py @@ -41,7 +41,8 @@ def configure_settings(**kwargs): simplifications: list of graph simplification functions Pass a list of `graph simplification functions `_ to change - the default simplifications applied. + the default simplifications applied. The default list of simplifications + can be found in ``nengo_dl.graph_optimizer.default_simplifications``. inference_only : bool Set to True if the network will only be run in inference mode (i.e., no calls to `.Simulator.fit`). This may result in a small diff --git a/nengo_dl/graph_optimizer.py b/nengo_dl/graph_optimizer.py index 136248cf8..46e98ee31 100644 --- a/nengo_dl/graph_optimizer.py +++ b/nengo_dl/graph_optimizer.py @@ -5,21 +5,25 @@ from collections import OrderedDict, defaultdict import logging +import warnings -from nengo.builder.operator import ElementwiseInc, DotInc, Reset, Copy -from nengo.exceptions import BuildError +from nengo.builder.operator import Copy, DotInc, ElementwiseInc, Reset, SparseDotInc +from nengo.builder.processes import SimProcess +from nengo.builder.transforms import ConvInc +from nengo.exceptions import BuildError, SignalError from nengo.transforms import SparseMatrix -from nengo.utils.graphs import toposort, BidirectionalDAG +from nengo.utils.graphs import BidirectionalDAG, toposort from nengo.utils.simulator import operator_dependency_graph import numpy as np from nengo_dl import ( - process_builders, builder, - tensor_node, - op_builders, learning_rule_builders, neuron_builders, + op_builders, + process_builders, + tensor_node, + transform_builders, ) logger = logging.getLogger(__name__) @@ -1051,7 +1055,7 @@ def remove_unmodified_resets(operators): If a signal is reset, but never inced/updated after that, we can just set the default signal value to the reset value and remove the reset. Note: this wouldn't normally happen, but it can happen if we removed - some of the incs (e.g. in remove_zero_incs). + some of the incs (e.g. in `.remove_zero_incs`). Parameters ---------- @@ -1087,7 +1091,7 @@ def remove_zero_incs(operators): Remove any operators where we know the input (and therefore output) is zero. - If the input to a DotInc/ElementwiseInc/Copy is zero then we know + If the input to a DotInc/ElementwiseInc/Copy/ConvInc is zero then we know that the output of the op will be zero, so we can just get rid of it. Parameters @@ -1118,7 +1122,7 @@ def all_zero(sig): new_operators = [] for op in operators: - if isinstance(op, (DotInc, ElementwiseInc, Copy)): + if isinstance(op, (DotInc, ElementwiseInc, Copy, ConvInc)): for src in op.reads: # check if the input is the output of a Node (in which case the # value might change, so we should never get rid of this op). @@ -1159,47 +1163,114 @@ def all_zero(sig): return new_operators -# def remove_reset_incs(operators): -# """Replace ``y=Reset(0) + x`` with ``y=x``. -# -# If a signal is Reset and Inc'd, we can change that to a Set that combines -# the two ops (note: any other incs of that signal can proceed as normal) -# -# Parameters -# ---------- -# operators : list of `~nengo.builder.Operator` -# operators in the model -# -# Returns -# ------- -# new_operators : list of `~nengo.builder.Operator` -# modified list of operators -# -# Notes -# ----- -# In practice, this modification seems to hurt more than it helps. Inc -# operators are cheaper to compute the gradient for, and changing Incs to -# Incs and Sets splits up the Inc merge groups. -# """ -# -# dg = operator_dependency_graph(operators) -# -# for op in operators: -# if type(op) == Reset and np.all(op.value == 0): -# incers = [succ for succ in dg[op] if op.dst in succ.incs] -# if len(incers) > 0: -# del dg[op] -# incer = incers[0] -# incer.sets.extend(incer.incs) -# incer.incs = [] -# if isinstance(incer, ElementwiseInc): -# incer.__class__ = op_builders.ElementwiseSet -# elif isinstance(incer, DotInc): -# incer.__class__ = op_builders.DotSet -# else: -# incer.inc = False -# -# return list(dg.keys()) +def remove_reset_incs(operators): + """Replace ``y=Reset(0) + x`` with ``y=x``. + + If a signal is Reset and Inc'd, we can change that to a Set that combines + the two ops (note: any other incs of that signal can proceed as normal) + + Parameters + ---------- + operators : list of `~nengo.builder.Operator` + operators in the model + + Returns + ------- + new_operators : list of `~nengo.builder.Operator` + modified list of operators + + Notes + ----- + In practice, this modification can hurt more than it helps. Inc + operators are cheaper to compute the gradient for, and changing Incs to + Incs and Sets splits up the Inc merge groups. It tends to provide the + most value for models consisting of long linear chains of objects. + """ + + # note: not using signal_io_dict because we care about exact signal matches + # in this case, not bases + valid_inc_types = [ + ElementwiseInc, + SparseDotInc, + DotInc, + Copy, + SimProcess, + ConvInc, + op_builders.ResetInc, + ] + incs = defaultdict(list) + for op in operators: + for s in op.incs: + if type(op) not in valid_inc_types: + warnings.warn("Unknown incer type %s in remove_reset_incs" % type(op)) + elif getattr(op, "dst_slice", None) is None: + # don't include copy ops with dst_slice, as they aren't incrementing + # the whole signal + incs[s].append(op) + + new_operators = [] + ignore = [] + for op in operators: + if op in ignore: + # don't add this op to new_operators + ignore.remove(op) + continue + + if type(op) == Reset and np.all(op.value == 0) and len(incs[op.dst]) > 0: + # pick the first op that increments dst, and change it to a set + # (to take the place of the reset) + incer = incs[op.dst][0] + + if isinstance(incer, ElementwiseInc): + setter = op_builders.ElementwiseSet( + incer.A, incer.X, incer.Y, tag=incer.tag + ) + elif isinstance(incer, SparseDotInc): + # note: this needs to come before the DotInc condition, since + # SparseDotInc is a subclass of DotInc + setter = op_builders.SparseDotSet( + incer.A, incer.X, incer.Y, tag=incer.tag + ) + elif isinstance(incer, DotInc): + setter = op_builders.DotSet(incer.A, incer.X, incer.Y, tag=incer.tag) + elif isinstance(incer, Copy): + setter = Copy( + incer.src, + incer.dst, + src_slice=incer.src_slice, + dst_slice=incer.dst_slice, + inc=False, + tag=incer.tag, + ) + elif isinstance(incer, SimProcess): + setter = SimProcess( + incer.process, + incer.input, + incer.output, + incer.t, + mode="set", + state=incer.state, + tag=incer.tag, + ) + elif isinstance(incer, ConvInc): + setter = transform_builders.ConvSet( + incer.W, incer.X, incer.Y, incer.conv, tag=incer.tag + ) + elif isinstance(incer, op_builders.ResetInc): + setter = Reset(incer.dst, tag=incer.tag) + # setting the value separately to bypass float casting in Reset init + setter.value = incer.value + + # replace incer with setter + try: + new_operators.remove(incer) + except ValueError: + ignore.append(incer) + new_operators.append(setter) + else: + new_operators.append(op) + + return new_operators def remove_constant_copies(operators): @@ -1224,10 +1295,24 @@ def remove_constant_copies(operators): sets, incs, _, updates = signal_io_dicts(operators) new_operators = [] + ignore = [] for op in operators: + if op in ignore: + # don't add this op to new_operators + ignore.remove(op) + continue + if isinstance(op, Copy): src = op.src + try: + dst = op.dst if op.dst_slice is None else op.dst[op.dst_slice] + except SignalError: + # Copy is implementing advanced indexing, which cannot be applied + # directly to a signal + new_operators.append(op) + continue + # check if the input is the output of a Node (in which case the # value might change, so we should never get rid of this op). # checking the name of the signal seems a bit fragile, but I can't @@ -1238,8 +1323,8 @@ def remove_constant_copies(operators): pred = sets[src.base] + incs[src.base] if len(pred) == 0 and not op.src.trainable and len(updates[src.base]) == 0: - # no predecessors means that the src is constant. but we also - # need to keep the bias signal if it is trainable (since + # no predecessors means that the src is constant. but we still + # need to keep the signal if it is trainable (since # changing it to a reset op would make it not trainable). # we also need to check if anything is updating src (which # wouldn't be in the predecessors). @@ -1248,24 +1333,21 @@ def remove_constant_copies(operators): # if the only predecessor is a Reset, we can just use that # set value val = pred[0].value + + # remove the reset operator try: new_operators.remove(pred[0]) except ValueError: - operators.remove(pred[0]) + ignore.append(pred[0]) else: new_operators.append(op) continue - new_op = Reset(op.dst if op.dst_slice is None else op.dst[op.dst_slice]) + new_op = op_builders.ResetInc(dst) if op.inc else Reset(dst) # note: we need to set the value separately to bypass the float() # casting in Reset new_op.value = val - if op.inc: - new_op.incs.extend(new_op.sets) - new_op.sets = [] - new_op.__class__ = op_builders.ResetInc - new_operators.append(new_op) else: new_operators.append(op) @@ -1340,7 +1422,14 @@ def is_identity(x, sig): if identity_input: other_src = [x for x in op.reads if x is not src][0] - new_operators.append(Copy(other_src, op.Y, inc=len(op.incs) > 0)) + new_operators.append( + Copy( + other_src, + op.Y, + inc=len(op.incs) > 0, + tag="%s.identity_mul" % op.tag, + ) + ) break else: new_operators.append(op) @@ -1426,3 +1515,16 @@ def display_signal_blocks(operators, all_signals): output[n, sig_group] = str(i) return "\n".join("".join(line) for line in output) + + +# the default simplifications that will be applied. exposed as a variable here to make +# it easier for users to add something to the defaults (e.g. +# `simplifications=default_simplifications + (a_thing,)`), rather than having to +# manually specify and track changes to the defaults +default_simplifications = ( + remove_unmodified_resets, + remove_zero_incs, + remove_identity_muls, + remove_constant_copies, + remove_reset_incs, +) diff --git a/nengo_dl/learning_rule_builders.py b/nengo_dl/learning_rule_builders.py index 13d2817fd..2be8fd8b0 100644 --- a/nengo_dl/learning_rule_builders.py +++ b/nengo_dl/learning_rule_builders.py @@ -2,7 +2,6 @@ Build classes for Nengo learning rule operators. """ -from nengo import rc as nengo_rc from nengo.builder import Signal from nengo.builder.learning_rules import ( SimBCM, @@ -206,7 +205,7 @@ def build_pes(model, pes, rule): conn = rule.connection # Create input error signal - error = Signal(np.zeros(rule.size_in, dtype=nengo_rc.float_dtype), name="PES:error") + error = Signal(shape=(rule.size_in,), name="PES:error") model.add_op(Reset(error)) model.sig[rule]["in"] = error # error connection will attach here @@ -222,17 +221,13 @@ def build_pes(model, pes, rule): # in order to avoid slicing encoders along an axis > 0, we pad # `error` out to the full base dimensionality and then do the # dotinc with the full encoder matrix - padded_error = Signal( - np.zeros(encoders.shape[1], dtype=nengo_rc.float_dtype) - ) + padded_error = Signal(shape=(encoders.shape[1],)) model.add_op(Copy(error, padded_error, dst_slice=conn.post_slice)) else: padded_error = error # error = dot(encoders, error) - local_error = Signal( - np.zeros(post.n_neurons, dtype=nengo_rc.float_dtype), name="PES:encoded" - ) + local_error = Signal(shape=(post.n_neurons,)) model.add_op(Reset(local_error)) model.add_op(DotInc(encoders, padded_error, local_error, tag="PES:encode")) else: diff --git a/nengo_dl/op_builders.py b/nengo_dl/op_builders.py index 53e014d04..cebd69861 100644 --- a/nengo_dl/op_builders.py +++ b/nengo_dl/op_builders.py @@ -28,15 +28,72 @@ class ResetInc(Reset): """ - A version of Reset that increments the target value rather than setting it. + A version of `~nengo.builder.operator.Reset` that increments the target value + rather than overwriting. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.incs, self.sets = self.sets, self.incs + @property def dst(self): - """Overridden to return from incs rather than sets.""" + """dst is stored in ``incs`` rather than ``sets``.""" return self.incs[0] +class ElementwiseSet(ElementwiseInc): + """ + A version of `~nengo.builder.operator.ElementwiseInc` that overwrites the target + rather than incrementing. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.incs, self.sets = self.sets, self.incs + + @property + def Y(self): + """Y is stored in ``sets`` rather than ``incs``.""" + return self.sets[0] + + +class DotSet(DotInc): + """ + A version of `~nengo.builder.operator.DotInc` that overwrites the target rather + than incrementing. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.incs, self.sets = self.sets, self.incs + + @property + def Y(self): + """Y is stored in ``sets`` rather than ``incs``.""" + return self.sets[0] + + +class SparseDotSet(SparseDotInc): + """ + A version of `~nengo.builder.operator.SparseDotInc` that overwrites the target + rather than incrementing. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.incs, self.sets = self.sets, self.incs + + @property + def Y(self): + """Y is stored in ``sets`` rather than ``incs``.""" + return self.sets[0] + + @Builder.register(Reset) @Builder.register(ResetInc) class ResetBuilder(OpBuilder): @@ -117,14 +174,8 @@ def mergeable(x, y): return True -# class ElementwiseSet(ElementwiseInc): -# @property -# def Y(self): -# return self.sets[0] - - @Builder.register(ElementwiseInc) -# @Builder.register(ElementwiseSet) +@Builder.register(ElementwiseSet) class ElementwiseIncBuilder(OpBuilder): """ Build a group of `~nengo.builder.operator.ElementwiseInc` operators. @@ -232,14 +283,8 @@ def sparse_matmul(A_indices, A_data, A_shape, X, transpose_x=False): return dot -# class DotSet(DotInc): -# @property -# def Y(self): -# return self.sets[0] - - @Builder.register(DotInc) -# @Builder.register(DotSet) +@Builder.register(DotSet) class DotIncBuilder(OpBuilder): """ Build a group of `~nengo.builder.operator.DotInc` operators. @@ -396,6 +441,7 @@ def mergeable(x, y): @Builder.register(SparseDotInc) +@Builder.register(SparseDotSet) class SparseDotIncBuilder(OpBuilder): """ Build a group of `~nengo.builder.operator.SparseDotInc` operators. @@ -404,6 +450,8 @@ class SparseDotIncBuilder(OpBuilder): def __init__(self, ops, signals, config): super().__init__(ops, signals, config) + self.mode = "inc" if type(ops[0]) == SparseDotInc else "update" + self.Y_data = signals.combine([op.Y for op in ops]) # group all the A's and X's @@ -457,7 +505,7 @@ def build_step(self, signals): dot.set_shape((signals.minibatch_size,) + self.Y_data.shape) - signals.scatter(self.Y_data, dot, mode="inc") + signals.scatter(self.Y_data, dot, mode=self.mode) @staticmethod def mergeable(x, y): diff --git a/nengo_dl/signals.py b/nengo_dl/signals.py index ad680cab6..4fc03c73c 100644 --- a/nengo_dl/signals.py +++ b/nengo_dl/signals.py @@ -318,7 +318,9 @@ def scatter(self, dst, val, mode="update"): logger.debug("values %s", val) logger.debug("dst %s", dst) logger.debug("slices %s", dst.slices) - logger.debug("dst base %s", self.bases[dst.key]) + logger.debug( + "dst base %s", self.bases[dst.key] if dst.key in self.bases else None + ) if val.dtype.is_floating and val.dtype.base_dtype != self.dtype: raise BuildError( @@ -326,26 +328,29 @@ def scatter(self, dst, val, mode="update"): "be %s." % (val.dtype.base_dtype, self.dtype) ) + # should never be writing to a variable + if isinstance(self.bases[dst.key], tf.Variable): + raise BuildError("Scatter target should not be a Variable") + + if isinstance(self.bases[dst.key], tuple): + # this is the first set operation for this signal + assert mode == "update" + + base_shape = self.bases[dst.key] + var = None + else: + self.bases[dst.key].shape.assert_is_fully_defined() + base_shape = self.bases[dst.key].shape + var = self.bases[dst.key] + # align val shape with dst base shape - self.bases[dst.key].shape.assert_is_fully_defined() val.shape.assert_is_fully_defined() - dst_shape = self.bases[dst.key].shape.as_list() + dst_shape = list(base_shape) dst_shape[dst.minibatched] = dst.shape[0] if val.shape != dst_shape: val = tf.reshape(val, dst.tf_shape) - var = self.bases[dst.key] - - # should never be writing to a variable - if isinstance(var, tf.Variable): - raise BuildError("Scatter target should not be a Variable") - - if ( - len(dst.slices) == 1 - and var.shape.is_compatible_with(val.shape) - and dst.slices[0][0] == 0 - and dst.slices[0][1] == var.shape[dst.minibatched] - ): + if len(dst.slices) == 1 and val.shape == base_shape: if mode == "inc": result = var + val self.write_types["assign_add"] += 1 @@ -356,7 +361,10 @@ def scatter(self, dst, val, mode="update"): result = tf.tensor_scatter_nd_add(var, dst.tf_indices_nd, val) self.write_types["scatter_add"] += 1 else: - result = tf.tensor_scatter_nd_update(var, dst.tf_indices_nd, val) + if var is None: + result = tf.scatter_nd(dst.tf_indices_nd, val, shape=base_shape) + else: + result = tf.tensor_scatter_nd_update(var, dst.tf_indices_nd, val) self.write_types["scatter_update"] += 1 self.bases[dst.key] = result @@ -390,6 +398,8 @@ def gather(self, src, force_copy=False): var = self.bases[src.key] + assert isinstance(var, tf.Tensor) + # we prefer to get the data via `strided_slice` or `identity` if # possible, as it is more efficient if force_copy or len(src.slices) > 1: diff --git a/nengo_dl/simulator.py b/nengo_dl/simulator.py index d95b0b971..7cd14b1c8 100644 --- a/nengo_dl/simulator.py +++ b/nengo_dl/simulator.py @@ -1208,6 +1208,11 @@ def load_params(self, path, include_non_trainable=False): ) with np.load(path + ".npz") as vals: + if len(vars) != len(vals.files): + raise SimulationError( + "Number of saved parameters in %s (%d) != number of variables in " + "the model (%d)" % (path, len(vals.files), len(vars)) + ) tf.keras.backend.batch_set_value( zip(vars, (vals["arr_%d" % i] for i in range(len(vals.files)))) ) @@ -1935,16 +1940,12 @@ def _check_data(self, data, batch_size=None, n_steps=None, nodes=True): @with_self def _update_steps(self): - if not hasattr(self, "_step_tensors"): - # cache these so we aren't adding new ops every time we call this function - self._step_tensors = [ - self.tensor_graph.get_tensor(self.model.step), - self.tensor_graph.get_tensor(self.model.time), - ] + if not hasattr(self, "_step_tensor"): + # cache this so we aren't adding new ops every time we call this function + self._step_tensor = self.tensor_graph.get_tensor(self.model.step) - self._n_steps, self._time = [ - x.item() for x in tf.keras.backend.batch_get_value(self._step_tensors) - ] + self._n_steps = tf.keras.backend.get_value(self._step_tensor).item() + self._time = self._n_steps * self.dt @property def dt(self): diff --git a/nengo_dl/tensor_graph.py b/nengo_dl/tensor_graph.py index e604c087d..a0cc636db 100644 --- a/nengo_dl/tensor_graph.py +++ b/nengo_dl/tensor_graph.py @@ -8,11 +8,13 @@ import warnings from nengo import Connection, Process -from nengo.builder.operator import Reset, SimPyFunc +from nengo.builder.neurons import SimNeurons +from nengo.builder.operator import Reset, SimPyFunc, TimeUpdate from nengo.builder.processes import SimProcess from nengo.config import ConfigError from nengo.exceptions import BuildError from nengo.neurons import Direct +from nengo.synapses import Lowpass from nengo.transforms import SparseMatrix import numpy as np import tensorflow as tf @@ -115,13 +117,7 @@ def __init__( # apply graph simplification functions simplifications = config.get_setting( - model, - "simplifications", - [ - graph_optimizer.remove_unmodified_resets, - graph_optimizer.remove_zero_incs, - graph_optimizer.remove_identity_muls, - ], + model, "simplifications", graph_optimizer.default_simplifications, ) with progress.sub("operator simplificaton", max_value=None): @@ -230,14 +226,20 @@ def get_initializer(init_vals): values, shapes, dtype, minibatched = init_vals - if all(np.all(v == 0) for v in values): + # initial value of None means that the initial value isn't used, so we + # can use anything for the initial value + if all(v is None for v in values): + initializer = None + elif all(v is None or np.all(v == 0) for v in values): initializer = tf.initializers.zeros() - elif all(np.all(v == 1) for v in values): + elif all(v is None or np.all(v == 1) for v in values): initializer = tf.initializers.ones() else: val = tf.concat( [ - tf.cast(tf.broadcast_to(v, s), dtype) + tf.zeros(s, dtype) + if v is None + else tf.cast(tf.broadcast_to(v, s), dtype) for v, s in zip(values, shapes) ], axis=1 if minibatched else 0, @@ -256,6 +258,7 @@ def get_initializer(init_vals): assert len(self.base_params) == 0 for k, v in self.base_arrays_init[True].items(): initializer, shape, dtype = get_initializer(v) + assert initializer is not None # trainable params should never be set self.base_params[k] = self.add_weight( initializer=initializer, shape=shape, @@ -272,13 +275,16 @@ def get_initializer(init_vals): self.saved_state = OrderedDict() for k, v in self.base_arrays_init[False].items(): initializer, shape, dtype = get_initializer(v) - self.saved_state[k] = self.add_weight( - initializer=initializer, - shape=shape, - dtype=dtype, - trainable=False, - name="saved_state/%s_%s" % (dtype, "_".join(str(x) for x in shape)), - ) + if initializer is not None: + # don't need to save the state for signals where the initial value + # doesn't matter + self.saved_state[k] = self.add_weight( + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=False, + name="saved_state/%s_%s" % (dtype, "_".join(str(x) for x in shape)), + ) logger.debug("created saved state variables") logger.debug([str(x) for x in self.saved_state.values()]) @@ -477,6 +483,33 @@ def call(self, inputs, training=None, progress=None, stateful=False): return outputs + def _fill_bases(self, saved_state, base_params): + """ + Initialize signals.bases from TensorGraph params. + + Parameters + ---------- + saved_state : dict + Mapping from base keys to initial values + base_params : dict + Mapping from base keys to initial values + """ + + for key, val in saved_state.items(): + # we add the tf.identity so that when we write we're not updating + # the base variable + self.signals.bases[key] = tf.identity(val) + for key, val in base_params.items(): + self.signals.bases[key] = tf.identity(val) + for key, (_, shapes, _, minibatched) in self.base_arrays_init[False].items(): + if key not in self.signals.bases: + # no saved state for this base, so we just temporarily insert + # the shape information so that future scatters will know + # what the base shape is + shape = list(shapes[0]) + shape[minibatched] = sum(x[minibatched] for x in shapes) + self.signals.bases[key] = tuple(shape) + def _build_loop(self, progress): """ Build simulation loop using symbolic while loop. @@ -503,13 +536,10 @@ def loop_body(loop_i, n_steps, probe_arrays, saved_state, base_params): # fill in signals.bases # note: we need to do this here because we # need to use the tensors from inside the loop, not the source variables) - # note2: eager while loops pass in the variable directly, - # so we add the tf.identity so that when we write we're not updating - # the base variable - for key, val in zip(self.saved_state.keys(), saved_state): - self.signals.bases[key] = tf.identity(val) - for key, val in zip(self.base_params.keys(), base_params): - self.signals.bases[key] = tf.identity(val) + self._fill_bases( + dict(zip(self.saved_state, saved_state)), + dict(zip(self.base_params, base_params)), + ) def update_probes(probe_tensors, loop_i): for i, p in enumerate(probe_tensors): @@ -599,10 +629,7 @@ def _build_no_loop(self, progress): Tensors representing the value of all internal state at the end of the run. """ - for key, val in self.saved_state.items(): - self.signals.bases[key] = tf.identity(val) - for key, val in self.base_params.items(): - self.signals.bases[key] = tf.identity(val) + self._fill_bases(self.saved_state, self.base_params) loop_i = tf.constant(0) # symbolic loop variable loop_iter = 0 # non-symbolic loop variable @@ -792,11 +819,6 @@ def get_tensor(self, sig): except KeyError: base = self.saved_state[tensor_sig.key] - if "/while/" in tensor_sig.tf_indices.name: - # invalidate cached indices so they will be rebuilt outside the - # while loop - tensor_sig._tf_indices = None - return tf.gather( base, tensor_sig.tf_indices, axis=1 if tensor_sig.minibatched else 0, ) @@ -993,6 +1015,31 @@ def create_signals(self, sigs): "\n%s", "".join("|" if i in breaks else " " for i in range(len(sigs))) ) + # find all the signals that have a set operation associated with them + + def special_set(s, op): + return ( + # we don't include Lowpass ops, because for efficiency reasons in the + # nengo-dl Lowpass implementation we reuse the output signal (which is + # set) as the state signal (so we need to include that signal in the + # state) + (isinstance(op, SimProcess) and isinstance(op.process, Lowpass)) + # nengo marks the time step as a set, but really it's an inc (since + # it's incrementing the simulation step) + or (isinstance(op, TimeUpdate) and s is op.step) + # nengo marks neuron state as a set, but really it's more like an + # inc/update (since the neuron calculation may depend on the state) + or (isinstance(op, SimNeurons) and s in op.states) + ) + + set_sigs = { + s.base + for ops in self.plan + for op in ops + for s in op.sets + if not special_set(s, op) + } + # create all the base signals for i, sig in enumerate(sigs): assert sig not in self.signals @@ -1029,12 +1076,18 @@ def create_signals(self, sigs): curr_keys[array_params] = object() key = curr_keys[array_params] - initial_value = sig.initial_value - if sig.sparse: - if isinstance(initial_value, SparseMatrix): - initial_value = initial_value.data - else: - initial_value = initial_value.tocoo().data + if sig in set_sigs: + # signals with a set operation associated with them don't need an + # initial value (since the value will just be immediately overridden + # by the set operation) + initial_value = None + else: + initial_value = sig.initial_value + if sig.sparse: + if isinstance(initial_value, SparseMatrix): + initial_value = initial_value.data + else: + initial_value = initial_value.tocoo().data if sig.minibatched: shape = (self.minibatch_size,) + shape diff --git a/nengo_dl/tensor_node.py b/nengo_dl/tensor_node.py index b94acc228..36d62a34e 100644 --- a/nengo_dl/tensor_node.py +++ b/nengo_dl/tensor_node.py @@ -9,7 +9,6 @@ import warnings from nengo import Node, Connection, Ensemble, builder -from nengo import rc as nengo_rc from nengo.base import NengoObject from nengo.builder.operator import Reset from nengo.config import Config @@ -248,16 +247,12 @@ def build_tensor_node(model, node): # input signal if node.shape_in is not None: - sig_in = builder.Signal( - np.zeros(node.size_in, dtype=nengo_rc.float_dtype), name="%s.in" % node - ) + sig_in = builder.Signal(shape=(node.size_in,), name="%s.in" % node) model.add_op(Reset(sig_in)) else: sig_in = None - sig_out = builder.Signal( - np.zeros(node.size_out, dtype=nengo_rc.float_dtype), name="%s.out" % node - ) + sig_out = builder.Signal(shape=(node.size_out,), name="%s.out" % node) model.sig[node]["in"] = sig_in model.sig[node]["out"] = sig_out diff --git a/nengo_dl/tests/test_benchmarks.py b/nengo_dl/tests/test_benchmarks.py index 98a28d26c..79e4c2c6b 100644 --- a/nengo_dl/tests/test_benchmarks.py +++ b/nengo_dl/tests/test_benchmarks.py @@ -101,7 +101,9 @@ def _test_random( assert all(net.inp in x for x in post_conns.values()) -@pytest.mark.parametrize("network, train", [("integrator", True), ("cconv", False)]) +@pytest.mark.parametrize( + "network, train", [("integrator", True), ("cconv", False), ("test", True)] +) def test_run_profile(network, train, pytestconfig, monkeypatch, tmpdir): monkeypatch.chdir(tmpdir) @@ -109,6 +111,10 @@ def test_run_profile(network, train, pytestconfig, monkeypatch, tmpdir): net = benchmarks.integrator(3, 2, nengo.SpikingRectifiedLinear()) elif network == "cconv": net = benchmarks.cconv(3, 10, nengo.LIF()) + elif network == "test": + with nengo.Network() as net: + ens = nengo.Ensemble(10, 1) + net.p = nengo.Probe(ens) benchmarks.run_profile( net, @@ -190,10 +196,10 @@ def test_lmu(Simulator, native_nengo, pytestconfig): @pytest.mark.parametrize( "net, train, minibatch_size, min, max", [ - (benchmarks.cconv(128, 64, nengo.RectifiedLinear()), False, 64, 0.65, 0.8), - (benchmarks.cconv(128, 64, nengo.LIF()), False, 64, 1.45, 1.65), - (benchmarks.integrator(128, 32, nengo.RectifiedLinear()), True, 64, 0.6, 1.0), - (benchmarks.integrator(128, 32, nengo.LIF()), True, 64, 1.1, 1.4), + (benchmarks.cconv(128, 64, nengo.RectifiedLinear()), False, 64, 0.6, 0.75), + (benchmarks.cconv(128, 64, nengo.LIF()), False, 64, 1.4, 1.6), + (benchmarks.integrator(128, 32, nengo.RectifiedLinear()), True, 64, 0.55, 0.95), + (benchmarks.integrator(128, 32, nengo.LIF()), True, 64, 1.0, 1.3), ( benchmarks.random_network( 64, diff --git a/nengo_dl/tests/test_graph_optimizer.py b/nengo_dl/tests/test_graph_optimizer.py index e6c73f29c..bf48342f3 100644 --- a/nengo_dl/tests/test_graph_optimizer.py +++ b/nengo_dl/tests/test_graph_optimizer.py @@ -1,5 +1,8 @@ # pylint: disable=missing-docstring +from distutils.version import LooseVersion + +import nengo from nengo.exceptions import BuildError from nengo.neurons import LIF, LIFRate, Izhikevich, AdaptiveLIF from nengo.synapses import Lowpass, Triangle, Alpha, LinearFilter @@ -15,22 +18,24 @@ ) from nengo.builder.processes import SimProcess from nengo.builder.signal import Signal +from nengo.builder.transforms import ConvInc import numpy as np import pytest -from nengo_dl import op_builders +from nengo_dl import config, op_builders, transform_builders from nengo_dl.graph_optimizer import ( - mergeable, greedy_planner, - tree_planner, - transitive_planner, + mergeable, + noop_order_signals, noop_planner, order_signals, - noop_order_signals, - remove_unmodified_resets, - remove_zero_incs, remove_constant_copies, remove_identity_muls, + remove_reset_incs, + remove_unmodified_resets, + remove_zero_incs, + transitive_planner, + tree_planner, ) from nengo_dl.tensor_node import SimTensorNode from nengo_dl.tests import dummies @@ -1017,3 +1022,209 @@ def test_remove_identity_muls(Op): operators = [Op(x, dummies.Signal(), dummies.Signal()), dummies.Op(sets=[x])] new_operators = remove_identity_muls(operators) assert new_operators == operators + + +def test_remove_reset_incs(): + # elementwiseinc converted to elementwiseset + x = dummies.Signal() + operators = [Reset(x), ElementwiseInc(dummies.Signal(), dummies.Signal(), x)] + new_operators = remove_reset_incs(operators) + assert len(new_operators) == 1 + assert isinstance(new_operators[0], op_builders.ElementwiseSet) + assert new_operators[0].Y is x + assert new_operators[0].incs == [] + assert new_operators[0].sets == [x] + + # dotinc converted to dotset + x = dummies.Signal() + operators = [Reset(x), DotInc(dummies.Signal(), dummies.Signal(), x)] + new_operators = remove_reset_incs(operators) + assert len(new_operators) == 1 + assert isinstance(new_operators[0], op_builders.DotSet) + assert new_operators[0].Y is x + + # copy inc converted to copy set + x = dummies.Signal() + operators = [Reset(x), Copy(dummies.Signal(), x, inc=True)] + new_operators = remove_reset_incs(operators) + assert len(new_operators) == 1 + assert not new_operators[0].inc + assert new_operators[0].dst is x + + # simprocess inc converted to simprocess set + x = dummies.Signal() + operators = [ + Reset(x), + SimProcess(None, dummies.Signal(), x, dummies.Signal(), mode="inc"), + ] + new_operators = remove_reset_incs(operators) + assert len(new_operators) == 1 + assert new_operators[0].mode == "set" + assert new_operators[0].output is x + + # convinc converted to convset + x = dummies.Signal() + operators = [Reset(x), ConvInc(dummies.Signal(), dummies.Signal(), x, None)] + new_operators = remove_reset_incs(operators) + assert len(new_operators) == 1 + assert isinstance(new_operators[0], transform_builders.ConvSet) + assert new_operators[0].Y is x + + # sparsedotinc converted to sparsedotset + x = dummies.Signal() + operators = [ + Reset(x), + SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(), x, None), + ] + new_operators = remove_reset_incs(operators) + assert len(new_operators) == 1 + assert isinstance(new_operators[0], op_builders.SparseDotSet) + assert new_operators[0].Y is x + + # resetinc converted to reset + x = dummies.Signal() + operators = [Reset(x), op_builders.ResetInc(x)] + operators[1].value = np.ones((2, 3)) + new_operators = remove_reset_incs(operators) + assert len(new_operators) == 1 + assert type(new_operators[0]) == Reset + assert np.allclose(new_operators[0].value, 1) + assert new_operators[0].dst is x + + # multiple incs + x = dummies.Signal() + operators = [ + Reset(x), + ElementwiseInc(dummies.Signal(), dummies.Signal(), x), + ElementwiseInc(dummies.Signal(), dummies.Signal(), x), + ] + new_operators = remove_reset_incs(operators) + assert len(new_operators) == 2 + assert isinstance(new_operators[0], op_builders.ElementwiseSet) + assert isinstance(new_operators[1], ElementwiseInc) + + # nonzero reset doesn't get converted + x = dummies.Signal() + operators = [ + Reset(x, value=1), + ElementwiseInc(dummies.Signal(), dummies.Signal(), x), + ] + new_operators = remove_reset_incs(operators) + assert operators == new_operators + + # reset without inc + x = dummies.Signal() + operators = [ + Reset(x), + Copy(dummies.Signal(), x, inc=False), + ] + new_operators = remove_reset_incs(operators) + assert operators == new_operators + + # reset with partial inc + x = Signal(shape=(10,)) + operators = [ + Reset(x), + Copy(dummies.Signal(), x[:5], inc=True), + ] + new_operators = remove_reset_incs(operators) + assert operators == new_operators + + # unknown inc type + class NewCopy(Copy): + pass + + x = dummies.Signal() + operators = [ + Reset(x), + NewCopy(dummies.Signal(), x, inc=True), + ElementwiseInc(dummies.Signal(), dummies.Signal(), x), + ] + with pytest.warns(UserWarning, match="Unknown incer type"): + new_operators = remove_reset_incs(operators) + assert len(new_operators) == 2 + # uses the known op (ElementwiseInc) instead of unknown one + assert isinstance(new_operators[0], op_builders.ElementwiseSet) + assert new_operators[1] is operators[1] + + operators = [ + Reset(x), + NewCopy(dummies.Signal(), x, inc=True), + ] + # no optimization if only unknown incers + with pytest.warns(UserWarning, match="Unknown incer type"): + new_operators = remove_reset_incs(operators) + assert new_operators == operators + + +def test_remove_reset_inc_functional(Simulator, seed): + with nengo.Network(seed=seed) as net: + config.configure_settings( + simplifications=[remove_zero_incs, remove_unmodified_resets] + ) + + # reset+simprocess on the noise + ens = nengo.Ensemble( + 1, 1, noise=nengo.processes.WhiteNoise(), neuron_type=nengo.Direct() + ) + + node0 = nengo.Node(size_in=1, label="node0") + # reset+elementwiseinc (weights) + # reset+copy (to node input) + nengo.Connection(ens, node0, transform=1, synapse=None) + + node1 = nengo.Node(size_in=3, label="node1") + # reset+dotinc (weights) + # reset+copy (to node input) + nengo.Connection(node0, node1, transform=np.ones((3, 1)), synapse=None) + + # reset+elementwiseinc (weights, in nengo<3.1) + # reset+copy (to probe input) + p = nengo.Probe(node1) + + with Simulator(net) as sim: + extra_op = LooseVersion(nengo.__version__) < "3.1.0" + + assert len(sim.tensor_graph.plan) == 8 + extra_op + + # check that we have all the resets we expect + resets = sim.tensor_graph.plan[1] + assert isinstance(resets[0], Reset) + assert len(resets) == 6 + extra_op + + # check that all the ops are incs like we expect + incs = sim.tensor_graph.plan[2:] + for ops in incs: + for op in ops: + assert len(op.incs) == 1 + assert len(op.sets) == 0 + + sim.run_steps(100) + + with net: + config.configure_settings( + simplifications=[ + remove_zero_incs, + remove_unmodified_resets, + remove_reset_incs, + ] + ) + + with Simulator(net) as sim_remove: + # check that resets have been removed + assert len(sim_remove.tensor_graph.plan) == 7 + extra_op + assert ( + len([x for x in sim_remove.tensor_graph.plan if isinstance(x[0], Reset)]) + == 0 + ) + + # check that all the ops are sets like we expect + incs = sim_remove.tensor_graph.plan[1:] + for ops in incs: + for op in ops: + assert len(op.incs) == 0 + assert len(op.sets) == 1 + + sim_remove.run_steps(100) + + assert np.allclose(sim.data[p], sim_remove.data[p]) diff --git a/nengo_dl/tests/test_simulator.py b/nengo_dl/tests/test_simulator.py index eb2f9773e..72d026e9b 100644 --- a/nengo_dl/tests/test_simulator.py +++ b/nengo_dl/tests/test_simulator.py @@ -591,6 +591,10 @@ def get_network(seed): else: assert not np.allclose(sim_load.data[p1], sim_save.data[p0][10:]) + with Simulator(nengo.Network()) as sim: + with pytest.raises(SimulationError, match="!= number of variables"): + sim.load_params(str(tmpdir)) + def test_model_passing(Simulator, seed): # make sure that passing a built model to the Simulator works properly diff --git a/nengo_dl/tests/test_tensor_graph.py b/nengo_dl/tests/test_tensor_graph.py index bea696e86..9c5aa7572 100644 --- a/nengo_dl/tests/test_tensor_graph.py +++ b/nengo_dl/tests/test_tensor_graph.py @@ -229,7 +229,10 @@ def test_signal_order_deterministic(Simulator, seed): sim1.tensor_graph.base_arrays_init[trainable].values(), sim2.tensor_graph.base_arrays_init[trainable].values(), ): - assert np.allclose(v[0], v2[0]) + assert all( + (x is None and y is None) or np.allclose(x, y) + for x, y in zip(v[0], v2[0]) + ) def test_create_signals(): diff --git a/nengo_dl/tests/test_tensor_node.py b/nengo_dl/tests/test_tensor_node.py index d97e4b0f8..5af8fb590 100644 --- a/nengo_dl/tests/test_tensor_node.py +++ b/nengo_dl/tests/test_tensor_node.py @@ -1,6 +1,5 @@ # pylint: disable=missing-docstring -from distutils.version import LooseVersion from functools import partial import nengo @@ -262,19 +261,12 @@ def call(self, x): assert np.allclose(sim.data[p2], 3) # note: when inference-only=True the weights will be marked as non-trainable - - default_conn_params = 2 if LooseVersion(nengo.__version__) < "3.1.0" else 0 - if sim.tensor_graph.inference_only: - assert ( - len(sim.keras_model.non_trainable_variables) == 8 + default_conn_params - ) + assert len(sim.keras_model.non_trainable_variables) == 4 assert len(sim.keras_model.trainable_variables) == 0 vars = sim.keras_model.non_trainable_variables[-2:] else: - assert ( - len(sim.keras_model.non_trainable_variables) == 6 + default_conn_params - ) + assert len(sim.keras_model.non_trainable_variables) == 2 assert len(sim.keras_model.trainable_variables) == 2 vars = sim.keras_model.trainable_variables diff --git a/nengo_dl/transform_builders.py b/nengo_dl/transform_builders.py index d47882d98..54ca8b332 100644 --- a/nengo_dl/transform_builders.py +++ b/nengo_dl/transform_builders.py @@ -11,7 +11,25 @@ from nengo_dl.builder import Builder, OpBuilder +class ConvSet(ConvInc): + """ + A version of `~nengo.builder.transforms.ConvInc` that overwrites the target + rather than incrementing. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.incs, self.sets = self.sets, self.incs + + @property + def Y(self): + """Y is stored in ``sets`` rather than ``incs``.""" + return self.sets[0] + + @Builder.register(ConvInc) +@Builder.register(ConvSet) class ConvIncBuilder(OpBuilder): """ Build a group of `nengo.builder.transforms.ConvInc` operators. @@ -22,6 +40,7 @@ def __init__(self, ops, signals, config): self.conv = ops[0].conv self.n_ops = len(ops) + self.mode = "inc" if type(ops[0]) == ConvInc else "update" if not self.conv.channels_last and config.cpu_only: # TensorFlow doesn't support channels first on CPU, so if @@ -167,7 +186,7 @@ def build_step(self, signals): (signals.minibatch_size, self.n_ops) + self.conv.output_shape.shape ) - signals.scatter(self.Y_data, Y, mode="inc") + signals.scatter(self.Y_data, Y, mode=self.mode) @staticmethod def mergeable(x, y):