Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using data tensors as data sources: action plan #7503

Closed
fchollet opened this issue Aug 2, 2017 · 22 comments
Closed

Using data tensors as data sources: action plan #7503

fchollet opened this issue Aug 2, 2017 · 22 comments

Comments

@fchollet
Copy link
Collaborator

fchollet commented Aug 2, 2017

We want to add the ability to feed TensorFlow data tensors (e.g. input queues) into Keras models. A few days ago I met with @athundt and we discussed his previous efforts to make it happen. Here is how we will handle it:

First step [Update: done]

The following API:

# Get data tensors
data_tensor, target_tensor = ...

# Build model on top of the data tensor
inputs = Input(tensor=data_tensor)
outputs = Dense(...)(inputs)
model = Model(inputs, outputs)

# Add internal loss
loss = loss_fn(target_tensor, outputs)
model.add_loss(loss)

# Compile without external loss
model.compile(optimizer='sgd', loss=None)

# Fit without external data
model.fit(epochs=10, steps_per_epoch=1000)

This is already 90% supported. What is missing is the steps_per_epoch argument (currently fit would only draw a single batch, so you would have to use it in a loop).

NEEDED:

  • [Update: done] PR introducing the steps_per_epoch argument in fit. Here's how it works:
    • Based on arguments received, we determine whether training should be step-based (like in fit_generator) or sample-based (like in fit currently).
    • We have two independent code branches handling each mode.
  • [Update: done] PR introducing a MNIST example of how to use data tensors for inputs and targets, following the code snippet above. It should use the MNIST data tensors built-in in TF.

Second step

The following API:

# Get data tensors
data_tensor, target_tensor = ...

# Build model on top of the data tensor
inputs = Input(tensor=data_tensor)
outputs = Dense(...)(inputs)
model = Model(inputs, outputs)

# Compile as usual
model.compile(optimizer='sgd', loss='mse')

# Fit by passing the target tensor
model.fit(y=target_tensor, epochs=10, steps_per_epoch=1000)

Main issue: in compile, we create placeholders for the targets. We need to discard them (cache them, actually) and use the provided target tensor instead.

Solution: a model recompilation step inside fit in order to cache the previous target placeholder and replace it with our target tensor.

NEEDED:

  • PR adding support for a target tensor in the call to fit for a normally compiled model. Involves a recompilation step.

Third step

The following API:

# Get data tensors
data_tensor, target_tensor = ...

# Build model on top of placeholders
inputs = Input(shape=(...))
outputs = Dense(...)(inputs)
model = Model(inputs, outputs)

# Compile as usual
model.compile(optimizer='sgd', loss='mse')

# Fit by passing the data tensor and target tensor
model.fit(data_tensor, target_tensor, epochs=10, steps_per_epoch=1000)

It's not 100% clear at this point how we will handle it, but we will figure it out. Most likely this will involve building a new TF graph inside fit, running training with it, then transferring weight values back to the initial graph. I'll handle it.

CC: @athundt @Dref360 @colinskow @TimZaman

@TimZaman
Copy link
Contributor

TimZaman commented Aug 3, 2017

LGTM. @athundt, do you take the lead in steps (1) and (2)? It seems you've mostly nailed those already.

@Dref360
Copy link
Contributor

Dref360 commented Aug 3, 2017

Step 3 seems really "hacky". Could we ask the TF team if they are willing to handle feeding placeholder with Tensors?

For step 2, I was away for a while so I didn't keep up with @athundt 's PR. But since the data_tensor is already there, I see no problem doing : model.compile(y=target_tensor, optimizer='sgd', loss='mse').

Would save one compilation, if you've already talked about it in the PR, ignore this.

@TimZaman
Copy link
Contributor

TimZaman commented Aug 4, 2017

@Dref360

Step 3 seems really "hacky".

Yes, it's a bit dirty. But I think Keras's API's do allow us to clean up the graph-surgery mess quite easily, in a way that it's a hack in principle, but it's a great one. We'll see when we get there.

Could we ask the TF team if they are willing to handle feeding placeholder with Tensors?

We did; issue: tensorflow/tensorflow#10837

model.compile(y=target_tensor, optimizer='sgd', loss='mse').

On first glance, that sounds pretty sane to me! I don't recall anyone suggesting this?

@ahundt
Copy link
Contributor

ahundt commented Aug 23, 2017

Sorry guys, I didn't see this until now because @ahundt is the account I actually use. I'm not sure I have access to the other one any more.

@Dref360 I submitted the request for the feature in tensorflow a couple months ago tensorflow/tensorflow#10837.

The graph editing PR might be a good way to implement the underlying functionality for API 3 #7505

@PBehr
Copy link

PBehr commented Sep 20, 2017

Update 2 and 3 will lead to issues with distributed training. Tensorflow distributed finalizes the graph, so we get an error if we try to recompile the model. See #3997 for reference

@fchollet
Copy link
Collaborator Author

fchollet commented Sep 20, 2017 via email

@ahundt
Copy link
Contributor

ahundt commented Oct 1, 2017

How should we handle validation data? When a model uses input tensors the data being loaded is pre-defined, so it likely needs to be instantiated a second time or perhaps something like #7505 would be needed to reconnect the input tensors.

Thoughts?

@fchollet
Copy link
Collaborator Author

fchollet commented Nov 20, 2017 via email

@ahundt
Copy link
Contributor

ahundt commented Nov 21, 2017

Cool, thanks! I saw the tf + keras estimator API is out with 1.4, perhaps there is an example somewhere?

@ahundt
Copy link
Contributor

ahundt commented Jan 16, 2018

Found an example of estimators in horovod, and it seems to convert a keras model to tf you use model_to_estimator.

@R-Miner
Copy link

R-Miner commented May 8, 2018

Do you have a fix on the ability to call
fit/evaluate/predict directly on data tensors for a model built on
top of placeholders?

@sekharvth
Copy link

sekharvth commented May 23, 2018

@R-Miner I tried using the 3rd step where tensors are directly passed as input to the model. But it threw me an AttributeError saying that 'Tensor' object has no attribute 'ndim'. I'm running Keras 2.1.6 on top of Tensorflow 1.8 .
@fchollet said that the issue would be most likely resolved by TF 1.6, but since there haven't been any further updates about that on this thread, I'm not sure if step 3 has been implemented.
It would be great to get an update regarding this, @ahundt

UPDATE - I got the following dummy code to work:

from keras.layers import Input, Dense, Lambda
from keras.models import Model
import tensorflow as tf
with tf.Session() as sess:
  sess.run(tf.initialize_all_tables())
  sess.run(tf.initialize_all_variables())
  inp = Input(tensor = embedding)
  inp1 = Lambda(lambda x: tf.cast(x, tf.float32))(inp) 
  dense = Dense(1, activation = 'sigmoid')(inp1)

  model = Model(inp, dense)

  model.compile(loss = 'binary_crossentropy', metrics = ['accuracy'], optimizer = 'adam')

  model.fit(embedding, np.array([5]), epochs = 10) 

The casting to float operation is done to avoid conflicting datatypes in the Matmul operation of the Dense layer. 'embedding' is a tensor of shape (num_examples, 512).

But it still doesn't support a multiple input model, where one input is a tensor and the other an array. It then throws the same error shown earlier ('Tensor' object has no attribute ndim').

So it apparently works with exclusively tensor inputs, but doesn't support multiple data type inputs yet. Is there like a temporary hack or something that can solve this problem?

@psoulos
Copy link

psoulos commented Jun 15, 2018

Is there a way to save and load models that use data tensors as data sources? I am able to create the original model and save it, but I'm not sure how to load the model. If I call load_model(), how do I correctly specify the input tensor? I found this stackoverflow answer for replacing the input tensor, but this creates a dangling input which prevents me from saving the model.

@TimZaman
Copy link
Contributor

@psoulos you cannot. A model that you load in sadly is always created on top of placeholders. The only thing you can do is:

x = $your_input_tensor
m1 = keras.$.load_model()
m2 = Model(inputs=x, outputs=m(x))

@psoulos
Copy link

psoulos commented Jun 15, 2018

@TimZaman Will that allow me to continue training without losing the state of my optimizer and configuration? Currently I'm re-creating the model architecture and calling model.load_weights but this makes it difficult to continue training.

@was84san
Copy link

was84san commented Jun 19, 2018

I tried to use the third step but then I have this error
" When feeding symbolic tensors to a model, we expect thetensors to have a static batch size. Got tensor with shape: (None, 32, 64, 64, 3)"

I used the following strategy to fit the model :

  training_filenames = [.....]
  dataset = tf.data.TFRecordDataset(training_filenames)
  dataset = dataset.map(_parse_function_all) # Parse the record into tensors.

  dataset = dataset.batch(20)
  iterator = dataset.make_initializable_iterator()
  next_element= iterator.get_next()
  # videos will be next_element [0], labels = next_element[1]

  # since it is pair I will use only first pair for training and second pair for validation
 # train_video = next_element [0][:, 0] val_videos = next_element[0][:, 1]
 # same with labels

 model = create_base_network()
 # input_dim = (None, 32, 64, 64 3) for the model above
 # output dimension will be (None, 10) for the model above

 sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
 model.compile(loss='categorical_crossentropy', optimizer=sgd)

 model.fit(next_element[0][:, 0], next_element[1][:, 0], validation_data=(next_element[0][:, 1], 
 next_element[1][:, 1]), epochs=10, steps_per_epoch=1000)

`
So any one can tell me why I got this error?
Is the third step of fitting the model working now in kerns? or still have issues

@nmiculinic
Copy link

Hmmm...how is this suppose to work with validation dataset? Is it possible to inject both via those API's or do I have to resort to tf magic?

@dillondaudert
Copy link

I wanted to leave a comment here so others could see, but as of tensorflow 1.9, the tf.keras package supports using tf.data.Datasets and tf.data.Iterators as inputs to Model.fit()/evaluate()/predict(). See the documentation here.

For instance, this works as of tf1.9:

import tensorflow as tf
import numpy as np
from tensorflow import keras

inputs = np.zeros((10, 3))
targets = np.zeros((10, 4))
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(5)

x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(flat)

model = keras.Model(x, y)
model.compile(loss='mse', optimizer='rmsprop')

model.fit(dataset, epochs=1, steps_per_epoch=2, validation_data=dataset, validation_steps=2)

I'm not sure what the exact differences between keras-team/keras and tensorflow/keras are at this point, but it seems that tf.data.Dataset support is further along in the latter.

@was84san
Copy link

@ dillondaudert . So thats mean I can't use it with tensor flow 1.8 version.

@lminer
Copy link

lminer commented Jul 11, 2018

@was84san seems to work if you call .set_shape((YOUR SHAPE INCLUDING BATCH SIZE)) on the tensors you get from .get_next()

Edit: Actually even better seems to be to set drop_remainder=True in the batch method.

@was84san
Copy link

@Iminer , I did that and still have this error
AttributeError: "'Tensor' object has no attribute 'ndim'"

@jandono
Copy link

jandono commented Nov 14, 2018

What's the current support for Model.predict(some_data), if I have hard wired tf.data.Dataset iterator as an input tensor to my model? Namely, I have something similar to the following:

# dataset = Some tf.data.Dataset
dataset_iterator = dataset.make_one_shot_iterator()
input_tensor_x, input_tensor_y = dataset_iterator.get_next()
outputs = Dense(10)(inputs)
model = Model(inputs=[input_tensor_x], outputs=[outputs])
model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['categorical_accuracy'],
        target_tensors=[input_tensor_y]
)

How can I call model.predict(data_to_be_predictied) on such a model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests