Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Native TensorFlow input queue implemented with fit, predict, evaluate (TFRecords) #7208

Closed
wants to merge 8 commits into from

Conversation

colinskow
Copy link

@colinskow colinskow commented Jul 2, 2017

Adds support for native multi-threaded QueueRunner / TFRecord input using the TensorFlow backend through the fit, predict, and evaluate functions. This includes data validation.

Working demo: https://github.com/colinskow/keras/blob/fit-tensor-input/examples/mnist_tfrecord.py

Each batch of input is evaluated inside the training loop and fed into the trainer. There are ways to enhance the performance by feeding input directly into the graph, but this is going to mean a LOT of work to decrease training time perhaps a few percent. Let's get the functionality implemented and work on increasing performance under the hood down the road.

WHY THIS IS NEEDED:

Processing input via Python/NumPy is very slow. Especially when paying for a high-end GPU, inputting through Python can become a bottleneck which takes longer than actual training. Keras' input pipeline is not cloud friendly.

TensorFlow has a very good solution to these problems, but it does not work with Keras now.
https://www.tensorflow.org/programmers_guide/reading_data

This PR allows Keras to train, evaluate, and predict using TF's super-fast, cloud friendly, multi-threaded input pipeline. This is essential to running models in the cloud and not wasting money when paying by the minute for high-end GPUs. And it does this transparently with no breaking changes to the API. (And other backends can be accommodated in the same way if they have similar functionality.)

HELP REQUESTED:

  1. I'm relatively new to Python, so please tell me if I've used any bad coding practices
  2. TODO: automated testing of TFRecord input, including possible edge cases

@fchollet @ahundt @TimZaman

tensor_input = _is_tf_tensor(ins[0])
if ins and tensor_input:
if not num_steps:
raise ValueError('`validation_steps` must be specified when using tensor input.')
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change validation_steps to num_steps

to evaluate from input tensors before declaring one epoch
finished and starting the next epoch. It should typically
be equal to the number of unique samples if your dataset
divided by the batch size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This falls flat when the division does not result in a discrete integer.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied directly from the fit_generator docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yet the tfrecords pipe you have in your example does not support a non-fixed batch size, while you are not sure your dataset size is not cleanly divisible by the batch size. Having a super fast data pipe in tf and having mixed-size batch sizes is hard. Optimized Tensorflow models therefore often works with steps, rather than epochs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is discussed in #7113

def fit(self, x, y, batch_size=32, epochs=10, verbose=1, callbacks=None,
validation_split=0., validation_data=None, shuffle=True,
class_weight=None, sample_weight=None, initial_epoch=0, **kwargs):
def fit(self, x=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the style changes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied and pasted from the fit function of training.py. After adding parameters the long lines were awkward.



def cnn_model(input_shape, num_classes):
model = Sequential()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try using the functional model API.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied straight from examples/mnist_cnn.py. I appreciate constructive feedback, but feel you are criticizing to criticize.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nonsense. Just trying to help.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my other example shows how to use the functional api #7075

# Returns
A boolean: whether the argument is a native tensor.
"""
return isinstance(x, tf.Tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about tf.SparseTensor, tf.Variable?
Anyway, you can just use the existing is_keras_tensor().

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_keras_tensor() doesn't work. This does. tf.train.shuffle_batch returns type tf.Tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but your function will get used in the future by others and they expect is to work for all types of Tensors.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, I will copy the functionality from _is_keras_tensor() but remove the check for Keras metadata.

@@ -263,6 +263,19 @@ def is_keras_tensor(x):
return hasattr(x, '_keras_history')


def is_native_tensor(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use is_keras_tensor()

@@ -201,6 +201,18 @@ def is_keras_tensor(x):
return hasattr(x, '_keras_history')


def is_native_tensor(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use is_keras_tensor()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DOESN'T WORK. There is no Keras metadata in the input tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InputLayer will have some metadatas once #7066 is merged.

print('tfrecord %s already exists' % filename)


def read_and_decode_recordinput(tf_glob, one_hot=True, classes=None, is_train=None, batch_shape=[128, 28, 28, 1]):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahundt is the expert in this area.

@Dref360
Copy link
Contributor

Dref360 commented Jul 2, 2017

This PR is really TF specific and adds a lot of code into the main files only for TF (training.py for example). Not convinced this is needed.

@colinskow
Copy link
Author

@Dref360 see the discussion toward the bottom of #6928.

TL;DR processing input via Python/NumPy is very slow. Especially when paying for a high-end GPU, inputting through Python can become a bottleneck which takes longer than actual training. Keras' input pipeline is not cloud friendly.

TensorFlow has a very good solution to these problems, but it does not work with Keras now.
https://www.tensorflow.org/programmers_guide/reading_data

This PR allows Keras to train, evaluate, and predict using TF's super-fast, cloud friendly, multi-threaded input pipeline. This is essential to running models in the cloud and not wasting money when paying by the minute for high-end GPUs. And it does this transparently with no breaking changes to the API.

Unfortunately I haven't seen that Theano or CNTK support similar functionality. If they do we can definitely accommodate them all with the same API.

I do agree training.py is an ugly place for TF-specific code. I'm open to feedback on how to best refactor this into the backend. But this is very important functionality for running models in production where performance is critical.

@TimZaman
Copy link
Contributor

TimZaman commented Jul 2, 2017

Agree with @Dref360 .

TL;DR processing input via Python/NumPy is very slow.

Yet you are making a tf->np->tf roundtrip, where ->np-> is in the critical path.
For feeding the GPU, it's all about the critical path. Keras's python-multiprocessing pipeline feeds fine and can keep up with things in principle. The problem you'll encounter on training on big systems like a DGX-1, is that the GPUs are extremely fast. The GPU's don't want to wait for anything to get the next batch, and the samples should be staged on the GPU. They should not be staged on CPU, and they should not be waiting on a tf->np->tf roundtrip to finish.

The first step in doing this is having a tensorflow input, this is exactly what #6928 was in the process of solving. This not dependent on TFRecords, yield ops, etc.
The feature we want is: "Having Keras cleanly support input $framework tensors instead of numpy arrays".

Having this PR together with #6928 is also a bit confusing, as they are trying to solve the same core issue.

@colinskow
Copy link
Author

@TimZaman I agree with you that #6928 is the most performance-effecient solution. The problem is that it requires refactoring the user's model and doesn't work with most features of the fit API such as data validation and sample weights. There's a lot of work and problem solving ahead to make it merge ready.

My solution is ready to go (except for some fine-tuning and cleanup) and works out of the box with all existing models. The only API change is specifying the number of steps per epoch (as with fit_generator). It gives 95% of the performance benefits of a native input queue. And it will be easy to implement with other backends when they develop input queue solutions as well.

I think you are exaggerating the performance impact of tf->np->tf once per batch. The training loop is doing this anyway. Just look how much Python code already gets executed inside _fit_loop between batches. The GPU is waiting anyway. The TF QueueRunner has the input processed and ready to go in memory when Python calls for it. The cost of feeding a NP.array back into TF is miniscule. And this bottleneck can be compensated for by simply increasing batch size.

To really get the performance you are talking about means eliminating Python from the training loop completely and going all-native. And the performance benefit will likely be only a few percent. I don't think this is going to be implemented anytime soon.

I think it's a good idea to actually benchmark my solution against #6928 to see what the difference really is. I personally only have a CPU, so perhaps someone with a good GPU can try too.

I do suggest we merge this first (after cleaning it up, adding tests etc.) Then when the technical kinks are worked out of #6928 we can merge that down the road to really get the extra performance.

And I don't agree with depriving users of a very useful feature for months while waiting on a slightly better implementation when we can release now and then transparently tune under the hood.

@colinskow
Copy link
Author

colinskow commented Jul 3, 2017

I benchmarked this pull request against @ahundt's solution in #6829 on my 2012 MacBook Pro (CPU only). Here are the results:

This uses the same data, same model architecture, and same TFRecords input pipeline. I used a batch size of 200 across 60,000 samples for 5 epochs. I disabled data validation to be fair since #6829 doesn't support it.

[ahundt/keras@tfrecords] (counts data samples)
Epoch 1/5
60000/60000 [==============================] - 120s - loss: 0.2943 - acc: 0.9110      
Epoch 2/5
60000/60000 [==============================] - 111s - loss: 0.0868 - acc: 0.9742     
Epoch 3/5
60000/60000 [==============================] - 111s - loss: 0.0648 - acc: 0.9807     
Epoch 4/5
60000/60000 [==============================] - 111s - loss: 0.0556 - acc: 0.9839     
Epoch 5/5
60000/60000 [==============================] - 111s - loss: 0.0490 - acc: 0.9850


[colinskow/keras@fit-tensor-input] (counts batches)
Epoch 1/5
300/300 [==============================] - 115s - loss: 0.2585 - acc: 0.9215     
Epoch 2/5
300/300 [==============================] - 114s - loss: 0.0869 - acc: 0.9745     
Epoch 3/5
300/300 [==============================] - 114s - loss: 0.0639 - acc: 0.9815     
Epoch 4/5
300/300 [==============================] - 113s - loss: 0.0515 - acc: 0.9848     
Epoch 5/5
300/300 [==============================] - 114s - loss: 0.0452 - acc: 0.9867

So 111 seconds vs 114 seconds. That's a difference of about 2.7%. Going from TF->NP-> TF costs about 10 ms / batch on my MacBook. And for me it is a cost worth paying to have full use of the fit, predict, and evaluate APIs.

My proposal is that we go with this API since it is much simpler, and down the road figure out how to plug the inputs directly into the graph under the hood to get the 2.7% performance boost.

@TimZaman
Copy link
Contributor

TimZaman commented Jul 3, 2017

Who uses CPU for DL? This MR won't benefit anyone with 1 CPU or 1 or 2 GPUs. The tensor-native input mostly makes sense for GPU heavy systems, as Keras's multiprocessing is already quite good to feed 1 or 2 GPUs. I am repeating myself so I will refrain from further comment.

@colinskow
Copy link
Author

@TimZaman I agree with you 100% on that native tensor input is ideal. All I'm suggesting is that we get the functionality implemented first and then make it performant later since that is the hard part.

Copy link
Contributor

@ahundt ahundt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I don't think I can recommend these changes. I absolutely would have done so if I found them to be a big improvement because I'm always happy to maintain less code. :-)

More specifically, I think many of these changes are both slower and more platform specific than my roughly equivalent changes in #6928. You can also see #7072 for a more specific, itemized, and atomic split of various parts you could use if you wish to make use of parts of #6928 in your own code and wish to do another iteration of these changes, though I think the performance difference for large datasets will be a fairly fundamental issue.

For others' reference, #6928 (comment) has some reasons for the design choices that were made and why. Other discussion including #6928 (comment) (a different comment) explains the underlying TensorFlow changes needed to implement a more ideal tensor input API.



def cnn_model(input_shape, num_classes):
model = Sequential()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my other example shows how to use the functional api #7075

to evaluate from input tensors before declaring one epoch
finished and starting the next epoch. It should typically
be equal to the number of unique samples if your dataset
divided by the batch size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is discussed in #7113

@@ -26,6 +26,18 @@
from ..legacy import interfaces


def _is_tf_tensor(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and everything that uses it is too tf specific for this file. See how I handled these same issues in #6928

@fchollet
Copy link
Collaborator

fchollet commented Jul 6, 2017

At some point, we want to be able to call models' fit on data tensors.

But there are only two reasonable ways to make it happen:

    1. Graph rewiring: replacing the graph's placeholders with the data tensors for the duration of fit (then restore them back)
    1. Rebuilding the same graph on top of the data tensors

The present proposal, at a high level, extracts batches from the data tensors (one memory copy) then feeds them to the placeholders. It's not as crazy as it sounds, but still very inefficient.

Importantly, this approach can be achieved with nearly no changes to the codebase (compared to the fairly involved changes proposed here) by simply structuring them as a Python generator that would be used with fit_generator. The generator would run the data tensors in a different session, and yield the resulting Numpy arrays -- that's it. This would be self-contained and would only involve a few lines of code.

Overall, we should not proceed with this approach and we should keep exploring 1) and 2).

@ahundt
Copy link
Contributor

ahundt commented Jul 8, 2017

ii. Rebuilding the same graph on top of the data tensors

@fchollet This is what is attempted in the current approach of #6928 after the latest commits (with a more atomic merge process outlined in #7072 (comment)).

@colinskow
Copy link
Author

Thanks for the feedback @fchollet I agree with your points. Extracting the feed data inside a fit_generator is a great idea, I didn't think about that. Although the ideal solution is definitely wiring the tensors directly into the graph.

For my personal projects I'm looking into plugging directly into the TensorFlow estimator and experiment APIs since these are built for easily taking models into distributed production.

I'll go ahead and close this so I'm not cluttering up your pending PRs.

@colinskow colinskow closed this Jul 10, 2017
@ahundt
Copy link
Contributor

ahundt commented Jul 10, 2017

@colinskow While you're exploring could you consider creating an mnist example (or any other kind) with keras models, an estimator & experiment and submit a PR here or directly in TF? I think that would help many users!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants