Skip to content

Commit

Permalink
Support for transform=None
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss authored and tbekolay committed Mar 4, 2020
1 parent 67d6f7f commit 895409e
Show file tree
Hide file tree
Showing 13 changed files with 102 additions and 40 deletions.
12 changes: 9 additions & 3 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Release history
- Deprecated
- Removed
3.0.1 (unreleased)
3.1.0 (unreleased)
------------------

**Added**
Expand All @@ -41,6 +41,10 @@ Release history
``nengo_dl.configure_settings(dtype=...)`` config option. Note that this will
override the default precision set in ``nengo.rc``. (`#119`_)
- Minimum Numpy version is now 1.16.0 (required by TensorFlow). (`#119`_)
- Added support for the new ``transform=None`` default in Nengo connections
(see `#1591`_). Note that this may change the number of trainable parameters in a
network as the scalar default ``transform=1`` weights on non-Ensemble connections
will no longer be present. (`#128`_)

**Fixed**

Expand All @@ -54,6 +58,8 @@ Release history
correctly (previously it was always set to the default value). (`#119`_)

.. _#119: https://github.com/nengo/nengo-dl/pull/119
.. _#1591: https://github.com/nengo/nengo/pull/1591
.. _#128: https://github.com/nengo/nengo-dl/pull/128

3.0.0 (December 17, 2019)
-------------------------
Expand Down Expand Up @@ -417,7 +423,7 @@ details.
``sim.loss``/``sim.train`` data argument, if no input/target data is
required.
- The ``objective`` dict in ``sim.train``/``sim.loss`` can now contain
tuples of probes as the keys, in which case the objective function will be
tuples of probes as the keys, in which case the objective function will be
called with a corresponding tuple of probe/target values as each argument.
- Added the ``sim.run_batch`` function. This exposes all the functionality
that the ``sim.run``/``sim.train``/``sim.loss`` functions are based on,
Expand Down Expand Up @@ -744,7 +750,7 @@ details.
- Fixed a bug where input nodes that were only read as a view were not
feedable
- Updated ``tensorflow-gpu`` installation check
- Improved numerical stability of ``LIFRate`` gradients (`#26
- Improved numerical stability of ``LIFRate`` gradients (`#26
<https://github.com/nengo/nengo-dl/issues/26>`_)
- Added more informative error message when data is provided with fewer items
than ``sim.minibatch_size`` (`#30 <https://github.com/nengo/nengo-dl/issues/30>`_)
Expand Down
7 changes: 6 additions & 1 deletion nengo_dl/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf
from tensorflow.python.eager import context

from nengo_dl import utils
from nengo_dl import compat, utils


class NengoSummaries(tf.keras.callbacks.Callback):
Expand Down Expand Up @@ -61,6 +61,11 @@ def __init__(self, log_dir, sim, objects):
param = "bias"
name = "Ensemble.neurons_%s" % obj.ensemble.label
elif isinstance(obj, nengo.Connection):
if not compat.conn_has_weights(obj):
raise ValidationError(
"Connection '%s' does not have any weights to log" % obj,
"objects",
)
param = "weights"
name = "Connection_%s" % obj.label

Expand Down
18 changes: 18 additions & 0 deletions nengo_dl/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
dependencies.
"""

from distutils.version import LooseVersion

import nengo
from nengo._vendor.scipy.sparse import linalg_interface, linalg_onenormest
import tensorflow as tf

Expand Down Expand Up @@ -75,3 +78,18 @@ def filter(self, record):

# monkeypatch fix for https://github.com/nengo/nengo/pull/1587
linalg_onenormest.aslinearoperator = linalg_interface.aslinearoperator

if LooseVersion(nengo.__version__) < "3.1.0":
default_transform = 1

def conn_has_weights(conn):
"""All connections have weights."""
return True


else:
default_transform = None

def conn_has_weights(conn):
"""Equivalent to conn.has_weights."""
return conn.has_weights
2 changes: 1 addition & 1 deletion nengo_dl/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Regularize(tf.losses.Loss):
with nengo.Network() as net:
a = nengo.Node([0])
b = nengo.Node(size_in=1)
c = nengo.Connection(a, b)
c = nengo.Connection(a, b, transform=1)
p = nengo.Probe(c, "weights")
...
Expand Down
17 changes: 13 additions & 4 deletions nengo_dl/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import tensorflow as tf
from tensorflow.python.keras import backend

from nengo_dl import utils, config, callbacks
from nengo_dl import callbacks, compat, config, utils
from nengo_dl.builder import NengoBuilder, NengoModel
from nengo_dl.tensor_graph import TensorGraph

Expand Down Expand Up @@ -1346,7 +1346,8 @@ def get_nengo_params(self, nengo_objs, as_dict=False):
fetches = []
for obj in nengo_objs:
if isinstance(obj, Connection):
fetches.append((obj, "weights"))
if compat.conn_has_weights(obj):
fetches.append((obj, "weights"))
elif isinstance(obj, Ensemble):
if isinstance(obj.neuron_type, Direct):
# we cannot transfer direct ensemble parameters, because
Expand All @@ -1372,6 +1373,10 @@ def get_nengo_params(self, nengo_objs, as_dict=False):
idx = 0
for obj in nengo_objs:
if isinstance(obj, Connection):
if not compat.conn_has_weights(obj):
params.append({"transform": None})
continue

weights = data[idx]
idx += 1
if isinstance(obj.pre_obj, Ensemble):
Expand All @@ -1381,7 +1386,7 @@ def get_nengo_params(self, nengo_objs, as_dict=False):
"function": lambda x, weights=weights: np.zeros(
weights.shape[0]
),
"transform": 1,
"transform": compat.default_transform,
}
)
elif isinstance(obj.transform, Convolution):
Expand Down Expand Up @@ -2104,7 +2109,11 @@ def __getitem__(self, obj):
)
elif isinstance(obj, Connection):
# get the live simulation values
weights = self.get_params((obj, "weights"))[0]
weights = (
self.get_params((obj, "weights"))[0]
if compat.conn_has_weights(obj)
else None
)

# impossible to recover transform
transform = None
Expand Down
24 changes: 17 additions & 7 deletions nengo_dl/tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@
from tensorflow.python.eager import context
from tensorflow.python.training.tracking import base as trackable

from nengo_dl import builder, graph_optimizer, signals, utils, tensor_node, config
from nengo_dl import (
builder,
config,
compat,
graph_optimizer,
tensor_node,
signals,
utils,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -856,10 +864,11 @@ def mark_network(parent_configs, net):
for conn in net.connections:
# note: this doesn't include probe connections, since they
# aren't added to the network
self.model.sig[conn]["weights"].trainable = get_trainable(
parent_configs, conn
)
self.model.sig[conn]["weights"].minibatched = False
if compat.conn_has_weights(conn):
self.model.sig[conn]["weights"].trainable = get_trainable(
parent_configs, conn
)
self.model.sig[conn]["weights"].minibatched = False

# parameters can't be modified by an online Nengo learning rule
# and offline training at the same time. (it is possible in
Expand Down Expand Up @@ -909,8 +918,9 @@ def mark_network(parent_configs, net):
probe_seeds = [self.model.seeds[p] for p in self.model.probes]
for obj, seed in self.model.seeds.items():
if isinstance(obj, Connection) and seed in probe_seeds:
self.model.sig[obj]["weights"].trainable = False
self.model.sig[obj]["weights"].minibatched = False
if compat.conn_has_weights(obj):
self.model.sig[obj]["weights"].trainable = False
self.model.sig[obj]["weights"].minibatched = False

# time/step are not minibatched and not trainable
self.model.step.trainable = False
Expand Down
3 changes: 2 additions & 1 deletion nengo_dl/tensor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tensorflow as tf

from nengo_dl.builder import Builder, OpBuilder, NengoBuilder
from nengo_dl.compat import default_transform
from nengo_dl.config import configure_settings


Expand Down Expand Up @@ -403,7 +404,7 @@ def __init__(self, layer_func):
def __call__(
self,
input,
transform=1,
transform=default_transform,
shape_in=None,
synapse=None,
return_conn=False,
Expand Down
2 changes: 1 addition & 1 deletion nengo_dl/tests/dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def linear_net():
with nengo.Network() as net:
a = nengo.Node([1])
b = nengo.Node(size_in=1)
nengo.Connection(a, b, synapse=None)
nengo.Connection(a, b, synapse=None, transform=1)
p = nengo.Probe(b)

return net, a, p
11 changes: 5 additions & 6 deletions nengo_dl/tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def test_lmu(Simulator, native_nengo, pytestconfig):
[
(benchmarks.cconv(128, 64, nengo.RectifiedLinear()), False, 64, 0.65, 0.8),
(benchmarks.cconv(128, 64, nengo.LIF()), False, 64, 1.45, 1.65),
(benchmarks.integrator(128, 32, nengo.RectifiedLinear()), True, 64, 0.5, 0.75),
(benchmarks.integrator(128, 32, nengo.LIF()), True, 64, 0.9, 1.2),
(benchmarks.integrator(128, 32, nengo.RectifiedLinear()), True, 64, 0.6, 1.0),
(benchmarks.integrator(128, 32, nengo.LIF()), True, 64, 1.1, 1.4),
(
benchmarks.random_network(
64,
Expand All @@ -208,8 +208,7 @@ def test_lmu(Simulator, native_nengo, pytestconfig):
0.35,
0.55,
),
(benchmarks.lmu(1000, 1, native_nengo=True), True, 100, 0.75, 1.05),
# (benchmarks.spaun(1), False, None, 8.02, 9.52),
(benchmarks.lmu(1000, 1, native_nengo=True), True, 100, 0.85, 1.15),
],
)
def test_performance(net, train, minibatch_size, min, max):
Expand All @@ -218,8 +217,8 @@ def test_performance(net, train, minibatch_size, min, max):
# GPU: GeForce GTX Titan X
# Python version: 3.6.8
# TensorFlow GPU version: 2.0.0
# Nengo version: 3.0.0
# NengoDL version: 3.0.0
# Nengo version: 3.1.0
# NengoDL version: 3.1.0

time = benchmarks.run_profile(
net,
Expand Down
4 changes: 2 additions & 2 deletions nengo_dl/tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ def test_fit(Simulator, seed):
inp_a = nengo.Node([0])
inp_b = nengo.Node([0])
inp = nengo.Node(size_in=2)
nengo.Connection(inp_a, inp[0])
nengo.Connection(inp_b, inp[1])
nengo.Connection(inp_a, inp[0], transform=1)
nengo.Connection(inp_b, inp[1], transform=1)

ens = nengo.Ensemble(
n_hidden + 1, n_hidden, neuron_type=nengo.Sigmoid(tau_ref=1)
Expand Down
26 changes: 16 additions & 10 deletions nengo_dl/tests/test_simulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=missing-docstring

from collections import OrderedDict
from distutils.version import LooseVersion
import logging
import os
import pickle
Expand All @@ -20,7 +21,7 @@
from tensorflow.core.util import event_pb2

from nengo_dl import Layer, TensorNode, callbacks, configure_settings, dists, utils
from nengo_dl.compat import TFLogFilter
from nengo_dl.compat import TFLogFilter, default_transform
from nengo_dl.simulator import SimulationData
from nengo_dl.tests import dummies

Expand Down Expand Up @@ -167,8 +168,8 @@ def test_train_ff(Simulator, neurons, use_loop, seed):
inp_a = nengo.Node([0])
inp_b = nengo.Node([0])
inp = nengo.Node(size_in=2)
nengo.Connection(inp_a, inp[0])
nengo.Connection(inp_b, inp[1])
nengo.Connection(inp_a, inp[0], transform=1)
nengo.Connection(inp_b, inp[1], transform=1)

ens = nengo.Ensemble(
n_hidden + 1, n_hidden, neuron_type=nengo.Sigmoid(tau_ref=1)
Expand Down Expand Up @@ -300,11 +301,11 @@ def test_train_objective(Simulator, unroll, seed):
inp = nengo.Node([1])

ens = nengo.Ensemble(n_hidden, 1, neuron_type=nengo.RectifiedLinear())
nengo.Connection(inp, ens, synapse=0.01)
nengo.Connection(inp, ens, synapse=0.01, transform=1)
p = nengo.Probe(ens)

ens2 = nengo.Ensemble(n_hidden, 1, neuron_type=nengo.RectifiedLinear())
nengo.Connection(inp, ens2, synapse=0.01)
nengo.Connection(inp, ens2, synapse=0.01, transform=1)
p2 = nengo.Probe(ens2)

with Simulator(
Expand Down Expand Up @@ -688,7 +689,8 @@ def test_tensorboard(Simulator, tmpdir):
with nengo.Network() as net:
a = nengo.Node([0])
b = nengo.Ensemble(10, 1, neuron_type=nengo.LIFRate())
c = nengo.Connection(a, b)
c = nengo.Connection(a, b, transform=1)
c0 = nengo.Connection(a, b)
p = nengo.Probe(b)
p2 = nengo.Probe(c)

Expand Down Expand Up @@ -768,6 +770,10 @@ def test_tensorboard(Simulator, tmpdir):
with pytest.raises(ValidationError, match="Unknown summary object"):
callbacks.NengoSummaries(log_dir=log_dir + "/nengo", sim=sim, objects=[a])

if LooseVersion(nengo.__version__) >= "3.1.0":
with pytest.raises(ValidationError, match="does not have any weights"):
callbacks.NengoSummaries(log_dir=log_dir + "/nengo", sim=sim, objects=[c0])


@pytest.mark.parametrize("mode", ("predict", "train"))
@pytest.mark.training
Expand Down Expand Up @@ -1259,7 +1265,7 @@ def test_get_nengo_params(Simulator, seed):

# check that single objects are returned as single dicts
params = sim.get_nengo_params(d)
assert params["transform"] == 1
assert params["transform"] is default_transform

fetches = [a.neurons, b, c, d, e, h]

Expand Down Expand Up @@ -1290,7 +1296,7 @@ def test_direct_grads(Simulator, mixed):
if mixed:
with net:
c = nengo.Node(size_in=1)
nengo.Connection(a, c, synapse=None)
nengo.Connection(a, c, synapse=None, transform=1)
p2 = nengo.Probe(c)

with Simulator(net, minibatch_size=1) as sim:
Expand Down Expand Up @@ -1328,7 +1334,7 @@ def test_non_differentiable(Simulator):
with nengo.Network() as net:
a = nengo.Node([0])
b = nengo.Node(lambda t, x: x, size_in=1)
c = nengo.Connection(a, b)
c = nengo.Connection(a, b, transform=1)
p = nengo.Probe(b)

with Simulator(net) as sim:
Expand Down Expand Up @@ -1429,7 +1435,7 @@ def test_inference_only(Simulator, neuron_type, seed):

a = nengo.Node([0])
b = nengo.Ensemble(10, 1, neuron_type=neuron_type)
c = nengo.Connection(a, b, synapse=None)
c = nengo.Connection(a, b, synapse=None, transform=1)
p = nengo.Probe(b)

with Simulator(net) as sim:
Expand Down
12 changes: 10 additions & 2 deletions nengo_dl/tests/test_tensor_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=missing-docstring

from distutils.version import LooseVersion
from functools import partial

import nengo
Expand Down Expand Up @@ -261,12 +262,19 @@ def call(self, x):
assert np.allclose(sim.data[p2], 3)

# note: when inference-only=True the weights will be marked as non-trainable

default_conn_params = 2 if LooseVersion(nengo.__version__) < "3.1.0" else 0

if sim.tensor_graph.inference_only:
assert len(sim.keras_model.non_trainable_variables) == 10
assert (
len(sim.keras_model.non_trainable_variables) == 8 + default_conn_params
)
assert len(sim.keras_model.trainable_variables) == 0
vars = sim.keras_model.non_trainable_variables[-2:]
else:
assert len(sim.keras_model.non_trainable_variables) == 8
assert (
len(sim.keras_model.non_trainable_variables) == 6 + default_conn_params
)
assert len(sim.keras_model.trainable_variables) == 2
vars = sim.keras_model.trainable_variables

Expand Down
Loading

0 comments on commit 895409e

Please sign in to comment.