Skip to content

Commit

Permalink
Simplify CopyBuilder implementation
Browse files Browse the repository at this point in the history
Remove TensorSignal.broadcast, as this was the only
place it was used.
  • Loading branch information
drasmuss authored and tbekolay committed Mar 4, 2020
1 parent 83ab3ed commit 08d3b95
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 68 deletions.
26 changes: 9 additions & 17 deletions nengo_dl/op_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(self, ops, signals, config):
if np.issubdtype(dtype, np.floating):
dtype = signals.dtype.as_numpy_dtype

# unlike other ops, Reset signals might be spread across multiple
# bases, which we need to handle
# Reset signals might be spread across multiple bases, so group them
# by the ones that do share a base
scatters = defaultdict(list)
for op in ops:
scatters[signals[op.dst].key].append(op)
Expand All @@ -73,7 +73,7 @@ def __init__(self, ops, signals, config):
],
axis=1,
)
self.scatters += [(signals.combine([x.dst for x in group]), value)]
self.scatters.append((signals.combine([x.dst for x in group]), value))

logger.debug("scatters")
logger.debug("\n".join([str(x) for x in self.scatters]))
Expand Down Expand Up @@ -101,24 +101,16 @@ def __init__(self, ops, signals, config):
logger.debug("dst %s", [op.dst for op in ops])
logger.debug("dst_slice %s", [getattr(op, "dst_slice", None) for op in ops])

srcs = []
dsts = []
for op in ops:
srcs += [signals[op.src][op.src_slice]]
dsts += [signals[op.dst][op.dst_slice]]
self.src_data = signals.combine([signals[op.src][op.src_slice] for op in ops])
self.dst_data = signals.combine([signals[op.dst][op.dst_slice] for op in ops])

self.mode = "inc" if ops[0].inc else "update"

self.src_data = signals.combine(srcs)
self.dst_data = signals.combine(dsts)

if not self.src_data.minibatched and self.dst_data.minibatched:
# broadcast indices so that the un-minibatched src data gets
# copied to each minibatch dimension in dst
self.src_data = self.src_data.broadcast(signals.minibatch_size)

def build_step(self, signals):
signals.scatter(self.dst_data, signals.gather(self.src_data), mode=self.mode)
src = signals.gather(self.src_data)
if not self.src_data.minibatched and self.dst_data.minibatched:
src = tf.broadcast_to(src, self.dst_data.full_shape)
signals.scatter(self.dst_data, src, mode=self.mode)

@staticmethod
def mergeable(x, y):
Expand Down
28 changes: 0 additions & 28 deletions nengo_dl/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,34 +171,6 @@ def reshape(self, shape):
label=self.label + ".reshape(%s)" % (shape,),
)

def broadcast(self, length):
"""
Add a new dimension by broadcasting this signal along the first axis
for the given length.
Parameters
----------
length : int
The number of times to duplicate signal along the first dimension.
Returns
-------
sig : `.signals.TensorSignal`
TensorSignal with new broadcasted shape
"""

# this only works on vectors
assert self.ndim == 1 and not self.minibatched

return TensorSignal(
self.slices * length,
self.key,
self.dtype,
(length,) + self.shape,
self.minibatch_size,
label=self.label + ".broadcast(%d)" % length,
)

@property
def tf_shape(self):
"""
Expand Down
22 changes: 13 additions & 9 deletions nengo_dl/tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings

from nengo import Connection, Process
from nengo.builder.operator import SimPyFunc, Reset
from nengo.builder.operator import Reset, SimPyFunc
from nengo.builder.processes import SimProcess
from nengo.config import ConfigError
from nengo.exceptions import BuildError
Expand Down Expand Up @@ -956,14 +956,18 @@ def create_signals(self, sigs):
breaks = []
diff = defaultdict(int)
for ops in self.plan:
# note: we don't include Resets, otherwise the big reset block
# overrides most of the partitioning
if not isinstance(ops[0], Reset):
for i in range(len(ops[0].all_signals)):
op_sigs = [op.all_signals[i].base for op in ops]
idxs = [sig_idxs[s] for s in op_sigs]
diff[op_sigs[np.argmin(idxs)]] += 1
diff[op_sigs[np.argmax(idxs)]] -= 1
if isinstance(ops[0], Reset):
# don't include Resets, otherwise the big reset block
# overrides most of the partitioning
partition_sigs = []
else:
partition_sigs = range(len(ops[0].all_signals))

for i in partition_sigs:
op_sigs = [op.all_signals[i].base for op in ops]
idxs = [sig_idxs[s] for s in op_sigs]
diff[op_sigs[np.argmin(idxs)]] += 1
diff[op_sigs[np.argmax(idxs)]] -= 1

# find the partition points in signal list
open = 0
Expand Down
2 changes: 1 addition & 1 deletion nengo_dl/tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_lmu(Simulator, native_nengo, pytestconfig):
@pytest.mark.parametrize(
"net, train, minibatch_size, min, max",
[
(benchmarks.cconv(128, 64, nengo.RectifiedLinear()), False, 64, 0.7, 0.85),
(benchmarks.cconv(128, 64, nengo.RectifiedLinear()), False, 64, 0.65, 0.8),
(benchmarks.cconv(128, 64, nengo.LIF()), False, 64, 1.45, 1.65),
(benchmarks.integrator(128, 32, nengo.RectifiedLinear()), True, 64, 0.5, 0.75),
(benchmarks.integrator(128, 32, nengo.LIF()), True, 64, 0.9, 1.2),
Expand Down
13 changes: 0 additions & 13 deletions nengo_dl/tests/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,6 @@ def test_tensor_signal_reshape():
sig.reshape((4, 4))


def test_tensor_signal_broadcast():
sig = TensorSignal([(0, 4)], object(), None, (4,), None)
base = np.random.randn(4)

sig_broad = sig.broadcast(2)
assert sig_broad.slices == ((0, 4), (0, 4))
assert sig_broad.shape == (2, 4)
assert sig_broad.key == sig.key
assert np.all(
np.reshape(base[sig_broad.tf_indices.numpy()], sig_broad.shape) == base[None, :]
)


def test_tensor_signal_load_indices():
sig = TensorSignal([(2, 6)], object(), None, (4,), None)
assert np.all(sig.tf_indices == np.arange(*sig.slices[0]))
Expand Down

0 comments on commit 08d3b95

Please sign in to comment.