From 9ecd7e83ffc2ba2756c657d379487bc71e823492 Mon Sep 17 00:00:00 2001 From: Daniel Rasmussen Date: Fri, 16 Apr 2021 13:28:05 -0300 Subject: [PATCH] Add support for TensorFlow 2.5 --- CHANGES.rst | 7 ++++++- nengo_dl/compat.py | 13 +++++++++++++ nengo_dl/tensor_graph.py | 6 +++--- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 7b08a40d1..e0c8d21b3 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -23,7 +23,11 @@ Release history *Compatible with Nengo 3.0.0 - 3.2.0* -*Compatible with TensorFlow 2.2.0 - 2.4.0* +*Compatible with TensorFlow 2.2.0 - 2.5.0* + +**Added** + +- Added support for TensorFlow 2.5.0. (`#212`_) **Fixed** @@ -37,6 +41,7 @@ Release history .. _#184: https://github.com/nengo/nengo-dl/pull/184 .. _#199: https://github.com/nengo/nengo-dl/pull/199 +.. _#212: https://github.com/nengo/nengo-dl/pull/212 3.4.0 (November 26, 2020) ------------------------- diff --git a/nengo_dl/compat.py b/nengo_dl/compat.py index 1ec222b9f..4a7140a6e 100644 --- a/nengo_dl/compat.py +++ b/nengo_dl/compat.py @@ -170,6 +170,19 @@ def _conform_to_reference_input(self, tensor, ref_input): network.Network._conform_to_reference_input = _conform_to_reference_input +if version.parse(tf.__version__) < version.parse("2.5.0rc0"): + + def sub_layers(layer): + """Get layers contained in ``layer``.""" + return layer._layers + + +else: + + def sub_layers(layer): + """Get layers contained in ``layer``.""" + return layer._self_tracked_trackables + # Nengo compatibility diff --git a/nengo_dl/tensor_graph.py b/nengo_dl/tensor_graph.py index 98feff96a..b698cf807 100644 --- a/nengo_dl/tensor_graph.py +++ b/nengo_dl/tensor_graph.py @@ -329,7 +329,7 @@ def unbuild(layer): layer.built = False - for sub in layer._layers: + for sub in compat.sub_layers(layer): if isinstance(sub, tf.keras.layers.Layer): unbuild(sub) @@ -343,7 +343,7 @@ def unbuild(layer): weight_gets = [] weight_sets = [] for op in layer_ops: - if op.func in self._layers: + if op.func in compat.sub_layers(self): # already built this layer continue @@ -378,7 +378,7 @@ def unbuild(layer): weight_sets.extend(op.func.weights) # add op func to _layers so that any weights are collected - self._layers.append(op.func) + compat.sub_layers(self).append(op.func) if len(weight_gets) > 0: # do all the weight getting/setting in one go, for efficiency reasons