Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't include saved state in keras_model.weights #140

Merged
merged 3 commits into from
Mar 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,15 @@ jobs:
distributions: "sdist "
on:
all_branches: true
tags: false
condition: $TRAVIS_BRANCH =~ ^release-candidate-*
condition: $TRAVIS_TAG = ""
- provider: pypi
user: drasmuss
password: $PYPI_TOKEN
distributions: "sdist "
on:
all_branches: true
tags: true
condition: $TRAVIS_TAG =~ ^v[0-9]*

before_install:
Expand Down
25 changes: 25 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,31 @@ Release history
3.1.1 (unreleased)
------------------

**Added**

- Compatible with TensorFlow 2.2.0. (`#140`_)

**Changed**

- Saved simulator state will no longer be included in ``Simulator.keras_model.weights``.
This means that ``Simulator.keras_model.save/load_weights`` will not include the
saved simulator state, making it easier to reuse weights between models (as long as
the models have the same weights, they do not need to have the same state variables).
``Simulator.save/load_params(..., include_state=True)`` can be used to explicitly
save the simulator state, if desired. (`#140`_)
- Model parameters (e.g., connection weights) that are not trainable (because they've
been marked non-trainable by user or targeted by an online learning rule) will now
be treated separately from simulator state. For example,
``Simulator.save_params(..., include_state=False)`` will still include those
parameters, and the results of any online learning will persist between calls even
with ``stateful=False``. (`#140`_)

**Deprecated**

- Renamed ``Simulator.save/load_params`` ``include_non_trainable`` parameter to
``include_state``. (`#140`_)

.. _#140: https://github.com/nengo/nengo-dl/pull/140

3.1.0 (March 4, 2020)
---------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/lmu.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@
" else:\n",
" urlretrieve(\n",
" \"https://drive.google.com/uc?export=download&\"\n",
" \"id=1qhKmpnipk8AK2bsilPlJBFQg1t7XAEZr\",\n",
" \"id=1epcfVDdUaHkwNo1kD4kjIF7qlXgJmb2i\",\n",
" \"lmu_params.npz\")\n",
" sim.load_params(\"./lmu_params\")\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/spa-memory.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@
" # download pretrained parameters\n",
" urlretrieve(\n",
" \"https://drive.google.com/uc?export=download&\"\n",
" \"id=1aspi_dayS37Rx_IvDuRrXoybdIC_-tkn\",\n",
" \"id=1Ym44i2sBLbNUiNgJaP3l1Obhf_NFenU7\",\n",
" \"mem_binding_params.npz\")"
]
},
Expand Down
33 changes: 33 additions & 0 deletions nengo_dl/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import nengo
from nengo._vendor.scipy.sparse import linalg_interface, linalg_onenormest
import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import network

# TensorFlow compatibility

Expand Down Expand Up @@ -74,6 +76,37 @@ def filter(self, record):

tf.get_logger().addFilter(TFLogFilter(err_on_deprecation=False))

if LooseVersion(tf.__version__) < "2.2.0":

def global_learning_phase():
"""Returns the global (eager) Keras learning phase."""

return backend._GRAPH_LEARNING_PHASES.get(backend._DUMMY_EAGER_GRAPH, None)


else:

def global_learning_phase():
"""Returns the global (eager) Keras learning phase."""

return backend._GRAPH_LEARNING_PHASES.get(backend._DUMMY_EAGER_GRAPH.key, None)

# monkeypatch to fix bug in TF2.2, see
# https://github.com/tensorflow/tensorflow/issues/37548
old_conform = network.Network._conform_to_reference_input

def _conform_to_reference_input(self, tensor, ref_input):
keras_history = getattr(tensor, "_keras_history", None)

tensor = old_conform(self, tensor, ref_input)

if keras_history is not None:
tensor._keras_history = keras_history

return tensor

network.Network._conform_to_reference_input = _conform_to_reference_input

# Nengo compatibility

# monkeypatch fix for https://github.com/nengo/nengo/pull/1587
Expand Down
2 changes: 1 addition & 1 deletion nengo_dl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def configure_settings(**kwargs):
Parameters
----------
trainable : bool or None
Adds a parameter to Nengo Ensembles/Connections/Networks that controls
Adds a parameter to Nengo Ensembles/Connections that controls
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a change in function that this can't be set on networks, or just a fix to the documentation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a fix for the documentation (setting it on networks was removed a while ago, I just missed this docstring).

whether or not they will be optimized by `.Simulator.fit`.
Passing ``None`` will use the default ``nengo_dl`` trainable settings,
or True/False will override the default for all objects. In either
Expand Down
13 changes: 6 additions & 7 deletions nengo_dl/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def get_history(tensor):
"""
Returns the Keras history (layer/node_idx/tensor_idx) that defined this tensor.

This function contains additional logic so that if ``tensor`` is the output of
This function contains additional logic so that if ``tensor`` is associated with
a Model then the history will trace into the internal layers of that Model
(rather than skipping to the input of that Model, which is the default Keras
history).
Expand All @@ -735,11 +735,9 @@ def get_history(tensor):
layer, node_index, tensor_index = tensor._keras_history

while isinstance(layer, tf.keras.Model):
# models have an output Identity transform that stores the history that
# "skips" the internals of the model; we want to traverse into the internals
# of the model, so we go back to the input of that identity op (which
# is the real output tensor from the model)
assert tensor.op.type == "Identity"
# models may have internal tensors representing some transforms on the
# input/output of the model. but we want to know the actual internal layer
# that generated this tensor, so we skip past these "model" tensors
tensor = tensor.op.inputs[0]
layer, node_index, tensor_index = tensor._keras_history

Expand Down Expand Up @@ -877,7 +875,8 @@ def trace_tensors(self, tensors, results=None):
-------
results : list of ``tf.Tensor``
The same as the ``results`` parameter (returned so that the top-level call,
which may not have a reference to the ``results`` list can get the results).
which may not have a reference to the ``results`` list, can get the
results).
"""
# brief intro to the keras functional graph structure:
# - a node represents the application of some layer to an input tensor
Expand Down
69 changes: 39 additions & 30 deletions nengo_dl/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from nengo.utils.magic import decorator
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import backend

from nengo_dl import callbacks, compat, config, utils
from nengo_dl.builder import NengoBuilder, NengoModel
Expand Down Expand Up @@ -550,9 +549,7 @@ def _build_keras(self, progress=None):
inputs,
stateful=self.stateful,
# if the global learning phase is set, use that
training=backend._GRAPH_LEARNING_PHASES.get(
backend._DUMMY_EAGER_GRAPH, None
),
training=compat.global_learning_phase(),
progress=progress,
)

Expand Down Expand Up @@ -614,8 +611,8 @@ def soft_reset(self, include_trainable=False, include_probes=False):
Parameters
----------
include_trainable : bool
If True, also reset any training that has been performed on
simulator parameters (e.g., connection weights).
If True, also reset any online or offline training that has been performed
on simulator parameters (e.g., connection weights).
include_probes : bool
If True, also clear probe data.

Expand Down Expand Up @@ -1160,18 +1157,19 @@ def loss(self, *args, **kwargs):

@require_open
@with_self
def save_params(self, path, include_non_trainable=False):
def save_params(self, path, include_state=False, include_non_trainable=None):
"""
Save network parameters to the given ``path``.

Parameters
----------
path : str
Filepath of parameter output file.
include_non_trainable : bool
If True (default False) also save information representing
non-trainable parameters of the network (this includes the internal
simulation state).
include_state : bool
If True (default False) also save the internal simulation state.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would opt for still having a line for include_non_trainable, that basically says it's deprecated and equivalent to include_state.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at nengo.Simulator.trange and the switch in parameters there, we don't have documentation for both the old and new names, but we do have a "version changed" tag, so I've opted for that instead.


.. versionchanged:: 3.1.1
Renamed from ``include_non_trainable`` to ``include_state``.

Notes
-----
Expand All @@ -1180,30 +1178,36 @@ def save_params(self, path, include_non_trainable=False):
`.get_nengo_params`.
"""

vars = (
self.keras_model.weights
if include_non_trainable
else self.keras_model.trainable_weights
)
if include_non_trainable is not None:
warnings.warn(
"include_non_trainable is deprecated, use include_state instead",
DeprecationWarning,
)
include_state = include_non_trainable

params = list(self.keras_model.weights)
if include_state:
params.extend(self.tensor_graph.saved_state.values())

np.savez_compressed(path + ".npz", *tf.keras.backend.batch_get_value(vars))
np.savez_compressed(path + ".npz", *tf.keras.backend.batch_get_value(params))

logger.info("Model parameters saved to %s.npz", path)

@require_open
@with_self
def load_params(self, path, include_non_trainable=False):
def load_params(self, path, include_state=False, include_non_trainable=None):
"""
Load network parameters from the given ``path``.

Parameters
----------
path : str
Filepath of parameter input file.
include_non_trainable : bool
If True (default False) also load information representing
non-trainable parameters of the network (this includes the internal
simulation state).
include_state : bool
If True (default False) also save the internal simulation state.
hunse marked this conversation as resolved.
Show resolved Hide resolved

.. versionchanged:: 3.1.1
Renamed from ``include_non_trainable`` to ``include_state``.

Notes
-----
Expand All @@ -1212,20 +1216,25 @@ def load_params(self, path, include_non_trainable=False):
`.get_nengo_params`.
"""

vars = (
self.keras_model.weights
if include_non_trainable
else self.keras_model.trainable_weights
)
if include_non_trainable is not None:
warnings.warn(
"include_non_trainable is deprecated, use include_state instead",
DeprecationWarning,
)
include_state = include_non_trainable

params = list(self.keras_model.weights)
if include_state:
params.extend(self.tensor_graph.saved_state.values())

with np.load(path + ".npz") as vals:
if len(vars) != len(vals.files):
if len(params) != len(vals.files):
raise SimulationError(
"Number of saved parameters in %s (%d) != number of variables in "
"the model (%d)" % (path, len(vals.files), len(vars))
"the model (%d)" % (path, len(vals.files), len(params))
)
tf.keras.backend.batch_set_value(
zip(vars, (vals["arr_%d" % i] for i in range(len(vals.files))))
zip(params, (vals["arr_%d" % i] for i in range(len(vals.files))))
)

logger.info("Model parameters loaded from %s.npz", path)
Expand Down
Loading