From fc52d9a084699a549ec1aadc7aed839bde75ef7c Mon Sep 17 00:00:00 2001 From: Taehoon Lee Date: Fri, 4 Aug 2017 00:49:11 +0900 Subject: [PATCH] Update type checking (#7507) --- keras/backend/cntk_backend.py | 2 +- keras/layers/core.py | 2 +- tests/keras/backend/backend_test.py | 2 +- tests/keras/engine/test_topology.py | 14 +++++++------- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index c2807d2d813..d416c4e18c1 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -1718,7 +1718,7 @@ def _is_input_shape_compatible(input, placeholder): def __call__(self, inputs): global _LEARNING_PHASE - assert type(inputs) in {list, tuple} + assert isinstance(inputs, (list, tuple)) feed_dict = {} for tensor, value in zip(self.placeholders, inputs): # cntk only support calculate on float, do auto cast here diff --git a/keras/layers/core.py b/keras/layers/core.py index 54c8fa9df89..a89ed39f79a 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -639,7 +639,7 @@ def compute_output_shape(self, input_shape): if not isinstance(shape, (list, tuple)): raise ValueError('`output_shape` function must return a tuple or a list of tuples.') if isinstance(shape, list): - if type(shape[0]) == int or shape[0] is None: + if isinstance(shape[0], int) or shape[0] is None: shape = tuple(shape) return shape diff --git a/tests/keras/backend/backend_test.py b/tests/keras/backend/backend_test.py index 0a22844a24f..f54d7c54328 100644 --- a/tests/keras/backend/backend_test.py +++ b/tests/keras/backend/backend_test.py @@ -26,7 +26,7 @@ def cntk_func_single_tensor(function_name, x_shape, **kwargs): def cntk_func_two_tensor(function_name, x_shape, y, **kwargs): - if type(y).__name__ == 'ndarray': + if isinstance(y, (np.generic, np.ndarray)): xc = KC.placeholder(x_shape) output_cntk = getattr(KC, function_name)(xc, KC.variable(y), **kwargs) return KC.function([xc], [output_cntk]) diff --git a/tests/keras/engine/test_topology.py b/tests/keras/engine/test_topology.py index 78d990f5a4f..6e3229d2d8e 100644 --- a/tests/keras/engine/test_topology.py +++ b/tests/keras/engine/test_topology.py @@ -151,20 +151,20 @@ def test_node_construction(): node = a_layer.inbound_nodes[a_node_index] assert node.outbound_layer == a_layer - assert type(node.inbound_layers) is list + assert isinstance(node.inbound_layers, list) assert node.inbound_layers == [] - assert type(node.input_tensors) is list + assert isinstance(node.input_tensors, list) assert node.input_tensors == [a] - assert type(node.input_masks) is list + assert isinstance(node.input_masks, list) assert node.input_masks == [None] - assert type(node.input_shapes) is list + assert isinstance(node.input_shapes, list) assert node.input_shapes == [(None, 32)] - assert type(node.output_tensors) is list + assert isinstance(node.output_tensors, list) assert node.output_tensors == [a] - assert type(node.output_shapes) is list + assert isinstance(node.output_shapes, list) assert node.output_shapes == [(None, 32)] - assert type(node.output_masks) is list + assert isinstance(node.output_masks, list) assert node.output_masks == [None] dense = Dense(16, name='dense_1')