Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input Tensors: High Performance Large Datasets via TFRecords #6928

Closed
wants to merge 204 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
204 commits
Select commit Hold shift + click to select a range
bd9d783
K.is_placeholder() added as common.py backend function
ahundt Jun 12, 2017
0fcec00
tensorflow_backend.py session.run(fetches, feed_dict) params are forw…
ahundt Jun 12, 2017
bb3a76c
.gitignore add .vscode developer environment
ahundt Jun 12, 2017
73221c9
K.is_placeholder() added as common.py backend function
ahundt Jun 12, 2017
07057e1
training.py _check_array_lengths(weights=None)
ahundt Jun 12, 2017
4067e5f
training.py _weighted_mask_objective() accesses weights variable corr…
ahundt Jun 12, 2017
bc742c9
training.py only create placeholder when tensor is a placeholder
ahundt Jun 12, 2017
b2e8cec
tensorflow tests pass
ahundt Jun 13, 2017
745f23d
mnist_tfrecord.py added
ahundt Jun 13, 2017
9a79ebc
mnist_tfrecord.py label op more clearly marked
ahundt Jun 13, 2017
b321de6
mnist_tfrecord.py removed debug lines
ahundt Jun 13, 2017
23b5d01
Merge commit 'f4cb8900245a12d03e33a5b76d2e33aa2bdda4f0' into tfrecord
ahundt Jun 13, 2017
66251aa
training.py model.fit() fetches and feed_dict docstring added
ahundt Jun 13, 2017
23b0ebf
Revert "training.py model.fit() fetches and feed_dict docstring added"
ahundt Jun 13, 2017
9fa2747
mnist_tfrecord.py remove extraneous comment
ahundt Jun 14, 2017
db16bb2
Container.__init__ new labels tensor parameter supports tensor genera…
ahundt Jun 15, 2017
8ad7ec8
mnist_tfrecord.py remove unused read_and_decode function
ahundt Jun 15, 2017
48f6d5a
mnist_tfrecord.py progbar
ahundt Jun 15, 2017
e8d6ad3
mnist_tfrecord.py network is closer to mnist_cnn.py
ahundt Jun 15, 2017
bc2a0f4
mnist_tfrecords.py add StagingArea
ahundt Jun 15, 2017
7ecef29
is_placeholder raised exceptions now return False
ahundt Jun 15, 2017
a0639e1
mnist_tfrecord.py remove stray comment
ahundt Jun 15, 2017
b7ea15f
training.py use append for lists instead of + operator, fixes crash w…
ahundt Jun 16, 2017
36a36a0
tensorflow_backend.py prevent Function kwargs from accumulating inadv…
ahundt Jun 16, 2017
1a1183d
Revert "training.py use append for lists instead of + operator, fixes…
ahundt Jun 16, 2017
9bf9b23
training.py fix crash when ins does not have shape attribute
ahundt Jun 16, 2017
3d323ec
training.py better verification of do_validation
ahundt Jun 16, 2017
aa2aa7e
mnist_tfrecord.py remove validation data during model.fit() because i…
ahundt Jun 16, 2017
4f38311
pep8
ahundt Jun 16, 2017
773faaf
mnist_tfrecord.py one import per line
ahundt Jun 16, 2017
6d4b898
tensorflow_backend.py clarify Function.__call__ according to review c…
ahundt Jun 16, 2017
40ac8c7
test_training.py initial TFRecord test in progress.
ahundt Jun 16, 2017
5954c61
training.py more stringent checks of inputs targets and weights
ahundt Jun 16, 2017
98392b2
training.py checks tensor labels in case weights aren't supported and…
ahundt Jun 16, 2017
748c243
training.py Model checks y more stringently
ahundt Jun 16, 2017
afcbaa7
training.py cleanup backend function calls
ahundt Jun 16, 2017
42fbdbc
training.py don't slice None entries
ahundt Jun 16, 2017
75edea7
topology.py always insert layer properties into feed
ahundt Jun 16, 2017
caad65a
training.py define ins consistently
ahundt Jun 16, 2017
845dd88
tensorflow_backend.py if feed_dict entry is None add it to fetches in…
ahundt Jun 16, 2017
39a7981
trainin.py ins should at least include an empty list
ahundt Jun 17, 2017
c0d0cda
training.py first TFRecord test passes!
ahundt Jun 17, 2017
6248b54
predict_on_batch runs
ahundt Jun 17, 2017
b8dae76
self._prepare_sample_weights(sample_weight_mode, skip_indices)
ahundt Jun 17, 2017
fe2916b
mae=>mse
ahundt Jun 17, 2017
4793149
_prepare_sample_weights initialization & return changes, may be buggy
ahundt Jun 17, 2017
af382e8
_make_function implementation changes to better support predict_funct…
ahundt Jun 17, 2017
1e6b9fb
tensorflow_backend.py Function handles varying length inputs
ahundt Jun 17, 2017
ea9ffba
TFRecord predict() passes again
ahundt Jun 17, 2017
06b71c4
tensorflow_backend Function __call__ izip_longest explicit fillvalue
ahundt Jun 18, 2017
325acc1
_make_function call & list bugs fixed
ahundt Jun 18, 2017
974886e
is_placeholder and standardize user data error checks fixed
ahundt Jun 18, 2017
3c7dc09
topology.py add missing docstring
ahundt Jun 18, 2017
ffb57fd
Merge branch 'master' into tfrecord_merge
ahundt Jun 18, 2017
20077cb
test_multiprocessing.py fix test which actually throws two exceptions
ahundt Jun 18, 2017
9dd1ea9
test_training.py remove extraneous try/except
ahundt Jun 18, 2017
c37b82c
test_training.py tfrecord test of predict_on_batch, evaluate, predict
ahundt Jun 18, 2017
75fbfbb
mnist_tfrecord.py _is_placeholder workaround no longer required.
ahundt Jun 18, 2017
6de4e83
tensorflow_backend.py py 2+3 compatibility: from moves.six import zip…
ahundt Jun 18, 2017
3d75649
mnist_tfrecord.py add parens to print
ahundt Jun 19, 2017
0a1606f
test_training.py add parenthesis to print, plus an extra error check …
ahundt Jun 19, 2017
b7d44a5
wrappers_test.py fix tolerance in def test_TimeDistributed_learning_p…
ahundt Jun 19, 2017
3be24ba
Merge branch 'tfrecord' into internal_fixes_tfrecord, including `is_p…
ahundt Jun 19, 2017
39d3eba
test_training.py Input yield op tests
ahundt Jun 20, 2017
cd1c1ba
Merge branch 'master' into tfrecord
ahundt Jun 20, 2017
fb27e11
Merge branch 'master' into internal_fixes_tfrecord
ahundt Jun 20, 2017
a717d31
Merge branch 'master' into internal_fixes_tfrecord
ahundt Jun 20, 2017
02f5a0c
Merge branch 'master' into tfrecord
ahundt Jun 20, 2017
ab9fe32
mnist_tfrecord.py added with yield_op workaround (#7046)
ahundt Jun 20, 2017
9fae1b2
test_training.py extended yield_op tests, reduce code repetition
ahundt Jun 21, 2017
e0ffad4
test_training.py extended yield_op tests, reduce code repetition
ahundt Jun 21, 2017
ad9283d
generic_utils.py don't crash when dealing with batched data
ahundt Jun 21, 2017
1a5f6b6
generic_utils.py don't crash when dealing with batched data
ahundt Jun 21, 2017
d4f5b71
tensorflow_backend Function() corner case handling and error checking
ahundt Jun 21, 2017
de54a1c
tensorflow_backend.py remove self.feed_to_fetch_count
ahundt Jun 21, 2017
54241d5
is_placeholder() implemented
ahundt Jun 21, 2017
800e5a1
import zip_longest from six.moves for #7064
ahundt Jun 21, 2017
b39c97d
training.py _standardize_input_data() cleanup protects against corner…
ahundt Jun 21, 2017
606f30d
test is_placeholder()
ahundt Jun 21, 2017
67aa1cd
is_placeholder theano_backend.py
ahundt Jun 21, 2017
0387a8c
test_training.py add test_standardize_input_data()
ahundt Jun 21, 2017
722e131
extra is_placeholder() tests
ahundt Jun 21, 2017
38cfe25
Progbar() unit test
ahundt Jun 21, 2017
c67cfad
mnist_tfrecord.py added (#7061, #7072, #6928, #7046)
ahundt Jun 21, 2017
04cdf5f
test keras.backend.Function() input handling backend test
ahundt Jun 21, 2017
fa9ce56
backend_test.py reverse change.
ahundt Jun 21, 2017
8f5092d
test_backend.py K.backend.function() error handling test
ahundt Jun 21, 2017
c7f6eb3
Fix mnist_tfrecord.py runtime errors
ahundt Jun 21, 2017
3a53528
typo fix
ahundt Jun 21, 2017
b1839e2
Fix mnist_tfrecord.py runtime errors
ahundt Jun 21, 2017
5a8fb70
mnist_tfrecord.py fix import bug
ahundt Jun 21, 2017
b8085f9
mnist_tfrecord.py pep8
ahundt Jun 21, 2017
3da96e5
mnist_tfrecord.py add parallelism option
ahundt Jun 21, 2017
919fdff
mnist_tfrecord.py add parallelism option
ahundt Jun 21, 2017
0cafa1e
Merge branch 'master' into tfrecord
ahundt Jun 21, 2017
d1ea4b7
Merge branch 'master' into progbar_batches
ahundt Jun 21, 2017
b8019c6
reorder inputs
ahundt Jun 21, 2017
6c776c7
Merge branch 'master' into is_placeholder
ahundt Jun 21, 2017
f56f1fa
backend_test.py rename f to g to fix test error
ahundt Jun 21, 2017
e489734
Merge branch 'tfrecord' of github.com:ahundt/keras into tfrecord
ahundt Jun 21, 2017
837ceeb
mnist_tfrecord remove whitespace
ahundt Jun 21, 2017
750a211
mnist_tfrecord.py indentation fix
ahundt Jun 21, 2017
6c42c6c
Merge commit 'de73eda89a916c4dd46ce74058bb2664455ed9db' into internal…
ahundt Jun 21, 2017
99da67d
Merge branch 'progbar_batches' into internal_fixes_tfrecord
ahundt Jun 22, 2017
0a8c13b
backend_test.py cleanup new TF backend test
ahundt Jun 22, 2017
916eaaf
Merge branch 'master' into standardize_input_data
ahundt Jun 22, 2017
450eda3
Merge branch 'master' into standardize_input_data
ahundt Jun 22, 2017
2a1cd77
lower batch size and epochs
ahundt Jun 23, 2017
a63dc5a
Merge branch 'master' into tfrecord
ahundt Jun 23, 2017
0c43bba
loss defaults to None in compile()
ahundt Jun 21, 2017
40d16c3
mnist_tfrecord.py
ahundt Jun 23, 2017
63b21aa
try a tricky way of getting closer to the "ideal" api
ahundt Jun 24, 2017
b044773
name typo
ahundt Jun 24, 2017
7ced84d
set y tensor values to non-tensor values
ahundt Jun 24, 2017
ea3e926
handle non list case
ahundt Jun 24, 2017
956885f
handle some error cases
ahundt Jun 24, 2017
320cbe9
param fix
ahundt Jun 24, 2017
d4ea224
fix some arg stuff
ahundt Jun 24, 2017
9736151
remove extraneous lines
ahundt Jun 24, 2017
4eb95e3
Improve numpy list handling
ahundt Jun 24, 2017
f3e25de
model.fit(steps_per_epoch) added
ahundt Jun 24, 2017
13867a1
Merge branch 'fit_steps_per_epoch' into ideal_tfrecord_fit_steps_per_…
ahundt Jun 24, 2017
05b3ecf
mnist_tfrecord.py support steps_per_epoch
ahundt Jun 24, 2017
8bac525
Merge branch 'ideal_tfrecord_fit_steps_per_epoch' into tfrecord
ahundt Jun 24, 2017
e8cdd22
fix test errors
ahundt Jun 24, 2017
7d8f0b7
added _check_num_samples for cases when batch_size does not apply
ahundt Jun 26, 2017
9498dba
rename tensors param to better skip_standardizing param name.
ahundt Jun 26, 2017
76b1a2b
mnist_tfrecord.py pep8
ahundt Jun 26, 2017
563c5ff
fix test failures
ahundt Jun 26, 2017
0de3fc6
remove inaccurate warning
ahundt Jun 26, 2017
fb67e06
is_keras_tensor() document new argument expect_other_types
ahundt Jun 27, 2017
2e214c4
Merge commit '59cd1c3994153a66084b00fadcafad2af5a15dd7' into tfrecord
ahundt Jul 3, 2017
b4fb086
remove extraneous imports of is_placeholder
ahundt Jul 25, 2017
d4f0b1c
_standardize_input_data rename skip_standardizing to input_tensors ba…
ahundt Jul 25, 2017
b8ac699
Merge branch 'standardize_input_data' into tfrecord
ahundt Jul 26, 2017
1efd0e5
Merge commit '84ceb94055b831c486dbf4955fdf1ba0f63320d1' into tfrecord
ahundt Jul 26, 2017
4d6e3d4
K.is_placeholder() becomes simple member variable is_placeholder with…
ahundt Aug 1, 2017
5b85132
_standardize_input_data add missing space to exception error string
ahundt Aug 1, 2017
d60f1e2
_standardize_input_data accepts feed inputs
ahundt Aug 1, 2017
c608ea8
Merge branch 'is_placeholder' into is_placeholder_and_standardize_inp…
ahundt Aug 1, 2017
478a680
_standardize_input_data() now only standardizes data that is converti…
ahundt Aug 1, 2017
29efd9f
Merge commit '01e2148732e4083b4850345e5ce4dd499cb5999e' into standard…
ahundt Aug 1, 2017
65a5bce
fix _feed_inputs name
ahundt Aug 1, 2017
620874f
training.py _standardize_input_data() accounts for input tensors (#7067)
ahundt Aug 1, 2017
886fa04
Merge commit '01e2148732e4083b4850345e5ce4dd499cb5999e' into fit_step…
ahundt Aug 1, 2017
e9b63cc
improved fit(steps_per_epoch) with separate internal epoch loop in _f…
ahundt Aug 1, 2017
5bcdc36
fit(steps_per_epoch) initial validation support
ahundt Aug 1, 2017
9f343a1
training.py pep8
ahundt Aug 1, 2017
ef47131
Merge commit '01e2148732e4083b4850345e5ce4dd499cb5999e' into tfrecord
ahundt Aug 1, 2017
a42a113
Merge branch 'fit_steps_per_epoch' into tfrecord
ahundt Aug 4, 2017
d3a917e
mnist_tfrecord.py simpler version
ahundt Aug 4, 2017
f6ae5ad
mnist_tfrecord.py many fewer epochs
ahundt Aug 4, 2017
171e932
mnist_tfrecord.py test at bottom runs
ahundt Aug 4, 2017
3d4e980
Merge branch 'mnist_tfrecord' into fit_steps_per_epoch
ahundt Aug 5, 2017
65246ad
Merge branch 'progbar_batches' into fit_steps_per_epoch
ahundt Aug 5, 2017
be423e3
mnist_tfrecord.py fix key missing lines
ahundt Aug 5, 2017
0b83662
mnist_tfrecord.py add coordinator
ahundt Aug 5, 2017
4cd6bb3
removed extraneous line
ahundt Aug 5, 2017
4a15c44
replace is_placeholder with is_keras_placeholder to resolve conflict …
ahundt Aug 5, 2017
181308c
Merge branch 'fit_steps_per_epoch' into tfrecord
ahundt Aug 5, 2017
8ad3844
(#6928) replace _is_placeholder with is_keras_placeholder to match (#…
ahundt Aug 5, 2017
49a105c
Merge branch 'standardize_input_data' into tfrecord, steps_per_epoch …
ahundt Aug 5, 2017
a73343b
Merge branch "master" into branch "tfrecords"
ahundt Aug 5, 2017
6b2a975
(#6928) remove is_placeholder to match (#7113)
ahundt Aug 5, 2017
c7cc493
merge branch "master" into branch "fit_steps_per_epoch"
ahundt Aug 5, 2017
cd7df51
Merge branch 'fit_steps_per_epoch' into tfrecord
ahundt Aug 5, 2017
48799e8
mnist_tfrecord.py and training.py clean up based on review (#7113)
ahundt Aug 7, 2017
2bf9695
mnist_tfrecord.py extended description
ahundt Aug 8, 2017
f1af666
training.py fix test error
ahundt Aug 8, 2017
351e56b
mnist_tfrecord.py and training.py fixed review comments, docs, and er…
ahundt Aug 8, 2017
d755b80
training.py fix unit test error for steps_per_epoch
ahundt Aug 8, 2017
60a7193
Merge branch 'fit_steps_per_epoch' into tfrecord
ahundt Aug 8, 2017
d88987d
add Model.evaluate(steps) and correct docs according to review. (#7113)
ahundt Aug 8, 2017
3e13d6c
Model.evaluate(steps) and Model.predict(steps)
ahundt Aug 8, 2017
f96f30a
merge branch fit_steps_per_epoch into branch tfrecord
ahundt Aug 8, 2017
56736f3
fix docstring comments from review (#7113)
ahundt Aug 8, 2017
e2e142f
adapting mnist_tfrecord.py for fit(y_train_in)
ahundt Aug 8, 2017
fdd9556
topology.py mnist_tfrecord.py bugfix
ahundt Aug 8, 2017
65d9103
add mnist_tfrecord_recordinput.py
ahundt Aug 8, 2017
6bef8c6
mnist_tfrecord_recordinput.py test with tensors
ahundt Aug 9, 2017
745e02a
training.py create _check_for_recompile() with calls in fit() and eva…
ahundt Aug 9, 2017
1ec1ce8
cntk_backend.py fix is_keras_tensor()
ahundt Aug 9, 2017
3142a25
mnist_tfrecord_recordinput.py use default keras session
ahundt Aug 9, 2017
f9b3874
test_training.py minor variable rename for clarity
ahundt Aug 9, 2017
38225bb
pep8
ahundt Aug 9, 2017
bbf6c3e
Model.train_on_batch() and Model.test_on_batch() support input tensors
ahundt Aug 9, 2017
91bee67
mnist_tfrecord*.py extended documentation
ahundt Aug 9, 2017
c0e6e3b
mnist_tfrecord*.py extended documentation
ahundt Aug 9, 2017
1831064
mnist_tfrecord_recordinput.py evaluate() bugfix
ahundt Aug 9, 2017
86257ee
training.py documentation of methods internal to Model
ahundt Aug 9, 2017
c62af05
training.py fix _step_loop
ahundt Aug 9, 2017
f4e6e2a
mnist_tfrecord_recordinput.py fix accuracy calculation
ahundt Aug 9, 2017
d89f467
pep8
ahundt Aug 9, 2017
b2635a1
training.py _step_loop() fix test error, progbar supports more data l…
ahundt Aug 9, 2017
a8d743d
tensorflow_backend.py remove unused self.feed_to_fectch_count
ahundt Aug 9, 2017
1384d25
Merge branch 'tensorflow_backend_Function' into tfrecord
ahundt Aug 9, 2017
7e0ac46
mnist_tfrecord*.py clean up extraneous imports
ahundt Aug 9, 2017
4ec2e63
mnist_tfrecord_recordinput.py Input(...) -> layers.Input(...)
ahundt Aug 9, 2017
41fdd5a
training.py improve docstrings and error case
ahundt Aug 11, 2017
c536879
Merge commit '41fdd5aa5fe080bf5fa12638b9abc9f50d835a39' into tfrecord
ahundt Aug 11, 2017
353bb12
Allows custom tensors for target based on (#5927) implemented in (#69…
ahundt Aug 11, 2017
6aba048
Merge branch 'master' into tfrecord
ahundt Aug 11, 2017
8fd6fe0
test_training.py major simplification of tensor test
ahundt Aug 11, 2017
0eb2a8a
pep8
ahundt Aug 11, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions examples/mnist_tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from keras import layers
from keras import objectives
from keras.utils import np_utils
from keras import objectives

from tensorflow.contrib.learn.python.learn.datasets import mnist

Expand Down Expand Up @@ -118,25 +117,29 @@ def cnn_layers(x_train_input):
x_batch_shape = x_train_batch.get_shape().as_list()
y_batch_shape = y_train_batch.get_shape().as_list()

# The input tensors are provided directly into the Model network.
# The network is fixed once it is initialized, so it must be
# reconstructed every time a new input data source is needed.
# This is substantially different from typical
# Keras numpy array inputs, and is more like TensorFlow.
x_train_input = layers.Input(tensor=x_train_batch, batch_shape=x_batch_shape)
x_train_out = cnn_layers(x_train_input)
y_train_input = layers.Input(tensor=y_train_batch, batch_shape=y_batch_shape)
train_model = Model(inputs=x_train_input, outputs=x_train_out)

cce = objectives.categorical_crossentropy(y_train_batch, x_train_out)
train_model.add_loss(cce)

# Do not pass the loss directly to model.compile()
# because it is not yet supported for Input Tensors.
train_model.compile(optimizer='rmsprop',
loss=None,
loss='categorical_crossentropy',
metrics=['accuracy'])
train_model.summary()

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
train_model.fit(epochs=epochs,
steps_per_epoch=steps_per_epoch)

# The input data was created with x_train_input,
# so only the label data needs to be provided.
train_model.fit(y=y_train_input,
epochs=epochs,
steps_per_epoch=steps_per_epoch)
train_model.save_weights('saved_wt.h5')

coord.request_stop()
Expand Down
227 changes: 227 additions & 0 deletions examples/mnist_tfrecord_recordinput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
'''Optimized MNIST dataset with TFRecords, the standard TensorFlow data format.

TFRecord is a data format supported throughout TensorFlow.
For a straightforward usage example see mnist_tfrecord.py.
This example demonstrates how to write and read TFRecord data using
Input Tensors in a way that is better optimized for high performance
on large datasets with Keras.

Gets to 99.25% test accuracy after 12 epochs
(there is still some margin for parameter tuning).
'''
import os
import copy
import time

import numpy as np

import tensorflow as tf
from tensorflow.python.ops import data_flow_ops
from keras import backend as K
from keras.models import Model
from keras import layers
from keras.utils.generic_utils import Progbar

from keras.datasets import mnist

if K.backend() != 'tensorflow':
raise RuntimeError('This example can only run with the '
'TensorFlow backend for the time being, '
'because it requires TFRecords, which '
'are not supported on other platforms.')


def images_to_tfrecord(images, labels, filename):
"""Write images and their labels to a TFRecord file.

# Arguments
images: A numpy array or list of image data with
shape (images, rows, cols, depth).
labels: A numpy array of labels, with one for each image.
filename: Path and name for the output dataset TFRecord file.
An example is `'path/to/mnist_train.tfrecord'`
"""
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

""" Save data into TFRecord """
if not os.path.isfile(filename):
num_examples = images.shape[0]

rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]

print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()
else:
print('tfrecord %s already exists' % filename)


def read_and_decode_recordinput(tf_glob, one_hot=True,
classes=None,
batch_shape=None, parallelism=1):
"""Get a TF Tensor that supplies shuffled batches of images.

# Arguments
tf_glob: File path for selecting one or more tfrecord files.
Examples are `'path/to/data.tfrecord'` and `'path/to/*.tfrecord'`.
one_hot: Use one hot encoding for labels, also known as categorical.
batch_shape: Specify the desired image batch shape, where the first
entry is the batch size. MNIST might be (128, 28, 28, 1).
parallelism: The number of threads to use for loading new data.
A reasonable value is the number of logical cores on your processor.
"""
if batch_shape is None:
batch_shape = [1000, 28, 28, 1]
print 'Creating graph for loading %s TFRecords...' % tf_glob
with tf.variable_scope("TFRecords"):
record_input = data_flow_ops.RecordInput(
tf_glob, batch_size=batch_shape[0], parallelism=parallelism)
records_op = record_input.get_yield_op()
records_op = tf.split(records_op, batch_shape[0], 0)
records_op = [tf.reshape(record, []) for record in records_op]
progbar = Progbar(len(records_op))

images = []
labels = []
for i, serialized_example in enumerate(records_op):
progbar.update(i)
with tf.variable_scope("parse_images", reuse=True):
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
})
img = tf.decode_raw(features['image_raw'], tf.uint8)
img.set_shape(batch_shape[1] * batch_shape[2])
img = tf.reshape(img, [1] + batch_shape[1:])

img = tf.cast(img, tf.float32) * (1. / 255) - 0.5

label = tf.cast(features['label'], tf.int32)
if one_hot and classes:
label = tf.one_hot(label, classes)

images.append(img)
labels.append(label)

images = tf.parallel_stack(images, 0)
labels = tf.parallel_stack(labels, 0)
images = tf.cast(images, tf.float32)
images = tf.reshape(images, shape=batch_shape)

return images, labels


def save_mnist_as_tfrecord():
"""Save one tfrecord file for each of the train and test mnist datasets.
"""
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train[..., np.newaxis]
x_test = x_test[..., np.newaxis]
images_to_tfrecord(images=x_train, labels=y_train, filename='train.mnist.tfrecord')
images_to_tfrecord(images=x_test, labels=y_test, filename='test.mnist.tfrecord')


def cnn_layers(x_train_input):
"""Create the CNN layers for use with either numpy inputs or tensor inputs.
"""
x = layers.Conv2D(32, (3, 3), activation='relu', padding='valid')(x_train_input)
x = layers.Conv2D(64, (3, 3), activation='relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)
x = layers.Dropout(0.25)(x)
x = layers.Flatten()(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
x_train_out = layers.Dense(classes,
activation='softmax',
name='x_train_out')(x)
return x_train_out


sess = K.get_session()

save_mnist_as_tfrecord()

batch_size = 100
batch_shape = [batch_size, 28, 28, 1]
epochs = 12
steps_per_epoch = 1000
classes = 10
parallelism = 10

x_train_batch, y_train_batch = read_and_decode_recordinput(
'train.mnist.tfrecord',
one_hot=True,
classes=classes,
batch_shape=batch_shape,
parallelism=parallelism)

x_batch_shape = x_train_batch.get_shape().as_list()
y_batch_shape = y_train_batch.get_shape().as_list()

# The input tensors are provided directly into the Model network.
# The network is fixed once it is initialized, so it must be
# reconstructed every time a new input data source is needed.
# This is substantially different from typical
# Keras numpy array inputs, and is more like TensorFlow.
x_train_in = layers.Input(tensor=x_train_batch, batch_shape=x_batch_shape)
x_train_out = cnn_layers(x_train_in)
y_train_in = layers.Input(tensor=y_train_batch, batch_shape=y_batch_shape, name='y_labels')
train_model = Model(inputs=[x_train_in], outputs=[x_train_out])
train_model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])

# The input data was created with x_train_input,
# so only the label data needs to be provided.
train_model.fit(y=y_train_in,
batch_size=None,
epochs=epochs,
steps_per_epoch=steps_per_epoch)
train_model.save_weights('saved_wt.h5')

K.clear_session()

# Second Session, test data
x_test_batch, y_test_batch = read_and_decode_recordinput(
'test.mnist.tfrecord',
one_hot=True,
classes=classes,
batch_shape=batch_shape,
parallelism=parallelism)

x_batch_shape = x_test_batch.get_shape().as_list()
y_batch_shape = y_test_batch.get_shape().as_list()

# Create a completely new network for new input data.
x_test_in = layers.Input(tensor=x_test_batch, batch_shape=x_batch_shape)
x_test_out = cnn_layers(x_test_in)
y_test_in = layers.Input(tensor=y_test_batch, batch_shape=y_batch_shape, name='y_labels')
test_model = Model(inputs=[x_test_in], outputs=[x_test_out])
test_model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
test_model.load_weights('saved_wt.h5')
test_model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

# Take steps for each element of validation data.
validation_samples = 10000
evaluate_steps = validation_samples / batch_size
loss, acc = test_model.evaluate(y=y_test_in, steps=evaluate_steps)
print('\nTest accuracy: {0}'.format(np.mean(acc)))
2 changes: 1 addition & 1 deletion examples/variational_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def sampling(args):
# Custom loss layer
class CustomVariationalLayer(Layer):
def __init__(self, **kwargs):
self.is_placeholder = True
self.is_keras_placeholder = True
super(CustomVariationalLayer, self).__init__(**kwargs)

def vae_loss(self, x, x_decoded_mean):
Expand Down
2 changes: 1 addition & 1 deletion examples/variational_autoencoder_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def sampling(args):
# Custom loss layer
class CustomVariationalLayer(Layer):
def __init__(self, **kwargs):
self.is_placeholder = True
self.is_keras_placeholder = True
super(CustomVariationalLayer, self).__init__(**kwargs)

def vae_loss(self, x, x_decoded_mean_squash):
Expand Down
12 changes: 7 additions & 5 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,16 @@ def placeholder(
name=name)
x._keras_shape = shape
x._uses_learning_phase = False
x.is_keras_placeholder = True
return x


def is_keras_tensor(x):
if not isinstance(x, (C.variables.Constant,
C.variables.Variable,
C.variables.Parameter,
C.ops.functions.Function)):
def is_keras_tensor(x, expect_other_types=False):
if (not expect_other_types and
not isinstance(x, (C.variables.Constant,
C.variables.Variable,
C.variables.Parameter,
C.ops.functions.Function))):
raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) + '`. '
'Expected a symbolic tensor instance.')
return hasattr(x, '_keras_history')
Expand Down
Loading