Skip to content

Commit

Permalink
Add support for TensorFlow 2.5
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Apr 16, 2021
1 parent 0e1fb4a commit 9ecd7e8
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
7 changes: 6 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ Release history

*Compatible with Nengo 3.0.0 - 3.2.0*

*Compatible with TensorFlow 2.2.0 - 2.4.0*
*Compatible with TensorFlow 2.2.0 - 2.5.0*

**Added**

- Added support for TensorFlow 2.5.0. (`#212`_)

**Fixed**

Expand All @@ -37,6 +41,7 @@ Release history

.. _#184: https://github.com/nengo/nengo-dl/pull/184
.. _#199: https://github.com/nengo/nengo-dl/pull/199
.. _#212: https://github.com/nengo/nengo-dl/pull/212

3.4.0 (November 26, 2020)
-------------------------
Expand Down
13 changes: 13 additions & 0 deletions nengo_dl/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ def _conform_to_reference_input(self, tensor, ref_input):

network.Network._conform_to_reference_input = _conform_to_reference_input

if version.parse(tf.__version__) < version.parse("2.5.0rc0"):

def sub_layers(layer):
"""Get layers contained in ``layer``."""
return layer._layers


else:

def sub_layers(layer):
"""Get layers contained in ``layer``."""
return layer._self_tracked_trackables


# Nengo compatibility

Expand Down
6 changes: 3 additions & 3 deletions nengo_dl/tensor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def unbuild(layer):

layer.built = False

for sub in layer._layers:
for sub in compat.sub_layers(layer):
if isinstance(sub, tf.keras.layers.Layer):
unbuild(sub)

Expand All @@ -343,7 +343,7 @@ def unbuild(layer):
weight_gets = []
weight_sets = []
for op in layer_ops:
if op.func in self._layers:
if op.func in compat.sub_layers(self):
# already built this layer
continue

Expand Down Expand Up @@ -378,7 +378,7 @@ def unbuild(layer):
weight_sets.extend(op.func.weights)

# add op func to _layers so that any weights are collected
self._layers.append(op.func)
compat.sub_layers(self).append(op.func)

if len(weight_gets) > 0:
# do all the weight getting/setting in one go, for efficiency reasons
Expand Down

0 comments on commit 9ecd7e8

Please sign in to comment.