Skip to content

Commit

Permalink
[FRONTEND][Keras] Add support for tf.Keras networks in Relay Keras fr…
Browse files Browse the repository at this point in the history
…ontend (apache#4630)

* Make Relay Keras frontend support networks created using
   Tensorflow (1.13) Keras implementation (tf.Keras)
 * Modify Keras frontend tests to run from a class rather than a
   function based script
 * Adjust Keras frontend tests to run with both 'Keras' and 'tf.Keras'
 * Change "TestKeras.test_forward_merge" to validate instances by
   class name rather than instance type
  • Loading branch information
leandron authored and zhiics committed Mar 2, 2020
1 parent 01335d5 commit ffdfb3e
Show file tree
Hide file tree
Showing 2 changed files with 348 additions and 296 deletions.
57 changes: 40 additions & 17 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,9 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
'Concatenate' : _convert_concat,
'BatchNormalization' : _convert_batchnorm,

# Specific tf.Keras terminology for batch normalization
'BatchNormalizationV1' : _convert_batchnorm,

'Add' : _convert_merge,
'Subtract' : _convert_merge,
'Multiply' : _convert_merge,
Expand Down Expand Up @@ -742,7 +745,7 @@ def from_keras(model, shape=None):
Parameters
----------
model : keras.engine.training.Model
model : keras.engine.training.Model or tensorflow.keras.models.Model
The keras model to be converted.
shape: dict of str to int list/tuple
Expand All @@ -756,25 +759,42 @@ def from_keras(model, shape=None):
params : dict of str to tvm.NDArray
The parameter dict to be used by Relay.
"""
try:
import keras
except ImportError:
raise ImportError('Keras must be installed')
assert isinstance(model, keras.engine.training.Model)
if keras.backend.backend() != 'tensorflow':
raise ValueError("Keras frontend currently supports tensorflow backend only.")
if keras.backend.image_data_format() != 'channels_last':
raise ValueError("Keras frontend currently supports data_format = channels_last only.")
_check_unsupported_layers(model)
def _check_model_is_tf_keras():
return type(model).__module__.startswith("tensorflow.python.keras")

def _convert_input_layer(keras_layer):
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, new_var(input_name, shape=input_shape))

is_tf_keras = _check_model_is_tf_keras()

if not is_tf_keras:
# Importing from Keras
try:
import keras
except ImportError:
raise ImportError("Keras must be installed")
if keras.backend.backend() != 'tensorflow':
raise ValueError("Keras frontend currently supports tensorflow backend only.")
if keras.backend.image_data_format() != 'channels_last':
raise ValueError("Keras frontend currently supports data_format = channels_last only.")
expected_model_class = keras.engine.training.Model
input_layer_class = keras.engine.InputLayer
else:
# Importing from Tensorflow Keras (tf.keras)
try:
from tensorflow import keras as tf_keras
except ImportError:
raise ImportError("Tensorflow must be installed")
expected_model_class = tf_keras.models.Model
input_layer_class = tf_keras.layers.InputLayer

assert isinstance(model, expected_model_class)

etab = ExprTable()
for keras_layer in model.layers:
if isinstance(keras_layer, keras.engine.InputLayer):
if isinstance(keras_layer, input_layer_class):
_convert_input_layer(keras_layer)
else:
inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
Expand All @@ -784,10 +804,13 @@ def _convert_input_layer(keras_layer):
raise TypeError("Unknown layer type or unsupported Keras version : {}"
.format(keras_layer))
for node_idx, node in enumerate(inbound_nodes):
# If some nodes in imported model is not relevant to the current model,
# skip such layers. model._network_nodes contains keys of all nodes relevant
# to the current model.
if not model._node_key(keras_layer, node_idx) in model._network_nodes:
# If some nodes in imported model are not relevant to the current model,
# skip such layers.
# - In Keras, model._network_nodes contains keys of all nodes relevant to the
# current model;
# - In tf.Keras, this is already done as part of tensorflow.keras.network.get_config
if not is_tf_keras and \
not model._node_key(keras_layer, node_idx) in model._network_nodes:
continue
inexpr = []
# Since Keras allows creating multiple layers from the same name instance,
Expand All @@ -797,7 +820,7 @@ def _convert_input_layer(keras_layer):
# they are named uniquely to input_1, input_2, input_3... by default.
zip_node = zip(node.node_indices, node.tensor_indices, node.inbound_layers)
for n_idx, t_idx, inbound_layer in zip_node:
if isinstance(inbound_layer, keras.engine.InputLayer):
if isinstance(inbound_layer, input_layer_class):
expr_name = inbound_layer.name
_convert_input_layer(inbound_layer)
else:
Expand Down
Loading

0 comments on commit ffdfb3e

Please sign in to comment.