-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't include saved state in keras_model.weights #140
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,6 @@ | |
from nengo.utils.magic import decorator | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow.python.keras import backend | ||
|
||
from nengo_dl import callbacks, compat, config, utils | ||
from nengo_dl.builder import NengoBuilder, NengoModel | ||
|
@@ -550,9 +549,7 @@ def _build_keras(self, progress=None): | |
inputs, | ||
stateful=self.stateful, | ||
# if the global learning phase is set, use that | ||
training=backend._GRAPH_LEARNING_PHASES.get( | ||
backend._DUMMY_EAGER_GRAPH, None | ||
), | ||
training=compat.global_learning_phase(), | ||
progress=progress, | ||
) | ||
|
||
|
@@ -614,8 +611,8 @@ def soft_reset(self, include_trainable=False, include_probes=False): | |
Parameters | ||
---------- | ||
include_trainable : bool | ||
If True, also reset any training that has been performed on | ||
simulator parameters (e.g., connection weights). | ||
If True, also reset any online or offline training that has been performed | ||
on simulator parameters (e.g., connection weights). | ||
include_probes : bool | ||
If True, also clear probe data. | ||
|
||
|
@@ -1160,18 +1157,19 @@ def loss(self, *args, **kwargs): | |
|
||
@require_open | ||
@with_self | ||
def save_params(self, path, include_non_trainable=False): | ||
def save_params(self, path, include_state=False, include_non_trainable=None): | ||
""" | ||
Save network parameters to the given ``path``. | ||
|
||
Parameters | ||
---------- | ||
path : str | ||
Filepath of parameter output file. | ||
include_non_trainable : bool | ||
If True (default False) also save information representing | ||
non-trainable parameters of the network (this includes the internal | ||
simulation state). | ||
include_state : bool | ||
If True (default False) also save the internal simulation state. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would opt for still having a line for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at |
||
|
||
.. versionchanged:: 3.1.1 | ||
Renamed from ``include_non_trainable`` to ``include_state``. | ||
|
||
Notes | ||
----- | ||
|
@@ -1180,30 +1178,36 @@ def save_params(self, path, include_non_trainable=False): | |
`.get_nengo_params`. | ||
""" | ||
|
||
vars = ( | ||
self.keras_model.weights | ||
if include_non_trainable | ||
else self.keras_model.trainable_weights | ||
) | ||
if include_non_trainable is not None: | ||
warnings.warn( | ||
"include_non_trainable is deprecated, use include_state instead", | ||
DeprecationWarning, | ||
) | ||
include_state = include_non_trainable | ||
|
||
params = list(self.keras_model.weights) | ||
if include_state: | ||
params.extend(self.tensor_graph.saved_state.values()) | ||
|
||
np.savez_compressed(path + ".npz", *tf.keras.backend.batch_get_value(vars)) | ||
np.savez_compressed(path + ".npz", *tf.keras.backend.batch_get_value(params)) | ||
|
||
logger.info("Model parameters saved to %s.npz", path) | ||
|
||
@require_open | ||
@with_self | ||
def load_params(self, path, include_non_trainable=False): | ||
def load_params(self, path, include_state=False, include_non_trainable=None): | ||
""" | ||
Load network parameters from the given ``path``. | ||
|
||
Parameters | ||
---------- | ||
path : str | ||
Filepath of parameter input file. | ||
include_non_trainable : bool | ||
If True (default False) also load information representing | ||
non-trainable parameters of the network (this includes the internal | ||
simulation state). | ||
include_state : bool | ||
If True (default False) also save the internal simulation state. | ||
hunse marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
.. versionchanged:: 3.1.1 | ||
Renamed from ``include_non_trainable`` to ``include_state``. | ||
|
||
Notes | ||
----- | ||
|
@@ -1212,20 +1216,25 @@ def load_params(self, path, include_non_trainable=False): | |
`.get_nengo_params`. | ||
""" | ||
|
||
vars = ( | ||
self.keras_model.weights | ||
if include_non_trainable | ||
else self.keras_model.trainable_weights | ||
) | ||
if include_non_trainable is not None: | ||
warnings.warn( | ||
"include_non_trainable is deprecated, use include_state instead", | ||
DeprecationWarning, | ||
) | ||
include_state = include_non_trainable | ||
|
||
params = list(self.keras_model.weights) | ||
if include_state: | ||
params.extend(self.tensor_graph.saved_state.values()) | ||
|
||
with np.load(path + ".npz") as vals: | ||
if len(vars) != len(vals.files): | ||
if len(params) != len(vals.files): | ||
raise SimulationError( | ||
"Number of saved parameters in %s (%d) != number of variables in " | ||
"the model (%d)" % (path, len(vals.files), len(vars)) | ||
"the model (%d)" % (path, len(vals.files), len(params)) | ||
) | ||
tf.keras.backend.batch_set_value( | ||
zip(vars, (vals["arr_%d" % i] for i in range(len(vals.files)))) | ||
zip(params, (vals["arr_%d" % i] for i in range(len(vals.files)))) | ||
) | ||
|
||
logger.info("Model parameters loaded from %s.npz", path) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a change in function that this can't be set on networks, or just a fix to the documentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a fix for the documentation (setting it on networks was removed a while ago, I just missed this docstring).