Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model() yield op Input internal fixes, plus add K.is_placeholder() #7046

Closed
wants to merge 76 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
bd9d783
K.is_placeholder() added as common.py backend function
ahundt Jun 12, 2017
0fcec00
tensorflow_backend.py session.run(fetches, feed_dict) params are forw…
ahundt Jun 12, 2017
bb3a76c
.gitignore add .vscode developer environment
ahundt Jun 12, 2017
73221c9
K.is_placeholder() added as common.py backend function
ahundt Jun 12, 2017
07057e1
training.py _check_array_lengths(weights=None)
ahundt Jun 12, 2017
4067e5f
training.py _weighted_mask_objective() accesses weights variable corr…
ahundt Jun 12, 2017
bc742c9
training.py only create placeholder when tensor is a placeholder
ahundt Jun 12, 2017
b2e8cec
tensorflow tests pass
ahundt Jun 13, 2017
745f23d
mnist_tfrecord.py added
ahundt Jun 13, 2017
9a79ebc
mnist_tfrecord.py label op more clearly marked
ahundt Jun 13, 2017
b321de6
mnist_tfrecord.py removed debug lines
ahundt Jun 13, 2017
23b5d01
Merge commit 'f4cb8900245a12d03e33a5b76d2e33aa2bdda4f0' into tfrecord
ahundt Jun 13, 2017
66251aa
training.py model.fit() fetches and feed_dict docstring added
ahundt Jun 13, 2017
23b0ebf
Revert "training.py model.fit() fetches and feed_dict docstring added"
ahundt Jun 13, 2017
9fa2747
mnist_tfrecord.py remove extraneous comment
ahundt Jun 14, 2017
db16bb2
Container.__init__ new labels tensor parameter supports tensor genera…
ahundt Jun 15, 2017
8ad7ec8
mnist_tfrecord.py remove unused read_and_decode function
ahundt Jun 15, 2017
48f6d5a
mnist_tfrecord.py progbar
ahundt Jun 15, 2017
e8d6ad3
mnist_tfrecord.py network is closer to mnist_cnn.py
ahundt Jun 15, 2017
bc2a0f4
mnist_tfrecords.py add StagingArea
ahundt Jun 15, 2017
7ecef29
is_placeholder raised exceptions now return False
ahundt Jun 15, 2017
a0639e1
mnist_tfrecord.py remove stray comment
ahundt Jun 15, 2017
b7ea15f
training.py use append for lists instead of + operator, fixes crash w…
ahundt Jun 16, 2017
36a36a0
tensorflow_backend.py prevent Function kwargs from accumulating inadv…
ahundt Jun 16, 2017
1a1183d
Revert "training.py use append for lists instead of + operator, fixes…
ahundt Jun 16, 2017
9bf9b23
training.py fix crash when ins does not have shape attribute
ahundt Jun 16, 2017
3d323ec
training.py better verification of do_validation
ahundt Jun 16, 2017
aa2aa7e
mnist_tfrecord.py remove validation data during model.fit() because i…
ahundt Jun 16, 2017
4f38311
pep8
ahundt Jun 16, 2017
773faaf
mnist_tfrecord.py one import per line
ahundt Jun 16, 2017
6d4b898
tensorflow_backend.py clarify Function.__call__ according to review c…
ahundt Jun 16, 2017
40ac8c7
test_training.py initial TFRecord test in progress.
ahundt Jun 16, 2017
5954c61
training.py more stringent checks of inputs targets and weights
ahundt Jun 16, 2017
98392b2
training.py checks tensor labels in case weights aren't supported and…
ahundt Jun 16, 2017
748c243
training.py Model checks y more stringently
ahundt Jun 16, 2017
afcbaa7
training.py cleanup backend function calls
ahundt Jun 16, 2017
42fbdbc
training.py don't slice None entries
ahundt Jun 16, 2017
75edea7
topology.py always insert layer properties into feed
ahundt Jun 16, 2017
caad65a
training.py define ins consistently
ahundt Jun 16, 2017
845dd88
tensorflow_backend.py if feed_dict entry is None add it to fetches in…
ahundt Jun 16, 2017
39a7981
trainin.py ins should at least include an empty list
ahundt Jun 17, 2017
c0d0cda
training.py first TFRecord test passes!
ahundt Jun 17, 2017
6248b54
predict_on_batch runs
ahundt Jun 17, 2017
b8dae76
self._prepare_sample_weights(sample_weight_mode, skip_indices)
ahundt Jun 17, 2017
fe2916b
mae=>mse
ahundt Jun 17, 2017
4793149
_prepare_sample_weights initialization & return changes, may be buggy
ahundt Jun 17, 2017
af382e8
_make_function implementation changes to better support predict_funct…
ahundt Jun 17, 2017
1e6b9fb
tensorflow_backend.py Function handles varying length inputs
ahundt Jun 17, 2017
ea9ffba
TFRecord predict() passes again
ahundt Jun 17, 2017
06b71c4
tensorflow_backend Function __call__ izip_longest explicit fillvalue
ahundt Jun 18, 2017
325acc1
_make_function call & list bugs fixed
ahundt Jun 18, 2017
974886e
is_placeholder and standardize user data error checks fixed
ahundt Jun 18, 2017
3c7dc09
topology.py add missing docstring
ahundt Jun 18, 2017
ffb57fd
Merge branch 'master' into tfrecord_merge
ahundt Jun 18, 2017
20077cb
test_multiprocessing.py fix test which actually throws two exceptions
ahundt Jun 18, 2017
9dd1ea9
test_training.py remove extraneous try/except
ahundt Jun 18, 2017
c37b82c
test_training.py tfrecord test of predict_on_batch, evaluate, predict
ahundt Jun 18, 2017
75fbfbb
mnist_tfrecord.py _is_placeholder workaround no longer required.
ahundt Jun 18, 2017
6de4e83
tensorflow_backend.py py 2+3 compatibility: from moves.six import zip…
ahundt Jun 18, 2017
3d75649
mnist_tfrecord.py add parens to print
ahundt Jun 19, 2017
0a1606f
test_training.py add parenthesis to print, plus an extra error check …
ahundt Jun 19, 2017
b7d44a5
wrappers_test.py fix tolerance in def test_TimeDistributed_learning_p…
ahundt Jun 19, 2017
3be24ba
Merge branch 'tfrecord' into internal_fixes_tfrecord, including `is_p…
ahundt Jun 19, 2017
39d3eba
test_training.py Input yield op tests
ahundt Jun 20, 2017
fb27e11
Merge branch 'master' into internal_fixes_tfrecord
ahundt Jun 20, 2017
a717d31
Merge branch 'master' into internal_fixes_tfrecord
ahundt Jun 20, 2017
ab9fe32
mnist_tfrecord.py added with yield_op workaround (#7046)
ahundt Jun 20, 2017
e0ffad4
test_training.py extended yield_op tests, reduce code repetition
ahundt Jun 21, 2017
ad9283d
generic_utils.py don't crash when dealing with batched data
ahundt Jun 21, 2017
1a5f6b6
generic_utils.py don't crash when dealing with batched data
ahundt Jun 21, 2017
38cfe25
Progbar() unit test
ahundt Jun 21, 2017
c7f6eb3
Fix mnist_tfrecord.py runtime errors
ahundt Jun 21, 2017
d1ea4b7
Merge branch 'master' into progbar_batches
ahundt Jun 21, 2017
b8019c6
reorder inputs
ahundt Jun 21, 2017
6c42c6c
Merge commit 'de73eda89a916c4dd46ce74058bb2664455ed9db' into internal…
ahundt Jun 21, 2017
99da67d
Merge branch 'progbar_batches' into internal_fixes_tfrecord
ahundt Jun 22, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions examples/mnist_tfrecord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
'''MNIST dataset with TensorFlow TFRecords.

Gets to 99.25% test accuracy after 12 epochs
(there is still a lot of margin for parameter tuning).
'''
import os
import copy
import time

import numpy as np

import tensorflow as tf
from tensorflow.python.ops import data_flow_ops
from keras import backend as K
from keras.models import Model
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers import Input
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.callbacks import EarlyStopping
from keras.callbacks import TensorBoard
from keras.objectives import categorical_crossentropy
from keras.utils import np_utils
from keras.utils.generic_utils import Progbar
from keras import callbacks as cbks
from keras import optimizers, objectives
from keras import metrics as metrics_module

from keras.datasets import mnist

if K.backend() != 'tensorflow':
raise RuntimeError('This example can only run with the '
'TensorFlow backend for the time being, '
'because it requires TFRecords, which '
'are not supported on other platforms.')


def images_to_tfrecord(images, labels, filename):
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

""" Save data into TFRecord """
if not os.path.isfile(filename):
num_examples = images.shape[0]

rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]

print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()
else:
print('tfrecord %s already exists' % filename)


def read_and_decode_recordinput(tf_glob, one_hot=True, classes=None, is_train=None, batch_shape=[1000, 28, 28, 1]):
""" Return tensor to read from TFRecord """
print 'Creating graph for loading %s TFRecords...' % tf_glob
with tf.variable_scope("TFRecords"):
record_input = data_flow_ops.RecordInput(tf_glob, batch_size=batch_shape[0])
records_op = record_input.get_yield_op()
records_op = tf.split(records_op, batch_shape[0], 0)
records_op = [tf.reshape(record, []) for record in records_op]
progbar = Progbar(len(records_op))

images = []
labels = []
for i, serialized_example in enumerate(records_op):
progbar.update(i)
with tf.variable_scope("parse_images", reuse=True):
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
})
img = tf.decode_raw(features['image_raw'], tf.uint8)
img.set_shape(batch_shape[1] * batch_shape[2])
img = tf.reshape(img, [1] + batch_shape[1:])

img = tf.cast(img, tf.float32) * (1. / 255) - 0.5

label = tf.cast(features['label'], tf.int32)
if one_hot and classes:
label = tf.one_hot(label, classes)

images.append(img)
labels.append(label)

images = tf.parallel_stack(images, 0)
labels = tf.parallel_stack(labels, 0)
images = tf.cast(images, tf.float32)

images = tf.reshape(images, shape=batch_shape)

# StagingArea will store tensors
# across multiple steps to
# speed up execution
images_shape = images.get_shape()
labels_shape = labels.get_shape()
copy_stage = data_flow_ops.StagingArea(
[tf.float32, tf.float32],
shapes=[images_shape, labels_shape])
copy_stage_op = copy_stage.put(
[images, labels])
staged_images, staged_labels = copy_stage.get()

return images, labels


def save_mnist_as_tfrecord():
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train[..., np.newaxis]
X_test = X_test[..., np.newaxis]
images_to_tfrecord(images=X_train, labels=y_train, filename='train.mnist.tfrecord')
images_to_tfrecord(images=X_test, labels=y_test, filename='test.mnist.tfrecord')


def cnn_layers(x_train_input):
x = Conv2D(32, (3, 3), activation='relu', padding='valid')(x_train_input)
x = Conv2D(64, (3, 3), activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.5)(x)
x_train_out = Dense(classes,
activation='softmax',
name='x_train_out')(x)
return x_train_out


sess = tf.Session()
K.set_session(sess)

save_mnist_as_tfrecord()

batch_size = 1000
batch_shape = [batch_size, 28, 28, 1]
epochs = 6000
classes = 10

x_train_batch, y_train_batch = read_and_decode_recordinput(
'train.mnist.tfrecord',
one_hot=True,
classes=classes,
is_train=True,
batch_shape=batch_shape)

x_test_batch, y_test_batch = read_and_decode_recordinput(
'test.mnist.tfrecord',
one_hot=True,
classes=classes,
is_train=True,
batch_shape=batch_shape)


x_batch_shape = x_train_batch.get_shape().as_list()
y_batch_shape = y_train_batch.get_shape().as_list()

x_train_input = Input(tensor=x_train_batch, batch_shape=x_batch_shape)
x_train_out = cnn_layers(x_train_input)
y_train_in_out = Input(tensor=y_train_batch, batch_shape=y_batch_shape, name='y_labels')
cce = categorical_crossentropy(y_train_batch, x_train_out)
train_model = Model(inputs=[x_train_input], outputs=[x_train_out])
train_model.add_loss(cce)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a lot of people knew about this trick. Really nice!

Copy link
Contributor Author

@ahundt ahundt Jun 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

credit: @fchollet in #6928

However, it is extremely limiting, a good portion of the functionality enabled by model.compile() is skipped... see below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally I've discovered this trick has serious performance problems, see #7075 (comment). It is 20x slower than mnist_tfrecord.py in #6928


train_model.compile(optimizer='rmsprop',
loss=None,
metrics=['accuracy'])
train_model.summary()

tensorboard = TensorBoard()

train_model.fit(batch_size=batch_size,
epochs=epochs)
# disabled due to Keras bug
# callbacks=[tensorboard])
train_model.save_weights('saved_wt.h5')

K.clear_session()

# Second Session, pure Keras
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train[..., np.newaxis]
X_test = X_test[..., np.newaxis]
x_test_inp = Input(batch_shape=(None,) + (X_test.shape[1:]))
test_out = cnn_layers(x_test_inp)
test_model = Model(inputs=x_test_inp, outputs=test_out)

test_model.load_weights('saved_wt.h5')
test_model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
test_model.summary()

loss, acc = test_model.evaluate(X_test, np_utils.to_categorical(y_test), classes)
print('\nTest accuracy: {0}'.format(acc))
2 changes: 1 addition & 1 deletion examples/variational_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def sampling(args):
# Custom loss layer
class CustomVariationalLayer(Layer):
def __init__(self, **kwargs):
self.is_placeholder = True
self._is_placeholder = True
super(CustomVariationalLayer, self).__init__(**kwargs)

def vae_loss(self, x, x_decoded_mean):
Expand Down
2 changes: 1 addition & 1 deletion examples/variational_autoencoder_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def sampling(args):
# Custom loss layer
class CustomVariationalLayer(Layer):
def __init__(self, **kwargs):
self.is_placeholder = True
self._is_placeholder = True
super(CustomVariationalLayer, self).__init__(**kwargs)

def vae_loss(self, x, x_decoded_mean_squash):
Expand Down
1 change: 1 addition & 0 deletions keras/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .common import cast_to_floatx
from .common import image_data_format
from .common import set_image_data_format
from .common import is_placeholder

# Obtain Keras base dir path: either ~/.keras or /tmp.
_keras_base_dir = os.path.expanduser('~')
Expand Down
2 changes: 2 additions & 0 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import cntk as C
import numpy as np
from .common import _FLOATX, _EPSILON, image_dim_ordering, image_data_format
from .common import is_placeholder
from collections import defaultdict
from contextlib import contextmanager
import warnings
Expand Down Expand Up @@ -256,6 +257,7 @@ def placeholder(
name=name)
x._keras_shape = shape
x._uses_learning_phase = False
x._is_placeholder = True
return x


Expand Down
23 changes: 23 additions & 0 deletions keras/backend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,29 @@ def cast_to_floatx(x):
return np.asarray(x, dtype=_FLOATX)


def is_placeholder(tensor):
"""Returns whether a tensor is a placeholder.

# Arguments
tensor: A tensor instance.

# Returns
A boolean.

# Example
```python
>>> from keras import backend as K
>>> a = K.placeholder((2, 2), sparse=False)
>>> print(K.is_placeholder(a))
True
```
"""
try:
return tensor._is_placeholder
except AttributeError:
return False


def image_data_format():
"""Returns the default image data format convention ('channels_first' or 'channels_last').

Expand Down
50 changes: 43 additions & 7 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

import numpy as np
import os
from six.moves import zip_longest

from .common import floatx
from .common import _EPSILON
from .common import image_data_format
from ..utils.generic_utils import has_arg
from .common import is_placeholder

# Legacy functions
from .common import set_image_dim_ordering
Expand Down Expand Up @@ -374,7 +376,7 @@ def is_keras_tensor(x):
```python
>>> from keras import backend as K
>>> np_var = numpy.array([1, 2])
>>> K.is_keras_tensor(np_var) # A numpy array is not a symbolic yensor.
>>> K.is_keras_tensor(np_var) # A numpy array is not a symbolic tensor.
ValueError
>>> k_var = tf.placeholder('float32', shape=(1,1))
>>> K.is_keras_tensor(k_var) # A variable created directly from tensorflow/theano is not a Keras tensor.
Expand Down Expand Up @@ -432,6 +434,7 @@ def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
x = tf.placeholder(dtype, shape=shape, name=name)
x._keras_shape = shape
x._uses_learning_phase = False
x._is_placeholder = True
return x


