-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model() yield op Input internal fixes, plus add K.is_placeholder() #7046
Closed
Closed
Changes from all commits
Commits
Show all changes
76 commits
Select commit
Hold shift + click to select a range
bd9d783
K.is_placeholder() added as common.py backend function
ahundt 0fcec00
tensorflow_backend.py session.run(fetches, feed_dict) params are forw…
ahundt bb3a76c
.gitignore add .vscode developer environment
ahundt 73221c9
K.is_placeholder() added as common.py backend function
ahundt 07057e1
training.py _check_array_lengths(weights=None)
ahundt 4067e5f
training.py _weighted_mask_objective() accesses weights variable corr…
ahundt bc742c9
training.py only create placeholder when tensor is a placeholder
ahundt b2e8cec
tensorflow tests pass
ahundt 745f23d
mnist_tfrecord.py added
ahundt 9a79ebc
mnist_tfrecord.py label op more clearly marked
ahundt b321de6
mnist_tfrecord.py removed debug lines
ahundt 23b5d01
Merge commit 'f4cb8900245a12d03e33a5b76d2e33aa2bdda4f0' into tfrecord
ahundt 66251aa
training.py model.fit() fetches and feed_dict docstring added
ahundt 23b0ebf
Revert "training.py model.fit() fetches and feed_dict docstring added"
ahundt 9fa2747
mnist_tfrecord.py remove extraneous comment
ahundt db16bb2
Container.__init__ new labels tensor parameter supports tensor genera…
ahundt 8ad7ec8
mnist_tfrecord.py remove unused read_and_decode function
ahundt 48f6d5a
mnist_tfrecord.py progbar
ahundt e8d6ad3
mnist_tfrecord.py network is closer to mnist_cnn.py
ahundt bc2a0f4
mnist_tfrecords.py add StagingArea
ahundt 7ecef29
is_placeholder raised exceptions now return False
ahundt a0639e1
mnist_tfrecord.py remove stray comment
ahundt b7ea15f
training.py use append for lists instead of + operator, fixes crash w…
ahundt 36a36a0
tensorflow_backend.py prevent Function kwargs from accumulating inadv…
ahundt 1a1183d
Revert "training.py use append for lists instead of + operator, fixes…
ahundt 9bf9b23
training.py fix crash when ins does not have shape attribute
ahundt 3d323ec
training.py better verification of do_validation
ahundt aa2aa7e
mnist_tfrecord.py remove validation data during model.fit() because i…
ahundt 4f38311
pep8
ahundt 773faaf
mnist_tfrecord.py one import per line
ahundt 6d4b898
tensorflow_backend.py clarify Function.__call__ according to review c…
ahundt 40ac8c7
test_training.py initial TFRecord test in progress.
ahundt 5954c61
training.py more stringent checks of inputs targets and weights
ahundt 98392b2
training.py checks tensor labels in case weights aren't supported and…
ahundt 748c243
training.py Model checks y more stringently
ahundt afcbaa7
training.py cleanup backend function calls
ahundt 42fbdbc
training.py don't slice None entries
ahundt 75edea7
topology.py always insert layer properties into feed
ahundt caad65a
training.py define ins consistently
ahundt 845dd88
tensorflow_backend.py if feed_dict entry is None add it to fetches in…
ahundt 39a7981
trainin.py ins should at least include an empty list
ahundt c0d0cda
training.py first TFRecord test passes!
ahundt 6248b54
predict_on_batch runs
ahundt b8dae76
self._prepare_sample_weights(sample_weight_mode, skip_indices)
ahundt fe2916b
mae=>mse
ahundt 4793149
_prepare_sample_weights initialization & return changes, may be buggy
ahundt af382e8
_make_function implementation changes to better support predict_funct…
ahundt 1e6b9fb
tensorflow_backend.py Function handles varying length inputs
ahundt ea9ffba
TFRecord predict() passes again
ahundt 06b71c4
tensorflow_backend Function __call__ izip_longest explicit fillvalue
ahundt 325acc1
_make_function call & list bugs fixed
ahundt 974886e
is_placeholder and standardize user data error checks fixed
ahundt 3c7dc09
topology.py add missing docstring
ahundt ffb57fd
Merge branch 'master' into tfrecord_merge
ahundt 20077cb
test_multiprocessing.py fix test which actually throws two exceptions
ahundt 9dd1ea9
test_training.py remove extraneous try/except
ahundt c37b82c
test_training.py tfrecord test of predict_on_batch, evaluate, predict
ahundt 75fbfbb
mnist_tfrecord.py _is_placeholder workaround no longer required.
ahundt 6de4e83
tensorflow_backend.py py 2+3 compatibility: from moves.six import zip…
ahundt 3d75649
mnist_tfrecord.py add parens to print
ahundt 0a1606f
test_training.py add parenthesis to print, plus an extra error check …
ahundt b7d44a5
wrappers_test.py fix tolerance in def test_TimeDistributed_learning_p…
ahundt 3be24ba
Merge branch 'tfrecord' into internal_fixes_tfrecord, including `is_p…
ahundt 39d3eba
test_training.py Input yield op tests
ahundt fb27e11
Merge branch 'master' into internal_fixes_tfrecord
ahundt a717d31
Merge branch 'master' into internal_fixes_tfrecord
ahundt ab9fe32
mnist_tfrecord.py added with yield_op workaround (#7046)
ahundt e0ffad4
test_training.py extended yield_op tests, reduce code repetition
ahundt ad9283d
generic_utils.py don't crash when dealing with batched data
ahundt 1a5f6b6
generic_utils.py don't crash when dealing with batched data
ahundt 38cfe25
Progbar() unit test
ahundt c7f6eb3
Fix mnist_tfrecord.py runtime errors
ahundt d1ea4b7
Merge branch 'master' into progbar_batches
ahundt b8019c6
reorder inputs
ahundt 6c42c6c
Merge commit 'de73eda89a916c4dd46ce74058bb2664455ed9db' into internal…
ahundt 99da67d
Merge branch 'progbar_batches' into internal_fixes_tfrecord
ahundt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
'''MNIST dataset with TensorFlow TFRecords. | ||
|
||
Gets to 99.25% test accuracy after 12 epochs | ||
(there is still a lot of margin for parameter tuning). | ||
''' | ||
import os | ||
import copy | ||
import time | ||
|
||
import numpy as np | ||
|
||
import tensorflow as tf | ||
from tensorflow.python.ops import data_flow_ops | ||
from keras import backend as K | ||
from keras.models import Model | ||
from keras.layers import Dense | ||
from keras.layers import Dropout | ||
from keras.layers import Flatten | ||
from keras.layers import Input | ||
from keras.layers import Conv2D | ||
from keras.layers import MaxPooling2D | ||
from keras.callbacks import EarlyStopping | ||
from keras.callbacks import TensorBoard | ||
from keras.objectives import categorical_crossentropy | ||
from keras.utils import np_utils | ||
from keras.utils.generic_utils import Progbar | ||
from keras import callbacks as cbks | ||
from keras import optimizers, objectives | ||
from keras import metrics as metrics_module | ||
|
||
from keras.datasets import mnist | ||
|
||
if K.backend() != 'tensorflow': | ||
raise RuntimeError('This example can only run with the ' | ||
'TensorFlow backend for the time being, ' | ||
'because it requires TFRecords, which ' | ||
'are not supported on other platforms.') | ||
|
||
|
||
def images_to_tfrecord(images, labels, filename): | ||
def _int64_feature(value): | ||
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | ||
|
||
def _bytes_feature(value): | ||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | ||
|
||
""" Save data into TFRecord """ | ||
if not os.path.isfile(filename): | ||
num_examples = images.shape[0] | ||
|
||
rows = images.shape[1] | ||
cols = images.shape[2] | ||
depth = images.shape[3] | ||
|
||
print('Writing', filename) | ||
writer = tf.python_io.TFRecordWriter(filename) | ||
for index in range(num_examples): | ||
image_raw = images[index].tostring() | ||
example = tf.train.Example(features=tf.train.Features(feature={ | ||
'height': _int64_feature(rows), | ||
'width': _int64_feature(cols), | ||
'depth': _int64_feature(depth), | ||
'label': _int64_feature(int(labels[index])), | ||
'image_raw': _bytes_feature(image_raw)})) | ||
writer.write(example.SerializeToString()) | ||
writer.close() | ||
else: | ||
print('tfrecord %s already exists' % filename) | ||
|
||
|
||
def read_and_decode_recordinput(tf_glob, one_hot=True, classes=None, is_train=None, batch_shape=[1000, 28, 28, 1]): | ||
""" Return tensor to read from TFRecord """ | ||
print 'Creating graph for loading %s TFRecords...' % tf_glob | ||
with tf.variable_scope("TFRecords"): | ||
record_input = data_flow_ops.RecordInput(tf_glob, batch_size=batch_shape[0]) | ||
records_op = record_input.get_yield_op() | ||
records_op = tf.split(records_op, batch_shape[0], 0) | ||
records_op = [tf.reshape(record, []) for record in records_op] | ||
progbar = Progbar(len(records_op)) | ||
|
||
images = [] | ||
labels = [] | ||
for i, serialized_example in enumerate(records_op): | ||
progbar.update(i) | ||
with tf.variable_scope("parse_images", reuse=True): | ||
features = tf.parse_single_example( | ||
serialized_example, | ||
features={ | ||
'label': tf.FixedLenFeature([], tf.int64), | ||
'image_raw': tf.FixedLenFeature([], tf.string), | ||
}) | ||
img = tf.decode_raw(features['image_raw'], tf.uint8) | ||
img.set_shape(batch_shape[1] * batch_shape[2]) | ||
img = tf.reshape(img, [1] + batch_shape[1:]) | ||
|
||
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 | ||
|
||
label = tf.cast(features['label'], tf.int32) | ||
if one_hot and classes: | ||
label = tf.one_hot(label, classes) | ||
|
||
images.append(img) | ||
labels.append(label) | ||
|
||
images = tf.parallel_stack(images, 0) | ||
labels = tf.parallel_stack(labels, 0) | ||
images = tf.cast(images, tf.float32) | ||
|
||
images = tf.reshape(images, shape=batch_shape) | ||
|
||
# StagingArea will store tensors | ||
# across multiple steps to | ||
# speed up execution | ||
images_shape = images.get_shape() | ||
labels_shape = labels.get_shape() | ||
copy_stage = data_flow_ops.StagingArea( | ||
[tf.float32, tf.float32], | ||
shapes=[images_shape, labels_shape]) | ||
copy_stage_op = copy_stage.put( | ||
[images, labels]) | ||
staged_images, staged_labels = copy_stage.get() | ||
|
||
return images, labels | ||
|
||
|
||
def save_mnist_as_tfrecord(): | ||
(X_train, y_train), (X_test, y_test) = mnist.load_data() | ||
X_train = X_train[..., np.newaxis] | ||
X_test = X_test[..., np.newaxis] | ||
images_to_tfrecord(images=X_train, labels=y_train, filename='train.mnist.tfrecord') | ||
images_to_tfrecord(images=X_test, labels=y_test, filename='test.mnist.tfrecord') | ||
|
||
|
||
def cnn_layers(x_train_input): | ||
x = Conv2D(32, (3, 3), activation='relu', padding='valid')(x_train_input) | ||
x = Conv2D(64, (3, 3), activation='relu')(x) | ||
x = MaxPooling2D(pool_size=(2, 2))(x) | ||
x = Dropout(0.25)(x) | ||
x = Flatten()(x) | ||
x = Dense(128, activation='relu')(x) | ||
x = Dropout(0.5)(x) | ||
x_train_out = Dense(classes, | ||
activation='softmax', | ||
name='x_train_out')(x) | ||
return x_train_out | ||
|
||
|
||
sess = tf.Session() | ||
K.set_session(sess) | ||
|
||
save_mnist_as_tfrecord() | ||
|
||
batch_size = 1000 | ||
batch_shape = [batch_size, 28, 28, 1] | ||
epochs = 6000 | ||
classes = 10 | ||
|
||
x_train_batch, y_train_batch = read_and_decode_recordinput( | ||
'train.mnist.tfrecord', | ||
one_hot=True, | ||
classes=classes, | ||
is_train=True, | ||
batch_shape=batch_shape) | ||
|
||
x_test_batch, y_test_batch = read_and_decode_recordinput( | ||
'test.mnist.tfrecord', | ||
one_hot=True, | ||
classes=classes, | ||
is_train=True, | ||
batch_shape=batch_shape) | ||
|
||
|
||
x_batch_shape = x_train_batch.get_shape().as_list() | ||
y_batch_shape = y_train_batch.get_shape().as_list() | ||
|
||
x_train_input = Input(tensor=x_train_batch, batch_shape=x_batch_shape) | ||
x_train_out = cnn_layers(x_train_input) | ||
y_train_in_out = Input(tensor=y_train_batch, batch_shape=y_batch_shape, name='y_labels') | ||
cce = categorical_crossentropy(y_train_batch, x_train_out) | ||
train_model = Model(inputs=[x_train_input], outputs=[x_train_out]) | ||
train_model.add_loss(cce) | ||
|
||
train_model.compile(optimizer='rmsprop', | ||
loss=None, | ||
metrics=['accuracy']) | ||
train_model.summary() | ||
|
||
tensorboard = TensorBoard() | ||
|
||
train_model.fit(batch_size=batch_size, | ||
epochs=epochs) | ||
# disabled due to Keras bug | ||
# callbacks=[tensorboard]) | ||
train_model.save_weights('saved_wt.h5') | ||
|
||
K.clear_session() | ||
|
||
# Second Session, pure Keras | ||
(X_train, y_train), (X_test, y_test) = mnist.load_data() | ||
X_train = X_train[..., np.newaxis] | ||
X_test = X_test[..., np.newaxis] | ||
x_test_inp = Input(batch_shape=(None,) + (X_test.shape[1:])) | ||
test_out = cnn_layers(x_test_inp) | ||
test_model = Model(inputs=x_test_inp, outputs=test_out) | ||
|
||
test_model.load_weights('saved_wt.h5') | ||
test_model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy']) | ||
test_model.summary() | ||
|
||
loss, acc = test_model.evaluate(X_test, np_utils.to_categorical(y_test), classes) | ||
print('\nTest accuracy: {0}'.format(acc)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think a lot of people knew about this trick. Really nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
credit: @fchollet in #6928
However, it is extremely limiting, a good portion of the functionality enabled by model.compile() is skipped... see below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally I've discovered this trick has serious performance problems, see #7075 (comment). It is 20x slower than
mnist_tfrecord.py
in #6928