Skip to content

Commit

Permalink
(keras) Make image argument required
Browse files Browse the repository at this point in the history
  • Loading branch information
teabolt committed Aug 5, 2019
1 parent dc0dc18 commit 0503324
Show file tree
Hide file tree
Showing 35 changed files with 195 additions and 107 deletions.
90 changes: 57 additions & 33 deletions docs/source/_notebooks/keras-image-classifiers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ Loading our sample image:
# we start from a path / URI.
# If you already have an image loaded, follow the subsequent steps
image = 'imagenet-samples/cat_dog.jpg'
image_uri = 'imagenet-samples/cat_dog.jpg'
# this is the original "cat dog" image used in the Grad-CAM paper
# check the image with Pillow
im = Image.open(image)
im = Image.open(image_uri)
print(type(im))
display(im)
Expand All @@ -96,14 +96,14 @@ dimensions! Let's resize it:
# we could resize the image manually
# but instead let's use a utility function from `keras.preprocessing`
# we pass the required dimensions as a (height, width) tuple
im = keras.preprocessing.image.load_img(image, target_size=dims) # -> PIL image
print(type(im))
im = keras.preprocessing.image.load_img(image_uri, target_size=dims) # -> PIL image
print(im)
display(im)
.. parsed-literal::
<class 'PIL.Image.Image'>
<PIL.Image.Image image mode=RGB size=224x224 at 0x7FD4FC485DD8>
Expand Down Expand Up @@ -143,7 +143,6 @@ Looking good. Now we need to convert the image to a numpy array.
.. code:: ipython3
# one last thing
# `keras.applications` models come with their own input preprocessing function
# for best results, apply that as well
Expand Down Expand Up @@ -171,6 +170,28 @@ inputting
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_13_0.png


One last thing, to explain image based models, we need to pass the image
as a PIL object explicitly. However, it must have mode 'RGBA'

.. code:: ipython3
print(im) # current mode
image = im.convert(mode='RGBA') # add alpha channel
print(image)
display(image)
.. parsed-literal::
<PIL.Image.Image image mode=RGB size=224x224 at 0x7FD4FC485DD8>
<PIL.Image.Image image mode=RGBA size=224x224 at 0x7FD4DB62EF28>
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_15_1.png


Ready to go!

2. Explaining our model's prediction
Expand Down Expand Up @@ -218,12 +239,15 @@ for a dog with ELI5:

.. code:: ipython3
eli5.show_prediction(model, doc)
# we need to pass the network
# the input as a numpy array
# and the corresponding input image (RGBA mode)
eli5.show_prediction(model, doc, image=image)
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_19_0.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_21_0.png



Expand All @@ -241,12 +265,12 @@ classifier looks to find those objects.
.. code:: ipython3
cat_idx = 282 # ImageNet ID for "tiger_cat" class, because we have a cat in the picture
eli5.show_prediction(model, doc, targets=[cat_idx]) # pass the class id
eli5.show_prediction(model, doc, image=image, targets=[cat_idx]) # pass the class id
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_22_0.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_24_0.png



Expand All @@ -259,16 +283,16 @@ Currently only one class can be explained at a time.
window_idx = 904 # 'window screen'
turtle_idx = 35 # 'mud turtle', some nonsense
display(eli5.show_prediction(model, doc, targets=[window_idx]))
display(eli5.show_prediction(model, doc, targets=[turtle_idx]))
display(eli5.show_prediction(model, doc, image=image, targets=[window_idx]))
display(eli5.show_prediction(model, doc, image=image, targets=[turtle_idx]))
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_24_0.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_26_0.png



.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_24_1.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_26_1.png


That's quite noisy! Perhaps the model is weak at classifying 'window
Expand Down Expand Up @@ -345,7 +369,7 @@ Rough print but okay. Let's pick a few convolutional layers that are
for l in ['block_2_expand', 'block_9_expand', 'Conv_1']:
print(l)
display(eli5.show_prediction(model, doc, layer=l)) # we pass the layer as an argument
display(eli5.show_prediction(model, doc, image=image, layer=l)) # we pass the layer as an argument
.. parsed-literal::
Expand All @@ -354,7 +378,7 @@ Rough print but okay. Let's pick a few convolutional layers that are
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_29_1.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_31_1.png


.. parsed-literal::
Expand All @@ -363,7 +387,7 @@ Rough print but okay. Let's pick a few convolutional layers that are
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_29_3.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_31_3.png


.. parsed-literal::
Expand All @@ -372,7 +396,7 @@ Rough print but okay. Let's pick a few convolutional layers that are
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_29_5.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_31_5.png


These results should make intuitive sense for Convolutional Neural
Expand All @@ -393,7 +417,7 @@ better understand what is going on.

.. code:: ipython3
expl = eli5.explain_prediction(model, doc)
expl = eli5.explain_prediction(model, doc, image=image)
Examining the structure of the ``Explanation`` object:

Expand All @@ -417,7 +441,7 @@ Examining the structure of the ``Explanation`` object:
[0. , 0. , 0. , 0. , 0. ,
0. , 0.05308531],
[0. , 0. , 0. , 0. , 0. ,
0.01124764, 0.06864655]]))], feature_importances=None, decision_tree=None, highlight_spaces=None, transition_features=None, image=<PIL.Image.Image image mode=RGBA size=224x224 at 0x7FCA6FD17CC0>)
0.01124764, 0.06864655]]))], feature_importances=None, decision_tree=None, highlight_spaces=None, transition_features=None, image=<PIL.Image.Image image mode=RGBA size=224x224 at 0x7FD4DB62EF28>)
We can check the score (raw value) or probability (normalized score) of
Expand Down Expand Up @@ -446,7 +470,7 @@ We can also access the original image and the Grad-CAM heatmap:
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_39_0.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_41_0.png


.. parsed-literal::
Expand Down Expand Up @@ -476,7 +500,7 @@ Visualizing the heatmap:
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_41_0.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_43_0.png


That's only 7x7! This is the spatial dimensions of the
Expand All @@ -494,7 +518,7 @@ resampling):
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_43_0.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_45_0.png


Now it's clear what is being highlighted. We just need to apply some
Expand All @@ -508,7 +532,7 @@ colors and overlay the heatmap over the original image, exactly what
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_45_0.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_47_0.png


6. Extra arguments to ``format_as_image()``
Expand All @@ -525,7 +549,7 @@ colors and overlay the heatmap over the original image, exactly what
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_48_0.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_50_0.png


The ``alpha_limit`` argument controls the maximum opacity that the
Expand Down Expand Up @@ -554,7 +578,7 @@ check the explanation:
# first check the explanation *with* softmax
print('with softmax')
display(eli5.show_prediction(model, doc))
display(eli5.show_prediction(model, doc, image=image))
# remove softmax
Expand All @@ -566,7 +590,7 @@ check the explanation:
model = keras.models.load_model('tmp_model_save_rmsoftmax')
print('without softmax')
display(eli5.show_prediction(model, doc))
display(eli5.show_prediction(model, doc, image=image))
.. parsed-literal::
Expand All @@ -575,7 +599,7 @@ check the explanation:
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_51_1.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_53_1.png


.. parsed-literal::
Expand All @@ -584,7 +608,7 @@ check the explanation:
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_51_3.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_53_3.png


We see some slight differences. The activations are brighter. Do
Expand All @@ -610,9 +634,9 @@ loading another model and explaining a classification of the same image:
nasnet.preprocess_input(doc2)
print(model.name)
display(eli5.show_prediction(model, doc))
display(eli5.show_prediction(model, doc, image=image))
print(model2.name)
display(eli5.show_prediction(model2, doc2))
display(eli5.show_prediction(model2, doc2, image=image))
.. parsed-literal::
Expand All @@ -621,7 +645,7 @@ loading another model and explaining a classification of the same image:
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_54_1.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_56_1.png


.. parsed-literal::
Expand All @@ -630,7 +654,7 @@ loading another model and explaining a classification of the same image:
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_54_3.png
.. image:: ../_notebooks/keras-image-classifiers_files/keras-image-classifiers_56_3.png


Wow ``show_prediction()`` is so robust!
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Diff not rendered.
Diff not rendered.
6 changes: 6 additions & 0 deletions docs/source/libraries/keras.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ Important arguments to :func:`eli5.explain_prediction` for ``Model`` and ``Seque

- Check ``model.input_shape`` to confirm the required dimensions of the input tensor.

* ``image`` Pillow image, corresponds to doc input.

- **Must be passed for image explanations.**

- **Must have mode "RGBA".**

* ``target_names`` are the names of the output classes.

- *Currently not implemented*.
Expand Down
22 changes: 16 additions & 6 deletions eli5/keras/explain_prediction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from typing import Union, Optional, Callable, Tuple, List
from typing import Union, Optional, Callable, Tuple, List, TYPE_CHECKING
if TYPE_CHECKING:
import PIL # type: ignore

import numpy as np # type: ignore
import keras # type: ignore
Expand All @@ -22,10 +24,11 @@
@explain_prediction.register(Model)
def explain_prediction_keras(estimator, # type: Model
doc, # type: np.ndarray
image=None, # type: Optional['PIL.Image.Image']
target_names=None,
targets=None, # type: Optional[list]
layer=None, # type: Optional[Union[int, str, Layer]]
):
):
# type: (...) -> Explanation
"""
Explain the prediction of a Keras image classifier.
Expand Down Expand Up @@ -63,6 +66,16 @@ def explain_prediction_keras(estimator, # type: Model
:raises TypeError: if ``doc`` is not a numpy array.
:raises ValueError: if ``doc`` shape does not match.
:param image:
Pillow image over which to overlay the heatmap.
Corresponds to the input ``doc``.
Must have mode 'RGBA'.
:type image: PIL.Image.Image, optional
:param target_names:
*Not Implemented*.
Names for classes in the final output layer.
Expand Down Expand Up @@ -108,6 +121,7 @@ def explain_prediction_keras(estimator, # type: Model
* ``target`` ID of target class.
* ``score`` value for predicted class.
"""
assert image is not None
_validate_doc(estimator, doc)
activation_layer = _get_activation_layer(estimator, layer)

Expand All @@ -119,10 +133,6 @@ def explain_prediction_keras(estimator, # type: Model
weights, activations, grads, predicted_idx, predicted_val = values
heatmap = gradcam(weights, activations)

doc, = doc # rank 4 batch -> rank 3 single image
image = keras.preprocessing.image.array_to_img(doc) # -> RGB Pillow image
image = image.convert(mode='RGBA')

return Explanation(
estimator.name,
description=DESCRIPTION_KERAS,
Expand Down
157 changes: 99 additions & 58 deletions notebooks/keras-image-classifiers.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def test_get_target_prediction_invalid(simple_seq):


def test_explain_prediction_score(simple_seq):
expl = explain_prediction(simple_seq, np.zeros((1, 32, 32, 1)))
expl = explain_prediction(simple_seq,
np.zeros((1, 32, 32, 1)),
image=True) # TODO: dummy image
assert expl.targets[0].score is not None
assert expl.targets[0].proba is None

Expand Down
Loading

0 comments on commit 0503324

Please sign in to comment.