-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't include saved state in keras_model.weights #140
Conversation
6fc1f6e
to
3380717
Compare
@@ -22,7 +22,7 @@ def configure_settings(**kwargs): | |||
Parameters | |||
---------- | |||
trainable : bool or None | |||
Adds a parameter to Nengo Ensembles/Connections/Networks that controls | |||
Adds a parameter to Nengo Ensembles/Connections that controls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a change in function that this can't be set on networks, or just a fix to the documentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a fix for the documentation (setting it on networks was removed a while ago, I just missed this docstring).
non-trainable parameters of the network (this includes the internal | ||
simulation state). | ||
include_state : bool | ||
If True (default False) also save the internal simulation state. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would opt for still having a line for include_non_trainable
, that basically says it's deprecated and equivalent to include_state
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at nengo.Simulator.trange
and the switch in parameters there, we don't have documentation for both the old and new names, but we do have a "version changed" tag, so I've opted for that instead.
nengo_dl/simulator.py
Outdated
|
||
vars = self.keras_model.weights | ||
if include_state: | ||
vars.extend(self.tensor_graph.saved_state.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will modify self.keras_model.weights
, right? Is that a problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good call probably safer to do this on a copy
nengo_dl/tensor_graph.py
Outdated
"trainable": OrderedDict(), | ||
"non_trainable": OrderedDict(), | ||
"state": OrderedDict(), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make base_arrays
an ordered dict just so it's deterministic when we call .items
on it (right now, that's just in a logger message, but it's nice to have things consistent IMO).
CHANGES.rst
Outdated
- Model parameters (e.g., connection weights) that are not trainable (because they've | ||
been marked non-trainable by user or targeted by an online learning rule) will now | ||
be treated separately from simulator state. For example, resetting the simulator | ||
state will not reset those parameters, and the results of any online learning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't calling sim.reset
reset online learning, though (as per test_online_learning_reset
)? It's just sim.soft_reset
with include_trainable=False
that doesn't reset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was intending "resetting the simulator state" to refer specifically to include_trainable=False
, but that may not be clear.
|
||
sim.reset() | ||
|
||
assert np.allclose(w0, sim.data[conn].weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to group the lines a bit and add some short comments just making it clear what we expect for each case.
assert np.allclose(sim_load.data[p1], sim_save.data[p0][10:]) | ||
else: | ||
assert not np.allclose(sim_load.data[p1], sim_save.data[p0][10:]) | ||
assert np.allclose(sim_load.data[p1], sim_save.data[p0][:10]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NTS: add a comment about this (took a minute of thinking to figure it out)
Done with my review, and pushed a fixup commit making the changes I've suggested above. If it looks good to you @drasmuss, and passes CI, I'll merge. EDIT: Oh, I did have one question above about the changelog, basically wondering if it's slightly unclear about when online learning weights are reset. |
Fixups all look good to me, I added a clarification to the changelog, if that looks good to you go ahead and merge! |
abc9c16
to
4d64291
Compare
This makes it easier to re-use saved weights between models, as models with different saved state variables can still use the same saved Keras parameters.
Fixes #133