Expand Down Expand Up @@ -2223,9 +2226,12 @@ class Function(object):
outputs: Output tensors to fetch.
updates: Additional update ops to be run at function call.
name: a name to help users identify what this function does.
fetches: Parameters forwarded to `tf.session.run(fetches)`.
feed_dict: Parameters forwarded to `tf.session.run(feed_dict)`.
"""

def __init__(self, inputs, outputs, updates=None, name=None, **session_kwargs):
def __init__(self, inputs, outputs, updates=None, name=None,
fetches=None, feed_dict=None, **session_kwargs):
updates = updates or []
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` to a TensorFlow backend function '
Expand All @@ -2236,8 +2242,11 @@ def __init__(self, inputs, outputs, updates=None, name=None, **session_kwargs):
if not isinstance(updates, (list, tuple)):
raise TypeError('`updates` in a TensorFlow backend function '
'should be a list or tuple.')
# self.inputs holds tf Tensor objects
self.inputs = list(inputs)
self.outputs = list(outputs)
self.fetches = fetches
self.feed_dict = feed_dict
with tf.control_dependencies(self.outputs):
updates_ops = []
for update in updates:
Expand All @@ -2252,19 +2261,46 @@ def __init__(self, inputs, outputs, updates=None, name=None, **session_kwargs):
self.session_kwargs = session_kwargs

def __call__(self, inputs):
"""Run the TensorFlow session

# Arguments
inputs: Data and values that will go to the feed_dict of Session.run()
if it is associated with a tensor, if it is None the tensor will
be added to the fetches parameter of Session.run().
"""
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` should be a list or tuple.')
feed_dict = {}
for tensor, value in zip(self.inputs, inputs):
self.current_feed_dict = {} if self.feed_dict is None else self.feed_dict
self.feed_to_fetch_count = 0
self.current_fetches = self.outputs + [self.updates_op]
# self.inputs contains tf tensors, inputs contains feed_dict data.
for tensor, value in zip_longest(self.inputs, inputs, fillvalue=None):
if tensor is None and value is None:
continue
elif tensor is None and value is not None:
raise ValueError('A tensor containing None '
'was tied to value ' + str(value) +
'so Session.run() cannot execute, '
'please check your data and Model.')

if is_sparse(tensor):
sparse_coo = value.tocoo()
indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
np.expand_dims(sparse_coo.col, 1)), 1)
value = (indices, sparse_coo.data, sparse_coo.shape)
feed_dict[tensor] = value

if value is None and tensor is not None:
self.feed_to_fetch_count += 1
self.current_fetches.append(tensor)
else:
self.current_feed_dict[tensor] = value

if self.fetches is not None:
self.current_fetches += self.fetches

session = get_session()
updated = session.run(self.outputs + [self.updates_op],
feed_dict=feed_dict,
updated = session.run(fetches=self.current_fetches,
feed_dict=self.current_feed_dict,
**self.session_kwargs)
return updated[:len(self.outputs)]

Expand Down
Loading