diff --git a/examples/variational_autoencoder.py b/examples/variational_autoencoder.py index fd395b17f45..ba313f1c69b 100644 --- a/examples/variational_autoencoder.py +++ b/examples/variational_autoencoder.py @@ -45,7 +45,7 @@ def sampling(args): # Custom loss layer class CustomVariationalLayer(Layer): def __init__(self, **kwargs): - self.is_placeholder = True + self._is_placeholder = True super(CustomVariationalLayer, self).__init__(**kwargs) def vae_loss(self, x, x_decoded_mean): diff --git a/examples/variational_autoencoder_deconv.py b/examples/variational_autoencoder_deconv.py index f5aa1224d47..3612754c494 100644 --- a/examples/variational_autoencoder_deconv.py +++ b/examples/variational_autoencoder_deconv.py @@ -109,7 +109,7 @@ def sampling(args): # Custom loss layer class CustomVariationalLayer(Layer): def __init__(self, **kwargs): - self.is_placeholder = True + self._is_placeholder = True super(CustomVariationalLayer, self).__init__(**kwargs) def vae_loss(self, x, x_decoded_mean_squash): diff --git a/keras/backend/__init__.py b/keras/backend/__init__.py index 2e7208cf5af..00e6d52fbe8 100644 --- a/keras/backend/__init__.py +++ b/keras/backend/__init__.py @@ -10,6 +10,7 @@ from .common import cast_to_floatx from .common import image_data_format from .common import set_image_data_format +from .common import is_placeholder # Obtain Keras base dir path: either ~/.keras or /tmp. _keras_base_dir = os.path.expanduser('~') diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index b4161ced8e7..cfe056896a0 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -2,6 +2,7 @@ import cntk as C import numpy as np from .common import _FLOATX, _EPSILON, image_dim_ordering, image_data_format +from .common import is_placeholder from collections import defaultdict from contextlib import contextmanager import warnings @@ -256,6 +257,7 @@ def placeholder( name=name) x._keras_shape = shape x._uses_learning_phase = False + x._is_placeholder = True return x diff --git a/keras/backend/common.py b/keras/backend/common.py index 97646067c9b..a147c7f8e1d 100644 --- a/keras/backend/common.py +++ b/keras/backend/common.py @@ -108,6 +108,29 @@ def cast_to_floatx(x): return np.asarray(x, dtype=_FLOATX) +def is_placeholder(tensor): + """Returns whether a tensor is a placeholder. + + # Arguments + tensor: A tensor instance. + + # Returns + A boolean. + + # Example + ```python + >>> from keras import backend as K + >>> a = K.placeholder((2, 2), sparse=False) + >>> print(K.is_placeholder(a)) + True + ``` + """ + try: + return tensor._is_placeholder + except AttributeError: + return False + + def image_data_format(): """Returns the default image data format convention ('channels_first' or 'channels_last'). diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index ec6b637207d..4f109f67ea0 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -10,10 +10,12 @@ import inspect import numpy as np import os +from six.moves import zip_longest from .common import floatx from .common import _EPSILON from .common import image_data_format +from .common import is_placeholder # Legacy functions from .common import set_image_dim_ordering @@ -373,7 +375,7 @@ def is_keras_tensor(x): ```python >>> from keras import backend as K >>> np_var = numpy.array([1, 2]) - >>> K.is_keras_tensor(np_var) # A numpy array is not a symbolic yensor. + >>> K.is_keras_tensor(np_var) # A numpy array is not a symbolic tensor. ValueError >>> k_var = tf.placeholder('float32', shape=(1,1)) >>> K.is_keras_tensor(k_var) # A variable created directly from tensorflow/theano is not a Keras tensor. @@ -431,6 +433,7 @@ def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None): x = tf.placeholder(dtype, shape=shape, name=name) x._keras_shape = shape x._uses_learning_phase = False + x._is_placeholder = True return x @@ -2222,9 +2225,12 @@ class Function(object): outputs: Output tensors to fetch. updates: Additional update ops to be run at function call. name: a name to help users identify what this function does. + fetches: Parameters forwarded to `tf.session.run(fetches)`. + feed_dict: Parameters forwarded to `tf.session.run(feed_dict)`. """ - def __init__(self, inputs, outputs, updates=None, name=None, **session_kwargs): + def __init__(self, inputs, outputs, updates=None, name=None, + fetches=None, feed_dict=None, **session_kwargs): updates = updates or [] if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` to a TensorFlow backend function ' @@ -2235,8 +2241,11 @@ def __init__(self, inputs, outputs, updates=None, name=None, **session_kwargs): if not isinstance(updates, (list, tuple)): raise TypeError('`updates` in a TensorFlow backend function ' 'should be a list or tuple.') + # self.inputs holds tf Tensor objects self.inputs = list(inputs) self.outputs = list(outputs) + self.fetches = fetches + self.feed_dict = feed_dict with tf.control_dependencies(self.outputs): updates_ops = [] for update in updates: @@ -2251,19 +2260,46 @@ def __init__(self, inputs, outputs, updates=None, name=None, **session_kwargs): self.session_kwargs = session_kwargs def __call__(self, inputs): + """Run the TensorFlow session + + # Arguments + inputs: Data and values that will go to the feed_dict of Session.run() + if it is associated with a tensor, if it is None the tensor will + be added to the fetches parameter of Session.run(). + """ if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` should be a list or tuple.') - feed_dict = {} - for tensor, value in zip(self.inputs, inputs): + self.current_feed_dict = {} if self.feed_dict is None else self.feed_dict + self.feed_to_fetch_count = 0 + self.current_fetches = self.outputs + [self.updates_op] + # self.inputs contains tf tensors, inputs contains feed_dict data. + for tensor, value in zip_longest(self.inputs, inputs, fillvalue=None): + if tensor is None and value is None: + continue + elif tensor is None and value is not None: + raise ValueError('A tensor containing None ' + 'was tied to value ' + str(value) + + 'so Session.run() cannot execute, ' + 'please check your data and Model.') + if is_sparse(tensor): sparse_coo = value.tocoo() indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(sparse_coo.col, 1)), 1) value = (indices, sparse_coo.data, sparse_coo.shape) - feed_dict[tensor] = value + + if value is None and tensor is not None: + self.feed_to_fetch_count += 1 + self.current_fetches.append(tensor) + else: + self.current_feed_dict[tensor] = value + + if self.fetches is not None: + self.current_fetches += self.fetches + session = get_session() - updated = session.run(self.outputs + [self.updates_op], - feed_dict=feed_dict, + updated = session.run(fetches=self.current_fetches, + feed_dict=self.current_feed_dict, **self.session_kwargs) return updated[:len(self.outputs)] diff --git a/keras/engine/topology.py b/keras/engine/topology.py index ba5cbf707f5..61454f14330 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -97,6 +97,7 @@ class Node(object): output_shapes: list of output shape tuples. arguments: dictionary of keyword arguments that were passed to the `call` method of the layer at the call that created the node. + is_placeholder: Specifies if the Node represents a placeholder. `node_indices` and `tensor_indices` are basically fine-grained coordinates describing the origin of the `input_tensors`, verifying the following: @@ -113,7 +114,7 @@ def __init__(self, outbound_layer, input_tensors, output_tensors, input_masks, output_masks, input_shapes, output_shapes, - arguments=None): + arguments=None, is_placeholder=False): # Layer instance (NOT a list). # this is the layer that takes a list of input tensors # and turns them into a list of output tensors. @@ -157,6 +158,8 @@ def __init__(self, outbound_layer, # Optional keyword arguments to layer's `call`. self.arguments = arguments + # Indicates if the node represents a placeholder variable + self._is_placeholder = is_placeholder # Add nodes to all layers involved. for layer in inbound_layers: if layer is not None: @@ -1330,18 +1333,19 @@ def __init__(self, input_shape=None, batch_size=None, self.dtype = dtype if input_tensor is None: - self.is_placeholder = True + self._is_placeholder = True input_tensor = K.placeholder(shape=batch_input_shape, dtype=dtype, sparse=self.sparse, name=self.name) else: - self.is_placeholder = False + self._is_placeholder = False input_tensor._keras_shape = batch_input_shape # Create an input node to add to self.outbound_node # and set output_tensors' _keras_history. input_tensor._uses_learning_phase = False input_tensor._keras_history = (self, 0, 0) + input_tensor._is_placeholder = self._is_placeholder Node(self, inbound_layers=[], node_indices=[], @@ -1351,7 +1355,8 @@ def __init__(self, input_shape=None, batch_size=None, input_masks=[None], output_masks=[None], input_shapes=[batch_input_shape], - output_shapes=[batch_input_shape]) + output_shapes=[batch_input_shape], + is_placeholder=self._is_placeholder) def get_config(self): config = {'batch_input_shape': self.batch_input_shape, @@ -1491,11 +1496,14 @@ def __init__(self, inputs, outputs, name=None): self.inputs = list(inputs) # Tensor or list of tensors. else: self.inputs = [inputs] + if isinstance(outputs, (list, tuple)): self.outputs = list(outputs) else: self.outputs = [outputs] + self.target_configuration = [None] * len(self.outputs) + # Check for redundancy in inputs. if len(set(self.inputs)) != len(self.inputs): raise ValueError('The list of inputs passed to the model ' @@ -1535,6 +1543,8 @@ def __init__(self, inputs, outputs, name=None): self._output_tensor_cache = {} self._output_shape_cache = {} + self._input_placeholders = [] + self._input_yield_op_tensors = [] # User-provided arguments validation. for x in self.inputs: # Check that x is a Keras tensor. @@ -1560,6 +1570,10 @@ def __init__(self, inputs, outputs, name=None): 'instantiated via `tensor = Input(shape)`.\n' 'The tensor that caused the issue was: ' + str(x.name)) + if K.is_placeholder(layer): + self._input_placeholders.append((layer, node_index, tensor_index)) + else: + self._input_yield_op_tensors.append((layer, node_index, tensor_index)) for x in self.outputs: if not hasattr(x, '_keras_history'): cls_name = self.__class__.__name__ @@ -1621,7 +1635,7 @@ def __init__(self, inputs, outputs, name=None): i, layer.__class__.__name__)) self.input_names.append(layer.name) - if layer.is_placeholder: + if K.is_placeholder(layer): self._feed_input_names.append(layer.name) self._feed_inputs.append(layer.input) self._feed_input_shapes.append(self.inputs[i]._keras_shape) diff --git a/keras/models.py b/keras/models.py index 6228fc81b2a..3d9aea608c0 100644 --- a/keras/models.py +++ b/keras/models.py @@ -456,6 +456,7 @@ def add(self, layer): 'For multi-output layers, ' 'use the functional API.') + layer.inbound_nodes[0].output_tensors[0]._is_placeholder = True self.outputs = [layer.inbound_nodes[0].output_tensors[0]] self.inputs = topology.get_source_inputs(self.outputs[0]) @@ -479,6 +480,7 @@ def add(self, layer): 'should have a single output tensor. ' 'For multi-output layers, ' 'use the functional API.') + output_tensor._is_placeholder = True self.outputs = [output_tensor] # update self.inbound_nodes self.inbound_nodes[0].output_tensors = self.outputs