Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrated Gradient Explainer for multiple inputs #316

Closed
hugokitano opened this issue Nov 1, 2020 · 4 comments · Fixed by #321
Closed

Integrated Gradient Explainer for multiple inputs #316

hugokitano opened this issue Nov 1, 2020 · 4 comments · Fixed by #321
Labels
Type: Method extension Extensions to existing methods

Comments

@hugokitano
Copy link

Hi, I'm trying to use alibi's integrated gradients tool for a Keras model for multiple inputs. When I try to instantiate the IntegratedGradients object, I get a AttributeError: 'list' object has no attribute 'dtype' error.

Basically, my model takes in a sequence, which is run through an RNN and concatenated with several other secondary features. Then, that concatenation is run through some dense layers and outputs a probability between 0 and 1.

def single_lstm_model(num_lstm_units=16):
    seq = keras.Input(shape=(None, 7))
    other = keras.Input(shape=7)

    x = keras.layers.Bidirectional(keras.layers.LSTM(num_lstm_units))(seq)

    x = keras.layers.Concatenate()([x, other])
    x = keras.layers.Dense(16)(x)

    for _ in range(2):
        x = recurrent_dense(x, 16)

    outputs = keras.layers.Dense(1, activation='sigmoid')(x)
    return keras.Model(inputs=[seq, other], outputs=outputs)

(where recurrent_dense is simply some dense layers with batch normalization).

Because of the multiple-input nature of the model, the training/validation/test datasets are all of peculiar types as well:
<ShuffleDataset shapes: (((None, 30, 7), (None, 7)), (None,)), types: ((tf.float32, tf.float64), tf.float32)>
where the two kinds of inputs are basically in a list, which I guess is causing the error.

Any suggestions for working with this kind of model? Or other tools you think might work well? Thank you!

Hugo

@jklaise jklaise added the Type: Method extension Extensions to existing methods label Nov 2, 2020
@jklaise
Copy link
Contributor

jklaise commented Nov 2, 2020

@hugokitano currently the method doesn't support models with multiple inputs but we're keen to extend it and it shouldn't be too much work to support taking gradients with respect to a list of inputs tensors as in your model.

One thing that's a bit more complex is combining multiple input models with taking gradients wrt layers which are not input layers, e.g. in the text example you have to take gradients wrt the embedding layer not the token input layer (because the token->embedding layer is not differentiable) which would be quite common with text models. This could get complex with multiple inputs.

Is your use case a text model with sequences of tokens as input? Or can we assume that the first layer is differentiable and we can take gradients with respect to the input?

@hugokitano
Copy link
Author

My case is a sequence of one-hot encoded nucleotides ("A", "C", "T", and "G", as in DNA). This is fed into an RNN, and those embeddings are concatenated to single-value features that are fed into a dense network. So there is no embedding layer, but you do bring up some interesting points about what layer to take gradients from. Would I have to take gradients with respect to both the first layer of the RNN and the first layer of the dense network?

@hugokitano
Copy link
Author

I think the best layer is probably the last dense layer, now that I think about it, before the sigmoid function.

@jklaise
Copy link
Contributor

jklaise commented Nov 3, 2020

So if there is no embedding layer the graph should be differentiable end-to-end, so should be good to take gradients wrt any layer. Which layer to choose is use case dependent, will attributions for the last dense layer be interpretable?

@gipster gipster linked a pull request Nov 25, 2020 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Type: Method extension Extensions to existing methods
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants