Skip to content

Commit

Permalink
Support for transform=None
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss authored and tbekolay committed Mar 4, 2020
1 parent 08d3b95 commit 780e664
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 38 deletions.
8 changes: 7 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Release history
- Deprecated
- Removed
3.0.1 (unreleased)
3.1.0 (unreleased)
------------------

**Added**
Expand All @@ -41,6 +41,10 @@ Release history
``nengo_dl.configure_settings(dtype=...)`` config option. Note that this will
override the default precision set in ``nengo.rc``. (`#119`_)
- Minimum Numpy version is now 1.16.0 (required by TensorFlow). (`#119`_)
- Added support for the new ``transform=None`` default in Nengo connections
(see `Nengo#1591`_). Note that this may change the number of trainable
parameters in a network as the scalar default ``transform=1`` weights on
non-Ensemble connections will no longer be present. (`#128`_)

**Fixed**

Expand All @@ -55,7 +59,9 @@ Release history
- Fixed compatibility with ``progressbar2`` version 3.50.0. (`#136`_)

.. _#119: https://github.com/nengo/nengo-dl/pull/119
.. _#128: https://github.com/nengo/nengo-dl/pull/128
.. _#136: https://github.com/nengo/nengo-dl/pull/136
.. _Nengo#1591: https://github.com/nengo/nengo/pull/1591

3.0.0 (December 17, 2019)
-------------------------
Expand Down
7 changes: 6 additions & 1 deletion nengo_dl/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf
from tensorflow.python.eager import context

from nengo_dl import utils
from nengo_dl import compat, utils


class NengoSummaries(tf.keras.callbacks.Callback):
Expand Down Expand Up @@ -61,6 +61,11 @@ def __init__(self, log_dir, sim, objects):
param = "bias"
name = "Ensemble.neurons_%s" % obj.ensemble.label
elif isinstance(obj, nengo.Connection):
if not compat.conn_has_weights(obj):
raise ValidationError(
"Connection '%s' does not have any weights to log" % obj,
"objects",
)
param = "weights"
name = "Connection_%s" % obj.label

Expand Down
18 changes: 18 additions & 0 deletions nengo_dl/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
dependencies.
"""

from distutils.version import LooseVersion

import nengo
from nengo._vendor.scipy.sparse import linalg_interface, linalg_onenormest
import tensorflow as tf

Expand Down Expand Up @@ -75,3 +78,18 @@ def filter(self, record):

# monkeypatch fix for https://github.com/nengo/nengo/pull/1587
linalg_onenormest.aslinearoperator = linalg_interface.aslinearoperator

if LooseVersion(nengo.__version__) < "3.1.0":
default_transform = 1

def conn_has_weights(conn):
"""All connections have weights."""
return True


else:
default_transform = None

def conn_has_weights(conn):
"""Equivalent to conn.has_weights."""
return conn.has_weights
2 changes: 1 addition & 1 deletion nengo_dl/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Regularize(tf.losses.Loss):
with nengo.Network() as net:
a = nengo.Node([0])
b = nengo.Node(size_in=1)
c = nengo.Connection(a, b)
c = nengo.Connection(a, b, transform=1)
p = nengo.Probe(c, "weights")
...
Expand Down
17 changes: 13 additions & 4 deletions nengo_dl/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import tensorflow as tf
from tensorflow.python.keras import backend

from nengo_dl import utils, config, callbacks
from nengo_dl import callbacks, compat, config, utils
from nengo_dl.builder import NengoBuilder, NengoModel
from nengo_dl.tensor_graph import TensorGraph

Expand Down Expand Up @@ -1346,7 +1346,8 @@ def get_nengo_params(self, nengo_objs, as_dict=False):
fetches = []
for obj in nengo_objs:
if isinstance(obj, Connection):
fetches.append((obj, "weights"))
if compat.conn_has_weights(obj):
fetches.append((obj, "weights"))
elif isinstance(obj, Ensemble):
if isinstance(obj.neuron_type, Direct):
# we cannot transfer direct ensemble parameters, because
Expand All @@ -1372,6 +1373,10 @@ def get_nengo_params(self, nengo_objs, as_dict=False):
idx = 0
for obj in nengo_objs:
if isinstance(obj, Connection):
if not compat.conn_has_weights(obj):
params.append({"transform": None})
continue

weights = data[idx]
idx += 1
if isinstance(obj.pre_obj, Ensemble):
Expand All @@ -1381,7 +1386,7 @@ def get_nengo_params(self, nengo_objs, as_dict=False):
"function": lambda x, weights=weights: np.zeros(
weights.shape[0]
),
"transform": 1,
"transform": compat.default_transform,
}
)
elif isinstance(obj.transform, Convolution):
Expand Down Expand Up @@ -2104,7 +2109,11 @@ def __getitem__(self, obj):
)
elif isinstance(obj, Connection):
# get the live simulation values
weights = self.get_params((obj, "weights"))[0]
weights = (
self.get_params((obj, "weights"))[0]
if compat.conn_has_weights(obj)
else None
)

# impossible to recover transform
transform = None
Expand Down
24 changes: 17 additions & 7 deletions nengo_dl/tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@
from tensorflow.python.eager import context
from tensorflow.python.training.tracking import base as trackable

from nengo_dl import builder, graph_optimizer, signals, utils, tensor_node, config
from nengo_dl import (
builder,
config,
compat,
graph_optimizer,
tensor_node,
signals,
utils,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -856,10 +864,11 @@ def mark_network(parent_configs, net):
for conn in net.connections:
# note: this doesn't include probe connections, since they
# aren't added to the network
self.model.sig[conn]["weights"].trainable = get_trainable(
parent_configs, conn
)
self.model.sig[conn]["weights"].minibatched = False
if compat.conn_has_weights(conn):
self.model.sig[conn]["weights"].trainable = get_trainable(
parent_configs, conn
)
self.model.sig[conn]["weights"].minibatched = False

# parameters can't be modified by an online Nengo learning rule
# and offline training at the same time. (it is possible in
Expand Down Expand Up @@ -909,8 +918,9 @@ def mark_network(parent_configs, net):
probe_seeds = [self.model.seeds[p] for p in self.model.probes]
for obj, seed in self.model.seeds.items():
if isinstance(obj, Connection) and seed in probe_seeds:
self.model.sig[obj]["weights"].trainable = False
self.model.sig[obj]["weights"].minibatched = False
if compat.conn_has_weights(obj):
self.model.sig[obj]["weights"].trainable = False
self.model.sig[obj]["weights"].minibatched = False

# time/step are not minibatched and not trainable
self.model.step.trainable = False
Expand Down
3 changes: 2 additions & 1 deletion nengo_dl/tensor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tensorflow as tf

from nengo_dl.builder import Builder, OpBuilder, NengoBuilder
from nengo_dl.compat import default_transform
from nengo_dl.config import configure_settings


Expand Down Expand Up @@ -403,7 +404,7 @@ def __init__(self, layer_func):
def __call__(
self,
input,
transform=1,
transform=default_transform,
shape_in=None,
synapse=None,
return_conn=False,
Expand Down
2 changes: 1 addition & 1 deletion nengo_dl/tests/dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def linear_net():
with nengo.Network() as net:
a = nengo.Node([1])
b = nengo.Node(size_in=1)
nengo.Connection(a, b, synapse=None)
nengo.Connection(a, b, synapse=None, transform=1)
p = nengo.Probe(b)

return net, a, p
11 changes: 5 additions & 6 deletions nengo_dl/tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def test_lmu(Simulator, native_nengo, pytestconfig):
[
(benchmarks.cconv(128, 64, nengo.RectifiedLinear()), False, 64, 0.65, 0.8),
(benchmarks.cconv(128, 64, nengo.LIF()), False, 64, 1.45, 1.65),
(benchmarks.integrator(128, 32, nengo.RectifiedLinear()), True, 64, 0.5, 0.75),
(benchmarks.integrator(128, 32, nengo.LIF()), True, 64, 0.9, 1.2),
(benchmarks.integrator(128, 32, nengo.RectifiedLinear()), True, 64, 0.6, 1.0),
(benchmarks.integrator(128, 32, nengo.LIF()), True, 64, 1.1, 1.4),
(
benchmarks.random_network(
64,
Expand All @@ -208,8 +208,7 @@ def test_lmu(Simulator, native_nengo, pytestconfig):
0.35,
0.55,
),
(benchmarks.lmu(1000, 1, native_nengo=True), True, 100, 0.75, 1.05),
# (benchmarks.spaun(1), False, None, 8.02, 9.52),
(benchmarks.lmu(1000, 1, native_nengo=True), True, 100, 0.85, 1.15),
],
)
def test_performance(net, train, minibatch_size, min, max):
Expand All @@ -218,8 +217,8 @@ def test_performance(net, train, minibatch_size, min, max):
# GPU: GeForce GTX Titan X
# Python version: 3.6.8
# TensorFlow GPU version: 2.0.0
# Nengo version: 3.0.0
# NengoDL version: 3.0.0
# Nengo version: 3.1.0
# NengoDL version: 3.1.0

time = benchmarks.run_profile(
net,
Expand Down
4 changes: 2 additions & 2 deletions nengo_dl/tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ def test_fit(Simulator, seed):
inp_a = nengo.Node([0])
inp_b = nengo.Node([0])
inp = nengo.Node(size_in=2)
nengo.Connection(inp_a, inp[0])
nengo.Connection(inp_b, inp[1])
nengo.Connection(inp_a, inp[0], transform=1)
nengo.Connection(inp_b, inp[1], transform=1)

ens = nengo.Ensemble(
n_hidden + 1, n_hidden, neuron_type=nengo.Sigmoid(tau_ref=1)
Expand Down
26 changes: 16 additions & 10 deletions nengo_dl/tests/test_simulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=missing-docstring

from collections import OrderedDict
from distutils.version import LooseVersion
import logging
import os
import pickle
Expand All @@ -20,7 +21,7 @@
from tensorflow.core.util import event_pb2

from nengo_dl import Layer, TensorNode, callbacks, configure_settings, dists, utils
from nengo_dl.compat import TFLogFilter
from nengo_dl.compat import TFLogFilter, default_transform
from nengo_dl.simulator import SimulationData
from nengo_dl.tests import dummies

Expand Down Expand Up @@ -167,8 +168,8 @@ def test_train_ff(Simulator, neurons, use_loop, seed):
inp_a = nengo.Node([0])
inp_b = nengo.Node([0])
inp = nengo.Node(size_in=2)
nengo.Connection(inp_a, inp[0])
nengo.Connection(inp_b, inp[1])
nengo.Connection(inp_a, inp[0], transform=1)
nengo.Connection(inp_b, inp[1], transform=1)

ens = nengo.Ensemble(
n_hidden + 1, n_hidden, neuron_type=nengo.Sigmoid(tau_ref=1)
Expand Down Expand Up @@ -300,11 +301,11 @@ def test_train_objective(Simulator, unroll, seed):
inp = nengo.Node([1])

ens = nengo.Ensemble(n_hidden, 1, neuron_type=nengo.RectifiedLinear())
nengo.Connection(inp, ens, synapse=0.01)
nengo.Connection(inp, ens, synapse=0.01, transform=1)
p = nengo.Probe(ens)

ens2 = nengo.Ensemble(n_hidden, 1, neuron_type=nengo.RectifiedLinear())
nengo.Connection(inp, ens2, synapse=0.01)
nengo.Connection(inp, ens2, synapse=0.01, transform=1)
p2 = nengo.Probe(ens2)

with Simulator(
Expand Down Expand Up @@ -688,7 +689,8 @@ def test_tensorboard(Simulator, tmpdir):
with nengo.Network() as net:
a = nengo.Node([0])
b = nengo.Ensemble(10, 1, neuron_type=nengo.LIFRate())
c = nengo.Connection(a, b)
c = nengo.Connection(a, b, transform=1)
c0 = nengo.Connection(a, b)
p = nengo.Probe(b)
p2 = nengo.Probe(c)

Expand Down Expand Up @@ -768,6 +770,10 @@ def test_tensorboard(Simulator, tmpdir):
with pytest.raises(ValidationError, match="Unknown summary object"):
callbacks.NengoSummaries(log_dir=log_dir + "/nengo", sim=sim, objects=[a])

if LooseVersion(nengo.__version__) >= "3.1.0":
with pytest.raises(ValidationError, match="does not have any weights"):
callbacks.NengoSummaries(log_dir=log_dir + "/nengo", sim=sim, objects=[c0])


@pytest.mark.parametrize("mode", ("predict", "train"))
@pytest.mark.training
Expand Down Expand Up @@ -1259,7 +1265,7 @@ def test_get_nengo_params(Simulator, seed):

# check that single objects are returned as single dicts
params = sim.get_nengo_params(d)
assert params["transform"] == 1
assert params["transform"] is default_transform

fetches = [a.neurons, b, c, d, e, h]

Expand Down Expand Up @@ -1290,7 +1296,7 @@ def test_direct_grads(Simulator, mixed):
if mixed:
with net:
c = nengo.Node(size_in=1)
nengo.Connection(a, c, synapse=None)
nengo.Connection(a, c, synapse=None, transform=1)
p2 = nengo.Probe(c)

with Simulator(net, minibatch_size=1) as sim:
Expand Down Expand Up @@ -1328,7 +1334,7 @@ def test_non_differentiable(Simulator):
with nengo.Network() as net:
a = nengo.Node([0])
b = nengo.Node(lambda t, x: x, size_in=1)
c = nengo.Connection(a, b)
c = nengo.Connection(a, b, transform=1)
p = nengo.Probe(b)

with Simulator(net) as sim:
Expand Down Expand Up @@ -1429,7 +1435,7 @@ def test_inference_only(Simulator, neuron_type, seed):

a = nengo.Node([0])
b = nengo.Ensemble(10, 1, neuron_type=neuron_type)
c = nengo.Connection(a, b, synapse=None)
c = nengo.Connection(a, b, synapse=None, transform=1)
p = nengo.Probe(b)

with Simulator(net) as sim:
Expand Down
12 changes: 10 additions & 2 deletions nengo_dl/tests/test_tensor_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=missing-docstring

from distutils.version import LooseVersion
from functools import partial

import nengo
Expand Down Expand Up @@ -261,12 +262,19 @@ def call(self, x):
assert np.allclose(sim.data[p2], 3)

# note: when inference-only=True the weights will be marked as non-trainable

default_conn_params = 2 if LooseVersion(nengo.__version__) < "3.1.0" else 0

if sim.tensor_graph.inference_only:
assert len(sim.keras_model.non_trainable_variables) == 10
assert (
len(sim.keras_model.non_trainable_variables) == 8 + default_conn_params
)
assert len(sim.keras_model.trainable_variables) == 0
vars = sim.keras_model.non_trainable_variables[-2:]
else:
assert len(sim.keras_model.non_trainable_variables) == 8
assert (
len(sim.keras_model.non_trainable_variables) == 6 + default_conn_params
)
assert len(sim.keras_model.trainable_variables) == 2
vars = sim.keras_model.trainable_variables

Expand Down
Loading

0 comments on commit 780e664

Please sign in to comment.