Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API DESIGN REVIEW] Keras Input Tensor API #7102

Closed
ahundt opened this issue Jun 23, 2017 · 7 comments
Closed

[API DESIGN REVIEW] Keras Input Tensor API #7102

ahundt opened this issue Jun 23, 2017 · 7 comments

Comments

@ahundt
Copy link
Contributor

ahundt commented Jun 23, 2017

Keras Input Tensor API Design Proposal

Executive Summary

Tensors that input data to a Model will run faster, add dataset formats, improve usability, and reduce backend lock-in.

# Data Tensors, for example tf yield ops created by RecordInput,
# which supply dataset images and labels
x_train_batch, y_train_batch = read_and_decode_recordinput(...)
 
# batch tensors are added to an Input layer
# Perhaps this aspect of API usage can be improved?
x_train_input = Input(tensor=x_train_batch)
y_train_input = Input(tensor=y_train_batch)
 
# run training
train_model.fit(x_train_input, y_train_input)

I’d appreciate if you make suggestions and give feedback! It is currently a draft, and comments can be made directly on the Google Doc for the full proposal.
As this is the very first Keras API design review, please be kind. :-)

Thanks for considering my proposal and thanks to those who will become a reviewer or contributor!

HELP WANTED

I need help from CNTK & Theano experts on how those backends might be affected by this proposal. Feel free to add comments directly that can be discussed and incorporated.

REVIEW THE REVIEW OR CREATE A PROPOSAL

Please also either give feedback on the process itself or make your own proposal via the Keras API Design Review Template.

p.s. plaintext link:
https://docs.google.com/document/d/1tf2Nl7wor8rmWPUoxfClLuPLQGqvZryegD7K7-1tTe8/edit?usp=sharing

@fchollet
Copy link
Collaborator

fchollet commented Jun 23, 2017

Correct me if my summary is wrong, but it looks like the TL;DR is:

  • Ideally, we should support something like model.fit(x_data_tensor, y_data_tensor, epochs=10, steps_per_epoch=10000). However, that is not implementable today due to TF limitations.
  • In the mean time, we have two options, the existing one and the one you propose.

Existing option:

# tf yield ops that supply dataset images and labels
x_train_batch, y_train_batch = read_and_decode_recordinput(...)
 
# create a basic cnn
x_train_input = Input(tensor=x_train_batch)
x_train_out = cnn_layers(x_train_input)

model = Model(inputs=x_train_input, outputs=x_train_out)
loss = keras.losses.categorical_crossentropy(y_train_batch, x_train_out)
model.add_loss(loss)

model.compile(optimizer='rmsprop', loss=None)

Proposed option:

# tf yield ops that supply dataset images and labels
x_train_batch, y_train_batch = read_and_decode_recordinput(...)
 
# create a basic cnn
x_train_input = Input(tensor=x_train_batch)
x_train_out = cnn_layers(x_train_input)
 
# batch tensors are added to an Input layer
# Perhaps this aspect of API usage can be improved?
y_train_input = Input(tensor=y_train_batch)
 
# This is where the label op is supplied so a placeholder isn't created
x_out = Target(y_train_input)(x_train_out)

model = Model(inputs=x_train_input, outputs=x_out)
model.compile(optimizer='rmsprop',
                    loss='categorical_crossentropy',
                    metrics=['accuracy'])

I'll let others comment on pros/cons...

@fchollet
Copy link
Collaborator

Unless you convince me otherwise, my take on it is that we should:

  • stick with the existing option for now
  • implement the ideal option in the future, when it becomes possible
  • improve existing examples and documentation for people who want to use data tensors
    • and we should point them to using an external TF training loop, until we have the ideal option available. Mostly because fit(None, None) is not pretty

@ahundt
Copy link
Contributor Author

ahundt commented Jun 24, 2017

The above works for me! However, I have one last idea... What about re-compiling on demand when tensors are passed to y in model.fit()? I think this is much nicer than all my previous suggestions.

API usage, with working mnist_tfrecord.py implementation:

# tf yield ops that supply dataset images and labels
x_train_batch, y_train_batch = read_and_decode_recordinput(...)

x_train_in = Input(tensor=x_train_batch)
x_train_out = cnn_layers(x_train_in)
y_train_in = Input(tensor=y_train_batch, name='y_labels')
train_model = Model(inputs=[x_train_in], outputs=[x_train_out])
train_model.compile(optimizer='rmsprop',
                    loss='categorical_crossentropy',
                    metrics=['accuracy'])
train_model.fit(None, y_train_in, 
                batch_size=batch_size,
                epochs=epochs)

Working Model() implementation:

class Model(Container):
    # In this version, `compile()` always saves the arguments in case
    # they're needed later. I think the overhead should be acceptable:
    def compile(self, *args, **kwargs):
        # ...snip docs...
        self._saved_kwargs = kwargs
        self._saved_args = args
        # ...snip rest of compile()...

    
    # If any fit() parameters are tensors, we dynamically recompile:
    def fit(...)
        # ...snip docs and legacy support check...

        # Proof of concept, will need some tweaks
        # expect_other_types=True means is_keras_tensor
        # will return False on a numpy array, not throw an exception.
        if K.is_keras_tensor(y, expect_other_types=True):
            self.target_configuration = [y]
            y = None
            self._compile(*self._saved_args, **self._saved_kwargs)
        elif y is not None:
            recompile = False
            self.target_configuration = []
            for i, yi in enumerate(y):
                if K.is_keras_tensor(yi, expect_other_types=True):
                    self.target_configuration.append(yi)
                    y[i] = None
                    recompile = True
                else:
                    self.target_configuration.append(None)

            if recompile:
                self._compile(*self._saved_args, **self._saved_kwargs)
        # ...snip rest of fit()...

What do you think?

I like this proposal much better than my previous ones, so I've merged it into my largest PR, #6928, replacing the previous behavior.

@TimZaman
Copy link
Contributor

TimZaman commented Jul 16, 2017

I like the recompiling on-demand idea. I don't think tensorflow will have transparent graph surgery any time soon. Some questions:

  • Why is the x input to the fit method still None? (i.e. fit(None, ..)?
  • Why would you need to pass the tensor to Input(tensor=$)? I guess saying Input(shape=$) is fine, as it creates a placeholder that you'll later recompile.
  • How to handle other methods like predict() that also take in values (like numpy) and those would need recompilation to a placeholder, right? @ahundt I ran some training on your tfrecords branch, which went great, but if I do a predict after that, it's still hooked up to the training input, and just ignored whatever array arg i pass into predict.

@ahundt
Copy link
Contributor Author

ahundt commented Jul 16, 2017

@TimZaman The current implementation is the intermediate step where y can be specified at fit time, full x,y is not yet possible, since it requires transparent graph editing. Item 2 is actually the answer to item 1, because the x input is defined by the input tensor, so there is no reason to define it again. Additionally, in this case there are no placeholders because a placeholder requires going through python with the implied performance hit.

Item 3 is due to tensorflow's design. Since the input tensor is how input is supplied, you must save the current graph weights, and then create a new graph with different input tensors, or with placeholders if you go with pure python, and load the saved weights before making predictions.

Eventually, when graph editing is implemented with a clear and usable API, the placeholders can be replaced with input tensors after being created. However since at this time creating an input layer without a tensor parameter will create a placeholder, we must supply the input tensor when the input layer is created.

Was my explanation clear?

@ahundt
Copy link
Contributor Author

ahundt commented Aug 5, 2017

Update: #7113 has a proper input tensor example with external loss.

@ahundt
Copy link
Contributor Author

ahundt commented Aug 23, 2017

Closing in favor of #7503, which has the finalized action plan.

@ahundt ahundt closed this as completed Aug 23, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants