Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input Tensors: High Performance Large Datasets via TFRecords #6928

Closed
wants to merge 204 commits into from

Conversation

ahundt
Copy link
Contributor

@ahundt ahundt commented Jun 10, 2017

Implements the Input Tensor API detailed in #7102 (comment)

Update (2017-08-08): Two supported use cases based on reviews:

# API 2

model = # on top of a tensor input
model.add_loss()  # involving y_tensor
model.fit(epochs=10, steps_per_epoch=1000)

# API 3

model = # on top of a tensor input
model.compile()
model.fit(y=y_tensor, epochs=10, steps_per_epoch=1000)

API usage, with working mnist_tfrecord.py implementation.

Summary

This PR adds support for yield ops to Keras plus an example utilizing TFRecords. Correct support for yield ops in Model adds valuable functionality not currently supported by Keras for the reasons detailed below.

It re-compiles on demand when tensors are passed to y in model.fit().

Yield ops

Yield ops aka data tensors, such as RecordInput (test code), are different from tf.Variable because they provide data entirely on the C++ side when run without fetches or feed_dict, and are thus extremely efficient for large data like images.

Changes

Here are the changes, marked with ==bugfix== and ==new param== in the comments below:

# tf yield ops that supply dataset images and labels
x_train_batch, y_train_batch = read_and_decode_recordinput(...)

# create a basic cnn
x_train_input = Input(tensor=x_train_batch)
x_train_out = cnn_layers(x_train_input)

# y label batch is input & output
# Perhaps this aspect of API usage can be improved?
y_train_in = Input(tensor=y_train_batch)

# ==bugfix==
# This call causes a crash without this patch because
# an invalid call is made that is equivalent to:
# K.placeholder(dtype=x_train_input)
train_model = Model(inputs=[x_train_in], outputs=[x_train_out])

# ==bugfix==
# This call will crash without this patch because
# it is assumed the parameters `x` and `y` are
# provided here and not via the ops
# x_train_batch and y_train_batch 
train_model.compile(optimizer='rmsprop',
                    loss='categorical_crossentropy',
                    metrics=['accuracy'])

# ==bugfix== + ==new param==
# This call will crash without this patch because
# the changes in tensor order caused by the
# constructor update, which accepts yield ops,
# were not previously accounted for.
#
# A new param steps_per_epoch is added
# which works just like in fit_generator()
train_model.fit(None, y_train_in, 
                batch_size=batch_size,
                epochs=epochs,
                steps_per_epoch=10000)

There are extended unit tests and support for Input Tensors with each of the following APIs given an Input Tensor (aka yield op) x and y:

    # train_on_batch
    out = model.train_on_batch(x, y)

    # test_on_batch
    out = model.test_on_batch(x, y)

    # predict_on_batch
    out = model.predict_on_batch(x)

    # fit
    out = model.fit(x, y, epochs=1, batch_size=batch_size,
                    steps_per_epoch=steps_per_epoch)

    # evaluate
    out = model.evaluate(x, y, batch_size=batch_size,
                         steps=steps_per_epoch)

    # predict
    out = model.predict(x, batch_size=batch_size,
                        steps=steps_per_epoch)

TFRecord

TFRecord support is a side effect and key motivator of yield op support, and examples/mnist_tfrecord.py demonstrates usage.

Update: I've moved the fetches and feed_dict public API design into #6974. This PR now focuses more narrowly on supporting an input tensor yield_op as a parameter.

Performance Update

This latest version runs mnist_tfrecord.py twice as fast as it did previously!

@drewrice2
Copy link

drewrice2 commented Jun 10, 2017

watching this PR closely, thanks for your efforts @ahundt

@fchollet
Copy link
Collaborator

We already support fitting from tensors / TFRecords. Please make a case for this PR.

@ahundt
Copy link
Contributor Author

ahundt commented Jun 10, 2017

We already support fitting from tensors / TFRecords. Please make a case for this PR.

@fchollet Is there a TFRecords example with the model.fit() API somewhere that I missed? Perhaps I also misunderstood the API usage, my notes below assume changes are needed to support TFRecords.

Features

  • TFRecords example
  • Several small changes needed so the example does not crash
  • Users can pass additional ops and a feed dict to session.run() via model.fit()

Why modify model.fit() to accept TFRecords instead of writing my own training loop?

  • The Keras API design is excellent so I'd prefer to reuse it
  • Avoid maintaining a fork and compatibility as keras evolves
  • Use other libraries built on keras
  • Minimize the effort to switch backends in user code subsets that don't depend on TFRecords.
  • I know of ~6 others that have searched or asked about this functionality. I hope this will help them too. :-)

I can implement things in a different manner as needed. I can also split this into several smaller pull requests according to the headline changes below.

model.fit(x=None, y=None)

In several places model.fit() assumed that either x or y is not None, which isn't the case if the data is supplied by a Tensor op.

Disabling sample weights

Sample_weights in model.fit() assumes x is not None and y is not None in many places. Additionally, when a tensor supplies the samples, I don't think the information necessary to initialize the right number, dimension, and values of sample weight placeholders is available to this part of the code.

I chose the disabled parameter for backwards compatibility, but I know that naming scheme doesn't fit with the rest of the API.

is_placeholder()

I use is_placeholder() to differentiate numpy inputs from tensorflow op inputs, which allows y_true and y_pred to be correctly assigned to the appropriate tf op, without this change the program crashes by trying to incorrectly pass an op to K.placeholder.

is_placeholder can also be useful when enumerating the layers, inputs, outputs, and transforming an arbitrary model. For example anything that walks the graph may or may not want to modify placeholders, such as an automated labeling model to segmentation model converter.

The implementation of is_placeholder was based on _uses_learning_phase and _keras_shape.

TFRecords Example

There is quite a bit of interest in training models from TFRecords, with a bunch of comments + thumbs in the keras-contrib PR and tensorflow issue. Loading from TFRecords will certainly be faster than from disk, and is suggested in the high performance models docs. The dataset I'm using was released as a TFRecord, so it is quite convenient.

I didn't really like this workaround:

Model(inputs=[x_train_input, y_train_in_out], outputs=[x_train_out, y_train_in_out])

Is there a better way to supply label values from a TFRecord?

class Function(fetches, feed_dict)

Before this change, auxiliary tf tensor data & operations could almost be provided to the session.run() call, but not quite since the feed_dict and fetches params end up duplicated. The return path of the updated variable might still need some tweaking. Why might this be useful? API & algorithm extensions that need to be fast can be added without modifying Keras models or training functions.

My motivating API extension example is training on a realtime data source like a robot. I want to supply live images via an op while actively fitting, and log the image via another op. There is no need for the logging op to affect my model/training but the source/sink should run in the same session, and the overhead of moving the images to numpy arrays and serializing would cause frames to drop.

@fchollet
Copy link
Collaborator

So currently you are supposed to be able to:

  • build models on top of tensors
  • train with fit with partial Numpy data (i.e. any tensor input will no expect Numpy data), e.g. fit(x=None, y)

What changes?

@ahundt
Copy link
Contributor Author

ahundt commented Jun 11, 2017

What changes?

Here are the changes, marked with ==bugfix== or ==new param== in the comments below:

# tf ops that supply dataset images and labels
x_train_batch, y_train_batch = read_and_decode_recordinput(...)

# create a basic mlp
x_train_input = Input(tensor=x_train_batch)
x_train_out = cnn_layers(x_train_input)

# y label batch is input & output
y_train_in_out = Input(tensor=y_train_batch)

# ==bugfix== 
# these lines cause a crash without the patch
# because an invalid call equivalent to this is made:
# K.placeholder(dtype=x_train_input)
train_model = Model(inputs=[x_train_input, y_train_in_out],
                    outputs=[x_train_out, y_train_in_out])

# ==new param option==
# model.compile(sample_weight_mode='disabled')

# ==bugfix==
# This call will crash without the patch,
# even if sample_weight_mode=None.
train_model.compile(optimizer='rmsprop',
                    loss='categorical_crossentropy',
                    metrics=['accuracy'],
                    sample_weight_mode='disabled')

# ==new param==
# sess.run(fetches, feed_dict) parameters handled correctly in backend
hello = tf.constant('Hello, TensorFlow!')
train_model.fit(batch_size=batch_size, epochs=300, fetches=hello)

@ahundt ahundt changed the title TFRecord and expanded tf tensor support, initial implementation. model.fit() TFRecord and expanded tf tensor support, initial implementation. Jun 11, 2017
@ahundt
Copy link
Contributor Author

ahundt commented Jun 11, 2017

It seems right now I've broken more cases than I fixed in the tests... I knew this code section would get pretty complicated. Might be a few days before I can address further code fixes.

@fchollet
Copy link
Collaborator

fchollet commented Jun 11, 2017

If sample weighting cannot be supported, then it should be disabled automatically and a ValueError should be raised if the user attempt to pass some sample weights. This should not be handled by the user via sample_weight_mode='disabled'

read_and_decode_recordinput

What does it do exactly and why is it necessary?

fetches

Can you explain the fetches API in further detail?

Thanks!

@ahundt
Copy link
Contributor Author

ahundt commented Jun 11, 2017

If sample weighting cannot be supported, then it should be disabled automatically and a ValueError should be raised if the user attempt to pass some sample weights. This should not be handled by the user via sample_weight_mode='disabled'

Great idea, why didn't I think of that? Thanks!

read_and_decode_recordinput

What does it do exactly and why is it necessary?

It is not part of the API, it is user code creating the tf ops that feed data in mnist_tensorflow.py, read_and_decode_recordinput(). I'll explain in more detail in the datasets pull request, because it is most relevant there.

fetches

Can you explain the fetches API in further detail?

You may slap your head when you read this one, but no worries I'm happy to clarify! :-)

fetches is the first parameter of tf.session.run()

run(
    fetches,
    feed_dict=None,
    options=None,
    run_metadata=None
)

@fchollet
Copy link
Collaborator

In sess.run, fetches is a list (or otherwise nested datastructure) of tensors to return (nodes to fetch from the computational graph). But fit does not return tensors. Hence my confusion. What does fetches achieve in fit: how are users supposed to understand its usage? How would you write docs about it?

@ahundt
Copy link
Contributor Author

ahundt commented Jun 13, 2017

how are users supposed to understand its usage? How would you write docs about it?

What is does is forwards fetches and feed_dict to session.run().

Here are some first pass doc lines:

history = model.fit(fetches,feed_dict)
"""
# Arguments

    fetches: TensorFlow backend only. A single TensorFlow graph element, 
        a list of graph elements, or a dictionary whose values are graph elements 
        or lists of graph elements. Passes tensor ops to `tf.session.run()` and 
        returns an additional tensor tuple upon completion. This is for advanced 
        users that require auxiliary processing as fit runs. When provided,
        additional tensor values are returned:
            `history, tensors = model.fit(fetches,feed_dict)`
        See https://www.tensorflow.org/api_docs/python/tf/Session for more details
        on this parameter.
    feed_dict: TensorFlow backend only. A dictionary that maps TensorFlow graph
        elements to values.  This is for advanced users that require auxiliary 
        processing as `model.fit()` runs. 
        See https://www.tensorflow.org/api_docs/python/tf/Session for more details on this parameter.
"""

Motivation: I want to supply live images via an op while actively fitting, and log the image via another op. Details are in my previous post #6928 (comment).

ahundt added 4 commits August 9, 2017 02:54
* tensorflow_backend_Function:
  backend_test.py cleanup new TF backend test
  backend_test.py rename f to g to fix test error
  typo fix
  test_backend.py K.backend.function() error handling test
  backend_test.py reverse change.
  test keras.backend.Function() input handling backend test
  import zip_longest from six.moves for keras-team#7064
  tensorflow_backend.py remove self.feed_to_fetch_count
  tensorflow_backend Function() corner case handling and error checking

# Conflicts:
#	keras/backend/tensorflow_backend.py
#	tests/keras/backend/backend_test.py
ahundt added 2 commits August 11, 2017 11:19
* commit '41fdd5aa5fe080bf5fa12638b9abc9f50d835a39':
  training.py improve docstrings and error case
  fix docstring comments from review (keras-team#7113)
fchollet pushed a commit that referenced this pull request Aug 11, 2017
* generic_utils.py don't crash when dealing with batched data

* Progbar() unit test

* mnist_tfrecord.py added (#7061, #7072, #6928, #7046)

* Fix mnist_tfrecord.py runtime errors

* mnist_tfrecord.py pep8

* mnist_tfrecord.py add parallelism option

* reorder inputs

* mnist_tfrecord.py indentation fix

* lower batch size and epochs

* loss defaults to None in compile()

* mnist_tfrecord.py

* model.fit(steps_per_epoch) added

* added _check_num_samples for cases when batch_size does not apply

* fix test failures

* remove inaccurate warning

* improved fit(steps_per_epoch) with separate internal epoch loop in _fit_loop

* fit(steps_per_epoch) initial validation support

* training.py pep8

* mnist_tfrecord.py fix key missing lines

* mnist_tfrecord.py add coordinator

* removed extraneous line

* mnist_tfrecord.py and training.py clean up based on review (#7113)

* mnist_tfrecord.py extended description

* training.py fix test error

* mnist_tfrecord.py and training.py fixed review comments, docs, and error messages (#7113)

* training.py fix unit test error for steps_per_epoch

* fix docstring comments from review (#7113)

* training.py improve docstrings and error case
ahundt added 4 commits August 11, 2017 16:25
* master:
  model.fit(steps_per_epoch), 	mnist_tfrecord.py, progbar np.mean (keras-team#7113)
  Change default size to allow different MobileNet sizes (keras-team#7586)
  cntk backend: fix the reversed rnn bug (keras-team#7593)
  Fix mask for multi output --> multi inputs (keras-team#7591)
  [RELNOTES] Subtract merge layer (keras-team#7573)
  update docker with cntk 2.1 (keras-team#7565)
  [RELNOTES] Move constraint management to be based on variable attributes (like TF). (keras-team#7564)
  Add handling for `dtype` arg in initializer config.
  Fix keras-team#7550. (keras-team#7552)
  remove unnecessary function definition (keras-team#7542)
  refactor the function - _convert_string_dtype (keras-team#7540)
  Add batchnorm tests for CNTK (keras-team#7534)
  Clean up RNN tests (keras-team#7529)
  Add merge tests (keras-team#7533)

# Conflicts:
#	examples/mnist_tfrecord.py
#	keras/engine/training.py
#	keras/utils/generic_utils.py
#	tests/keras/engine/test_training.py

def predict(self, x, batch_size=32, verbose=0):
def predict(self, x=None, batch_size=32, verbose=0, steps=None):
"""Generates output predictions for the input samples.

Computation is done in batches.

# Arguments
x: the input data, as a Numpy array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or None

@@ -159,6 +159,8 @@ def test_is_keras_tensor(self):
assert k.is_keras_tensor(keras_var) is False
keras_placeholder = k.placeholder(shape=(2, 4, 5))
assert k.is_keras_tensor(keras_placeholder) is False
assert getattr(keras_placeholder, 'is_keras_placeholder', False) is True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a test for
K.is_keras_tensor(..., expect_other_types=True):

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is rather massive (1000+ lines) and I think it contains a lot of changes that don't seem necessary in order to implement the "data tensor action plan". In particular:

  • is_keras_tensor(x, expect_other_types=False) (it's also a rather strange API -- try asking a random user what they expect this function to do, given its signature)
  • Changes related to the is_keras_placeholder attribute on both layers and tensors
  • Changes to the Function API aren't immediately related to data tensors support

I would recommend simplifying, and focus on:

  • Supporting the steps argument in evaluate/predict. This is a self-contained changed.
  • Supporting recompilation if a target data tensor is passed to fit. This is also a self-contained changed.

These can be done in two PRs. Both would only involve changes in training.py as far as I can tell. I would expect these changes to end up at ~30-50% of the total size of this PR.

@ahundt
Copy link
Contributor Author

ahundt commented Aug 22, 2017

Much of the functionality in this PR has been reimplemented and merged on master as of 2.0.7. I don't have cycles right now but I plan on combing through this to separate any key tests and other functionality out. However, at this point I suggest using the examples in master, and if anything is broken we can make a unit test so the changes can be reimplemented.

Update: 2017-08-24 After further testing master is still broken for my real use cases. At this time I still recommend using this PR over master if you wish to use input tensors.

@fchollet
Copy link
Collaborator

fchollet commented Aug 22, 2017 via email

@ahundt
Copy link
Contributor Author

ahundt commented Aug 23, 2017

@fchollet For some reason I didn't see your recent comment #6928 (review) until now, I'm sorry about that.

I think it contains a lot of changes that don't seem necessary in order to implement the "data tensor action plan".

There is definitely a small amount of code cleanup that could be left out. I may be mistaken, but I suspect that something like ~80% of these changes will prove necessary in practice. This is because there are several use cases not yet supported in master which I may not have elucidated very clearly yet, and I apologize for that.

is_keras_tensor(x, expect_other_types=False) (it's also a rather strange API -- try asking a random user what they expect this function to do, given its signature)

I agree it is not pretty, what about renaming the parameter to is_keras_tensor(no_exception=False)?

I believe the most correct behavior of is_keras_tensor() should be to simply return False if the parameter is not a keras tensor, eliminating exceptions. An API changeover period like from keras 1 to 2 API could be utilized and the parameter dropped entirely after 6 months.

This could easily be a separate PR, are you open to that possibility?

Changes to the Function API aren't immediately related to data tensors support

The zipping of lists and moving lists between tensor input and the feed dict is required for it to work, even with basic use cases. Only the feed_dict and fetches pass through is not required. See the fake mnist tests in test_training.py for details, or try running the following gist on master:

https://gist.github.com/ahundt/5c2c7f5a324171bbbc8a5b622f2162c5, it will start at this error, tested on master 2017-08-22 at 9pm EST:

  File "mnist_tfrecord_recordinput.py", line 196, in <module>
    steps_per_epoch=steps_per_epoch)
  File "build/bdist.linux-x86_64/egg/keras/engine/training.py", line 1521, in fit
  File "build/bdist.linux-x86_64/egg/keras/engine/training.py", line 1388, in _standardize_user_data
  File "build/bdist.linux-x86_64/egg/keras/engine/training.py", line 561, in _standardize_weights
AttributeError: 'NoneType' object has no attribute 'shape'

The solution to fix the error at the above trace is #7067, but it will reveal another issue as soon as it is fixed. Pull on that thread long enough and you will eventually be patching keras.Function() to deal with incorrectly feeding None into the feed_dict and fetches parameters of session.run().

These can be done in two PRs. Both would only involve changes in training.py as far as I can tell. I would expect these changes to end up at ~30-50% of the total size of this PR.

Unfortunately, I think a number of the test_training.py tests in this PR will fail with only those two changes.

@TimZaman
Copy link
Contributor

How do we handle multi-tensor-output models? model.fit(y=[y_tensor1, y_tensor2], epochs=10, steps_per_epoch=1000)?

@ahundt
Copy link
Contributor Author

ahundt commented Aug 23, 2017

@TimZaman You should save each example proto in your map separately, and read each feature out of your map separately. See how there are separate image, height and width features in mnist_tfrecord_recordinput.py for an example of that aspect (but not the whole thing, since only one is used in training). You just supply each separate tensor in a way that looks much like numpy arrays, just to the Input(tensor=tfrecord_input_tensor) or compile(input_tensors=[label1_tensor, label2_tensor] parameter.

@stale
Copy link

stale bot commented Nov 21, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

@stale stale bot added the stale label Nov 21, 2017
@ahundt
Copy link
Contributor Author

ahundt commented Nov 23, 2017

It is possible to make due with the mode.compile(input_tensors) parameter and I don't have the free cycles to keep merging with master on an ongoing basis, so I'm closing this. If someone else has interest in getting these features integrated please let me know and I'll be happy to help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants