diff --git a/nengo_dl/op_builders.py b/nengo_dl/op_builders.py index d4b994040..53e014d04 100644 --- a/nengo_dl/op_builders.py +++ b/nengo_dl/op_builders.py @@ -56,8 +56,8 @@ def __init__(self, ops, signals, config): if np.issubdtype(dtype, np.floating): dtype = signals.dtype.as_numpy_dtype - # unlike other ops, Reset signals might be spread across multiple - # bases, which we need to handle + # Reset signals might be spread across multiple bases, so group them + # by the ones that do share a base scatters = defaultdict(list) for op in ops: scatters[signals[op.dst].key].append(op) @@ -73,7 +73,7 @@ def __init__(self, ops, signals, config): ], axis=1, ) - self.scatters += [(signals.combine([x.dst for x in group]), value)] + self.scatters.append((signals.combine([x.dst for x in group]), value)) logger.debug("scatters") logger.debug("\n".join([str(x) for x in self.scatters])) @@ -101,24 +101,16 @@ def __init__(self, ops, signals, config): logger.debug("dst %s", [op.dst for op in ops]) logger.debug("dst_slice %s", [getattr(op, "dst_slice", None) for op in ops]) - srcs = [] - dsts = [] - for op in ops: - srcs += [signals[op.src][op.src_slice]] - dsts += [signals[op.dst][op.dst_slice]] + self.src_data = signals.combine([signals[op.src][op.src_slice] for op in ops]) + self.dst_data = signals.combine([signals[op.dst][op.dst_slice] for op in ops]) self.mode = "inc" if ops[0].inc else "update" - self.src_data = signals.combine(srcs) - self.dst_data = signals.combine(dsts) - - if not self.src_data.minibatched and self.dst_data.minibatched: - # broadcast indices so that the un-minibatched src data gets - # copied to each minibatch dimension in dst - self.src_data = self.src_data.broadcast(signals.minibatch_size) - def build_step(self, signals): - signals.scatter(self.dst_data, signals.gather(self.src_data), mode=self.mode) + src = signals.gather(self.src_data) + if not self.src_data.minibatched and self.dst_data.minibatched: + src = tf.broadcast_to(src, self.dst_data.full_shape) + signals.scatter(self.dst_data, src, mode=self.mode) @staticmethod def mergeable(x, y): diff --git a/nengo_dl/signals.py b/nengo_dl/signals.py index 9e5141cd5..ad680cab6 100644 --- a/nengo_dl/signals.py +++ b/nengo_dl/signals.py @@ -171,34 +171,6 @@ def reshape(self, shape): label=self.label + ".reshape(%s)" % (shape,), ) - def broadcast(self, length): - """ - Add a new dimension by broadcasting this signal along the first axis - for the given length. - - Parameters - ---------- - length : int - The number of times to duplicate signal along the first dimension. - - Returns - ------- - sig : `.signals.TensorSignal` - TensorSignal with new broadcasted shape - """ - - # this only works on vectors - assert self.ndim == 1 and not self.minibatched - - return TensorSignal( - self.slices * length, - self.key, - self.dtype, - (length,) + self.shape, - self.minibatch_size, - label=self.label + ".broadcast(%d)" % length, - ) - @property def tf_shape(self): """ diff --git a/nengo_dl/tensor_graph.py b/nengo_dl/tensor_graph.py index 93c186cc7..453856a5a 100644 --- a/nengo_dl/tensor_graph.py +++ b/nengo_dl/tensor_graph.py @@ -8,7 +8,7 @@ import warnings from nengo import Connection, Process -from nengo.builder.operator import SimPyFunc, Reset +from nengo.builder.operator import Reset, SimPyFunc from nengo.builder.processes import SimProcess from nengo.config import ConfigError from nengo.exceptions import BuildError @@ -956,14 +956,18 @@ def create_signals(self, sigs): breaks = [] diff = defaultdict(int) for ops in self.plan: - # note: we don't include Resets, otherwise the big reset block - # overrides most of the partitioning - if not isinstance(ops[0], Reset): - for i in range(len(ops[0].all_signals)): - op_sigs = [op.all_signals[i].base for op in ops] - idxs = [sig_idxs[s] for s in op_sigs] - diff[op_sigs[np.argmin(idxs)]] += 1 - diff[op_sigs[np.argmax(idxs)]] -= 1 + if isinstance(ops[0], Reset): + # don't include Resets, otherwise the big reset block + # overrides most of the partitioning + partition_sigs = [] + else: + partition_sigs = range(len(ops[0].all_signals)) + + for i in partition_sigs: + op_sigs = [op.all_signals[i].base for op in ops] + idxs = [sig_idxs[s] for s in op_sigs] + diff[op_sigs[np.argmin(idxs)]] += 1 + diff[op_sigs[np.argmax(idxs)]] -= 1 # find the partition points in signal list open = 0 diff --git a/nengo_dl/tests/test_benchmarks.py b/nengo_dl/tests/test_benchmarks.py index 38d7cd143..70c86a31d 100644 --- a/nengo_dl/tests/test_benchmarks.py +++ b/nengo_dl/tests/test_benchmarks.py @@ -190,7 +190,7 @@ def test_lmu(Simulator, native_nengo, pytestconfig): @pytest.mark.parametrize( "net, train, minibatch_size, min, max", [ - (benchmarks.cconv(128, 64, nengo.RectifiedLinear()), False, 64, 0.7, 0.85), + (benchmarks.cconv(128, 64, nengo.RectifiedLinear()), False, 64, 0.65, 0.8), (benchmarks.cconv(128, 64, nengo.LIF()), False, 64, 1.45, 1.65), (benchmarks.integrator(128, 32, nengo.RectifiedLinear()), True, 64, 0.5, 0.75), (benchmarks.integrator(128, 32, nengo.LIF()), True, 64, 0.9, 1.2), diff --git a/nengo_dl/tests/test_signals.py b/nengo_dl/tests/test_signals.py index 74ef9be53..f14a25893 100644 --- a/nengo_dl/tests/test_signals.py +++ b/nengo_dl/tests/test_signals.py @@ -72,19 +72,6 @@ def test_tensor_signal_reshape(): sig.reshape((4, 4)) -def test_tensor_signal_broadcast(): - sig = TensorSignal([(0, 4)], object(), None, (4,), None) - base = np.random.randn(4) - - sig_broad = sig.broadcast(2) - assert sig_broad.slices == ((0, 4), (0, 4)) - assert sig_broad.shape == (2, 4) - assert sig_broad.key == sig.key - assert np.all( - np.reshape(base[sig_broad.tf_indices.numpy()], sig_broad.shape) == base[None, :] - ) - - def test_tensor_signal_load_indices(): sig = TensorSignal([(2, 6)], object(), None, (4,), None) assert np.all(sig.tf_indices == np.arange(*sig.slices[0]))