Skip to content

Commit

Permalink
Update type checking (keras-team#7507)
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoonlee authored and fchollet committed Aug 3, 2017
1 parent c83ca43 commit fc52d9a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,7 +1718,7 @@ def _is_input_shape_compatible(input, placeholder):

def __call__(self, inputs):
global _LEARNING_PHASE
assert type(inputs) in {list, tuple}
assert isinstance(inputs, (list, tuple))
feed_dict = {}
for tensor, value in zip(self.placeholders, inputs):
# cntk only support calculate on float, do auto cast here
Expand Down
2 changes: 1 addition & 1 deletion keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def compute_output_shape(self, input_shape):
if not isinstance(shape, (list, tuple)):
raise ValueError('`output_shape` function must return a tuple or a list of tuples.')
if isinstance(shape, list):
if type(shape[0]) == int or shape[0] is None:
if isinstance(shape[0], int) or shape[0] is None:
shape = tuple(shape)
return shape

Expand Down
2 changes: 1 addition & 1 deletion tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def cntk_func_single_tensor(function_name, x_shape, **kwargs):


def cntk_func_two_tensor(function_name, x_shape, y, **kwargs):
if type(y).__name__ == 'ndarray':
if isinstance(y, (np.generic, np.ndarray)):
xc = KC.placeholder(x_shape)
output_cntk = getattr(KC, function_name)(xc, KC.variable(y), **kwargs)
return KC.function([xc], [output_cntk])
Expand Down
14 changes: 7 additions & 7 deletions tests/keras/engine/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,20 @@ def test_node_construction():
node = a_layer.inbound_nodes[a_node_index]
assert node.outbound_layer == a_layer

assert type(node.inbound_layers) is list
assert isinstance(node.inbound_layers, list)
assert node.inbound_layers == []
assert type(node.input_tensors) is list
assert isinstance(node.input_tensors, list)
assert node.input_tensors == [a]
assert type(node.input_masks) is list
assert isinstance(node.input_masks, list)
assert node.input_masks == [None]
assert type(node.input_shapes) is list
assert isinstance(node.input_shapes, list)
assert node.input_shapes == [(None, 32)]

assert type(node.output_tensors) is list
assert isinstance(node.output_tensors, list)
assert node.output_tensors == [a]
assert type(node.output_shapes) is list
assert isinstance(node.output_shapes, list)
assert node.output_shapes == [(None, 32)]
assert type(node.output_masks) is list
assert isinstance(node.output_masks, list)
assert node.output_masks == [None]

dense = Dense(16, name='dense_1')
Expand Down

0 comments on commit fc52d9a

Please sign in to comment.