Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.fit(steps_per_epoch), mnist_tfrecord.py, progbar np.mean #7113

Merged
merged 33 commits into from
Aug 11, 2017

Conversation

ahundt
Copy link
Contributor

@ahundt ahundt commented Jun 24, 2017

Implements external loss based support for model.fit(steps_per_epoch) with tensor input and a corresponding tfrecord example.

# API 2

model = # on top of the tensor input
model.add_loss()  # involving y_tensor

model.fit(epochs=10, steps_per_epoch=1000)

Runtime results when running mnist_tfrecord.py from python tf tensors:

Test accuracy: 0.990999997973
python mnist_tfrecord.py  57.21s

The progbar change is needed for mnist_tfrecord.py where a step is a batch.

if ins and hasattr(ins[0], 'shape'):
num_train_samples = ins[0].shape[0]
if steps_per_epoch is not None:
num_train_samples = steps_per_epoch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_train_samples != steps_per_epoch, though. A step is a batch, not a single sample.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes forgot to divide but batch size. I'll do that when I'm at a computer

@fchollet
Copy link
Collaborator

I think what we should do is:

  • if the number of samples can be recovered from the data, use that. If steps_per_epoch is set, raise a warning that it's ignored. If batch_size is not set, raise a ValueError.
  • if it can't be recovered, raise a ValueError in case steps_per_epoch is not set. If batch_size is set, raise a warning that it's ignored.

There are serious issues with blending both step-based counting and sample-based counting (the last batch, variable batch sizes, etc). So we can't mix them without complications.

@ahundt
Copy link
Contributor Author

ahundt commented Jun 26, 2017

if the number of samples can be recovered from the data, use that. If steps_per_epoch is set, raise a warning that it's ignored. If batch_size is not set, raise a ValueError.

I agree with your suggestion. I just made a small tweak since the steps_per_epoch default is None and this is a new parameter. I provide the warning you suggest but otherwise I assume the user meant to override defaults when they provide steps_per_epoch. I also added it to the other relevant functions.

@ahundt
Copy link
Contributor Author

ahundt commented Jun 26, 2017

Seems it is necessary to allow both to be specified because batch_size is not ignored, it is used in some of the totals and other calculations so I'm not sure the warning is appropriate without more substantial internal changes. Example error:

___________________________ test_model_with_input_feed_tensor ___________________________
[gw0] darwin -- Python 2.7.13 /usr/local/opt/python/bin/python2.7
@pytest.mark.skipif(K.backend() != 'tensorflow', reason='Requires TF backend')
    @keras_test
    def test_model_with_input_feed_tensor():
        """We test building a model with a TF variable as input.
        We should be able to call fit, evaluate, predict,
        by only passing them data for the placeholder inputs
        in the model.
        """
        import tensorflow as tf

        input_a_np = np.random.random((10, 3))
        input_b_np = np.random.random((10, 3))

        output_a_np = np.random.random((10, 4))
        output_b_np = np.random.random((10, 3))

        a = Input(tensor=tf.Variable(input_a_np, dtype=tf.float32))
        b = Input(shape=(3,), name='input_b')

        a_2 = Dense(4, name='dense_1')(a)
        dp = Dropout(0.5, name='dropout')
        b_2 = dp(b)

        model = Model([a, b], [a_2, b_2])
        model.summary()

        optimizer = 'rmsprop'
        loss = 'mse'
        loss_weights = [1., 0.5]
        model.compile(optimizer, loss, metrics=['mean_squared_error'],
                      loss_weights=loss_weights,
                      sample_weight_mode=None)

        # test train_on_batch
        out = model.train_on_batch(input_b_np,
                                   [output_a_np, output_b_np])
        out = model.train_on_batch({'input_b': input_b_np},
                                   [output_a_np, output_b_np])
        out = model.test_on_batch({'input_b': input_b_np},
                                  [output_a_np, output_b_np])
        out = model.predict_on_batch({'input_b': input_b_np})

        # test fit
        out = model.fit({'input_b': input_b_np},
                        [output_a_np, output_b_np], epochs=1, batch_size=10)
        out = model.fit(input_b_np,
                        [output_a_np, output_b_np], epochs=1, batch_size=10)

        # test evaluate
        out = model.evaluate({'input_b': input_b_np},
                             [output_a_np, output_b_np], batch_size=10)
        out = model.evaluate(input_b_np,
                             [output_a_np, output_b_np], batch_size=10)

        # test predict
        out = model.predict({'input_b': input_b_np}, batch_size=10)
        out = model.predict(input_b_np, batch_size=10)
        assert len(out) == 2

        # Now test a model with a single input
        # i.e. we don't pass any data to fit the model.
        a = Input(tensor=tf.Variable(input_a_np, dtype=tf.float32))
        a_2 = Dense(4, name='dense_1')(a)
        a_2 = Dropout(0.5, name='dropout')(a_2)
        model = Model(a, a_2)
        model.summary()

        optimizer = 'rmsprop'
        loss = 'mse'
        model.compile(optimizer, loss, metrics=['mean_squared_error'])

        # test train_on_batch
        out = model.train_on_batch(None,
                                   output_a_np)
        out = model.train_on_batch(None,
                                   output_a_np)
        out = model.test_on_batch(None,
                                  output_a_np)
        out = model.predict_on_batch(None)
        out = model.train_on_batch([],
                                   output_a_np)
        out = model.train_on_batch({},
                                   output_a_np)

        # test fit
        out = model.fit(None,
                        output_a_np, epochs=1, batch_size=10)
        out = model.fit(None,
                        output_a_np, epochs=1, batch_size=10)

        # test evaluate
        out = model.evaluate(None,
                             output_a_np, batch_size=10)
        out = model.evaluate(None,
                             output_a_np, batch_size=10)

        # test predict
>       out = model.predict(None, batch_size=None, steps=1)

tests/keras/engine/test_training.py:519:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
keras/engine/training.py:1518: in predict
    verbose=verbose, steps=steps)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <keras.engine.training.Model object at 0x112d30f90>
f = <keras.backend.tensorflow_backend.Function object at 0x11306cd90>
ins = [0.0], batch_size = None, verbose = 0, steps = 1

    def _predict_loop(self, f, ins, batch_size=32, verbose=0, steps=None):
        """Abstract method to loop over some data in batches.

            # Arguments
                f: Keras function returning a list of tensors.
                ins: list of tensors to be fed to `f`.
                batch_size: integer batch size.
                verbose: verbosity mode.
                steps: Total number of steps (batches of samples)
                    before declaring _predict_loop finished.
                    Ignored with the default value of `None`.

            # Returns
                Array of predictions (if the model has a single output)
                or list of arrays of predictions
                (if the model has multiple outputs).
            """
        samples = self._check_num_samples(ins, batch_size, steps, 'steps')
        outs = []
        if verbose == 1:
            progbar = Progbar(target=samples)
        batches = _make_batches(samples, batch_size)
        index_array = np.arange(samples)
        for batch_index, (batch_start, batch_end) in enumerate(batches):
            batch_ids = index_array[batch_start:batch_end]
            if ins and isinstance(ins[-1], float):
                # Do not slice the training phase flag.
                ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
            else:
                ins_batch = _slice_arrays(ins, batch_ids)

            batch_outs = f(ins_batch)
            if not isinstance(batch_outs, list):
                batch_outs = [batch_outs]
            if batch_index == 0:
                for batch_out in batch_outs:
                    shape = (samples,) + batch_out.shape[1:]
                    outs.append(np.zeros(shape, dtype=batch_out.dtype))

            for i, batch_out in enumerate(batch_outs):
>               outs[i][batch_start:batch_end] = batch_out
E               ValueError: could not broadcast input array from shape (10,4) into shape (1,4)

@@ -962,10 +962,21 @@ def _make_predict_function(self):
name='predict_function',
**kwargs)

def _check_num_samples(self, ins, batch_size=None, steps=None, steps_name='steps'):
if steps is not None:
num_samples = steps
Copy link
Collaborator

@fchollet fchollet Jun 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should be, on average, step * batch_size, not steps. However, since the data length may not be a multiple of batch_size, and because in the generator case there may be batches of different sizes, it is actually impossible to determine num_samples.

Copy link
Contributor Author

@ahundt ahundt Jun 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps something should be renamed in this case or documentation modified a bit for everything to be consistent?

Also, I wasn't sure if you're saying I should make it step*batch_size here or if you're saying that's also inaccurate and we need to do something else.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to separate two use cases cleanly:

  • training for a specific number of steps (if the data has no length), like in fit_generator.
  • training for a specific number of samples (if the data has a length), like in current fit.

We can't unify both because no reliable conversion exists between the two.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the best approach then be the creation of a second code path in fit() that works more like fit_generator()? Perhaps it would be best to create a new function like fit_tensor()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Can we make the expanded API work with fit without refactoring or adding new methods?

Copy link
Contributor Author

@ahundt ahundt Jul 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be pretty simple to do with the same methods just by checking if either x or y is a Keras tensor in fit(x, y, ...), so long as steps_per_epoch and epochs from fit_generator() are added to fit(x, y, ..., steps_per_epoch=None, epochs=None).

Unless you have another suggestion, I propose checking for Keras tensors because it is the simplest mechanism by which multiple platforms, not just TF, could eventually support the API. This means wrapping TF tensors with Input(), as in #6928 or see the more atomic PR merge order outlined in #7072 (comment).

ahundt added 4 commits July 31, 2017 22:25
…s_per_epoch

* commit '01e2148732e4083b4850345e5ce4dd499cb5999e': (98 commits)
  Fix common backend styles (keras-team#7476)
  Allow dynamic shape for repeat_elements (keras-team#7422)
  Fix test style (keras-team#7473)
  Crossentropy backend API consistency cleanup (keras-team#7199)
  Replace the reserved word `input` with `inputs` (keras-team#7474)
  Increase test coverage (keras-team#7264)
  Fix l2_normalize
  Add default value for `l2_normalize`.
  Remove legacy axis handling in TF backend.
  Update save_model function (keras-team#7455)
  Switch to pydot 1.2.x (keras-team#7448)
  Add save, evaluate and predict functionality for example (keras-team#7430)
  Docs fix: `pointwise` instead of `depthwise` for SeparableConv2D (keras-team#7444)
  Fix conv reccurent test
  Style fix in conv recurrent tests.
  Support return_state parameter in ConvRecurrent2D (keras-team#7407)
  Small simplification in ResNet50 architecture
  Update FAQ with info about custom object loading.
  add example for passing in custom objects in load_model (keras-team#7420)
  Update applications.md (keras-team#7428)
  ...
* mnist_tfrecord:
  mnist_tfrecord.py
  loss defaults to None in compile()
  lower batch size and epochs
  mnist_tfrecord.py indentation fix
  mnist_tfrecord.py add parallelism option
  mnist_tfrecord.py pep8
  Fix mnist_tfrecord.py runtime errors
  mnist_tfrecord.py added (keras-team#7061, keras-team#7072, keras-team#6928, keras-team#7046)
@ahundt
Copy link
Contributor Author

ahundt commented Aug 8, 2017

Large batch sizes can run much in less wall clock time and still get better performance. For example batch_size=600, steps_per_epoch=300, epochs=6 ran in 30% less time and predicted 0.1% more accurately. Not really relevant to this PR so I just set it to approximately match mnist_cnn.py.

@ahundt
Copy link
Contributor Author

ahundt commented Aug 8, 2017

@fchollet hopefully everything should be addressed now

ahundt added a commit to ahundt/keras that referenced this pull request Aug 8, 2017
* fit_steps_per_epoch:
  training.py fix unit test error for steps_per_epoch
  mnist_tfrecord.py and training.py fixed review comments, docs, and error messages (keras-team#7113)
  training.py fix test error
  mnist_tfrecord.py extended description
  mnist_tfrecord.py and training.py clean up based on review (keras-team#7113)

# Conflicts:
#	examples/mnist_tfrecord.py
initial_epoch=initial_epoch)
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps)

def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to add the tensor-input support for evaluate immediately as well, or do that in another PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let's add data tensor support to evaluate and predict in this CL.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to hold off and get these changes in. We can cherry pick changes from #6928 in a separate PR for that if you think it is a good idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah the most recent post didn't load, also I had changed my mind as well and added evaluate since it was only a couple lines. I'll look at predict too.

@@ -1513,8 +1611,8 @@ def predict(self, x, batch_size=32, verbose=0):
ins = x
self._make_predict_function()
f = self.predict_function
return self._predict_loop(f, ins,
batch_size=batch_size, verbose=verbose)
return self._predict_loop(f, ins, batch_size=batch_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave as-is?

@@ -734,7 +739,7 @@ def test_model_with_external_loss():
out = model.predict_on_batch(None)

# test fit
out = model.fit(None, None, epochs=1, batch_size=10)
out = model.fit(None, None, epochs=1, batch_size=None, steps_per_epoch=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified. Looks great!

@TimZaman
Copy link
Contributor

TimZaman commented Aug 8, 2017

LGTM

@@ -1154,58 +1222,70 @@ def _predict_loop(self, f, ins, batch_size=32, verbose=0):
return outs[0]
return outs

def _test_loop(self, f, ins, batch_size=32, verbose=0):
def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None):
"""Abstract method to loop over some data in batches.

# Arguments
f: Keras function returning a list of tensors.
ins: list of tensors to be fed to `f`.
batch_size: integer batch size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or None

for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
else:
if shuffle == 'batch':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a warning that if steps_per_epoch is not None, shuffle=True has no effect. (no test required since pytest is flaky when it catches warning)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding a comment to the docs, I'd rather not see that warning all the time by default

epochs=1, batch_size=4, validation_split=0.5)
epochs=1, batch_size=None, validation_split=0.5,
steps_per_epoch=1)
out = model.fit({'input_a': input_a_np, 'input_b': input_b_np},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predict and evaluate with data tensors should have unit tests as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
test_model.summary()

loss, acc = test_model.evaluate(x_test, np_utils.to_categorical(y_test), classes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use evaluate with a data tensor in this example?

Copy link
Contributor Author

@ahundt ahundt Aug 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized we've started along the path to reimplementing #6928, and would require changes that separate numpy arrays from tensors. That PR allows the labels y to be passed as a tensor. For that reason I think we should revert support for evaluate(steps) and predict(steps), fix the docstring issues, and limit this PR to fit().

Full support can be integrated with what we called API 3 via #6928.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(i.e. all we need is add steps to these methods)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The simplicity of the test hides the lack of underlying implementation.

In this test there is no label y to evaluate against. You'll see in the _test_loop() the output of f(ins) is not saved in the steps case, which is ok for this PR as it was originally conceived.

This seemingly simple missing component actually requires some implementation to work. I'm not sure if supplying a Tensor as the parameter y will give results in the format described by #Returns. Perhaps simply retaining a list of the outputs is OK like in _step_loop()? (the link is to the #6928 version of the changes I reverted here)

I'm not sure about indexing correctly to calculate the summaries for each output in this case, e.g. something equivalent to outs[i][batch_start:batch_end] = batch_out in _predict_loop() or in _test_loop there is outs[i] += batch_out * len(batch_ids) and then outs[i] /= samples.

@ahundt ahundt force-pushed the fit_steps_per_epoch branch from 3e13d6c to 56736f3 Compare August 8, 2017 20:54
@ahundt
Copy link
Contributor Author

ahundt commented Aug 8, 2017

I merged the changes to predict(steps) and evaluate(steps) into #6928 and reverted them here.
That's because such a change will require adding support for an input tensor as a label, a change that is in the scope of #6928 and out of scope for this PR.

@fchollet
Copy link
Collaborator

fchollet commented Aug 8, 2017

I think it's preferable to have them here for API consistency. Like in fit, it just requires your model to be compiled with a None loss (and a call to add_loss before compilation).

Copy link
Contributor Author

@ahundt ahundt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussing the requirements of predict(steps) and evaluate(steps) below. Perhaps I misunderstood and the accuracy is calculated somewhere in the mnist_tfrecord.py call to fit()?

I believe only loss is calculated and no accuracy data is available.

if verbose == 1:
progbar = Progbar(target=steps)
for step_num in range(steps):
f(ins)
Copy link
Contributor Author

@ahundt ahundt Aug 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing is saved out in this case. Can the results be collated in a predictable manner? I'm not yet sure if the answer may depend on the input tensor the user supplies, such as RecordInput vs tf.train.shuffle_batch.

test_model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
test_model.summary()

loss, acc = test_model.evaluate(x_test, np_utils.to_categorical(y_test), classes)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The simplicity of the test hides the lack of underlying implementation.

In this test there is no label y to evaluate against. You'll see in the _test_loop() the output of f(ins) is not saved in the steps case, which is ok for this PR as it was originally conceived.

This seemingly simple missing component actually requires some implementation to work. I'm not sure if supplying a Tensor as the parameter y will give results in the format described by #Returns. Perhaps simply retaining a list of the outputs is OK like in _step_loop()? (the link is to the #6928 version of the changes I reverted here)

I'm not sure about indexing correctly to calculate the summaries for each output in this case, e.g. something equivalent to outs[i][batch_start:batch_end] = batch_out in _predict_loop() or in _test_loop there is outs[i] += batch_out * len(batch_ids) and then outs[i] /= samples.

@ahundt
Copy link
Contributor Author

ahundt commented Aug 9, 2017

@fchollet @TimZaman @Dref360 Considering the need for predict & evaluate, which then requires labels from input tensors, could we simply continue the review over at #6928, or is there a better approach?

Everything from this PR #7113 is merged into #6928.
While I know #6928 is rather large but it is the only PR that can provide a smooth transition in a single merge with decent overall API consistency. I've got everything working now in #6928 as follows:

# API 2

model = # on top of the tensor input
model.add_loss()  # involving y_tensor
model.fit(epochs=10, steps_per_epoch=1000)

# API 3

model = # on top of the tensor input
model.compile()
model.fit(y=y_tensor, epochs=10, steps_per_epoch=1000)

There are extended unit tests and support for Input Tensors with each of the following APIs:

    # train_on_batch
    out = model.train_on_batch(x, y)

    # test_on_batch
    out = model.test_on_batch(x, y)

    # predict_on_batch
    out = model.predict_on_batch(x)

    # fit
    out = model.fit(x, y, epochs=1, batch_size=batch_size,
                    steps_per_epoch=steps_per_epoch)

    # evaluate
    out = model.evaluate(x, y, batch_size=batch_size,
                         steps=steps_per_epoch)

    # predict
    out = model.predict(x, batch_size=batch_size,
                        steps=steps_per_epoch)

I've also incorporated and merged all feedback from all the split atomic PRs into #6928, so it is the canonical version with the most complete functionality for the purposes discussed here, but I'm ok with an alternative solution that works as well. :-)

steps: Total number of steps (batches of samples)
before declaring _predict_loop finished.
Ignored with the default value of `None`.
steps_name: The public API's parameter name for `steps`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring should have a Returns section and a Raises section.

ins: list of tensors to be fed to the Keras function.
batch_size: integer batch size or None if unknown.
steps: Total number of steps (batches of samples)
before declaring _predict_loop finished.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using code markers around code keywords (.e.g _predict_loop).

if hasattr(ins[0], 'shape'):
validation_steps = steps_per_epoch
else:
raise ValueError('When `steps_per_epoch` validation '
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"When using steps_per_epoch, ..."


self.history = cbks.History()
callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history]
if verbose:
callbacks += [cbks.ProgbarLogger()]
if steps_per_epoch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer if steps_per_epoch is not None

@ahundt
Copy link
Contributor Author

ahundt commented Aug 11, 2017

@fchollet Addressed the review items here.

ahundt added a commit to ahundt/keras that referenced this pull request Aug 11, 2017
* commit '41fdd5aa5fe080bf5fa12638b9abc9f50d835a39':
  training.py improve docstrings and error case
  fix docstring comments from review (keras-team#7113)
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fchollet fchollet merged commit d687c6e into keras-team:master Aug 11, 2017
@TimZaman
Copy link
Contributor

TimZaman commented Aug 11, 2017 via email

ahundt added a commit to ahundt/keras that referenced this pull request Aug 11, 2017
* master:
  model.fit(steps_per_epoch), 	mnist_tfrecord.py, progbar np.mean (keras-team#7113)
  Change default size to allow different MobileNet sizes (keras-team#7586)
  cntk backend: fix the reversed rnn bug (keras-team#7593)
  Fix mask for multi output --> multi inputs (keras-team#7591)
  [RELNOTES] Subtract merge layer (keras-team#7573)
  update docker with cntk 2.1 (keras-team#7565)
  [RELNOTES] Move constraint management to be based on variable attributes (like TF). (keras-team#7564)
  Add handling for `dtype` arg in initializer config.
  Fix keras-team#7550. (keras-team#7552)
  remove unnecessary function definition (keras-team#7542)
  refactor the function - _convert_string_dtype (keras-team#7540)
  Add batchnorm tests for CNTK (keras-team#7534)
  Clean up RNN tests (keras-team#7529)
  Add merge tests (keras-team#7533)

# Conflicts:
#	examples/mnist_tfrecord.py
#	keras/engine/training.py
#	keras/utils/generic_utils.py
#	tests/keras/engine/test_training.py
@ahundt
Copy link
Contributor Author

ahundt commented Aug 11, 2017

Yay!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants