-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic Attention attribution #148
Conversation
…o model initialization
…bution function now)
… torch to ~1.12.1 again because of platform issue: pytorch/pytorch#88826
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @lsickert, thank you for submitting a PR! We will respond as soon as possible.
Note: it's good to have the summary issue linked here, but we don't want to close it just yet! :) |
@gsarti I was working on the decoder-only models now, and came across some inconsistencies across the different models. For example both GPT and Transformer XL will only include |
Hi @lsickert, good question! The cross-attentions are defined for GPT2 and other decoder-only to support their usage as components of an encoder-decoder as part of the |
Ah perfect. Yes I assumed something like that but was not entirely sure. I think then the decoder-only support is now done for the basic attention functions. I will still need to write tests tomorrow and finish up the docstrings and other small things, but apart from that I think the branch is ready for merging. |
Also, some usage issues I identified:
out = model.attribute("The cafeteria had 23 apples. They used 20 for lunch. How many apples do they have left?")
/usr/local/lib/python3.8/dist-packages/inseq/data/attribution.py in <listcomp>(.0)
115 sources = None
116 if attr.source_attributions is not None:
--> 117 sources = [drop_padding(attr.source[seq_id], pad_id) for seq_id in range(num_sequences)]
118 targets = [
119 drop_padding([a.target[seq_id][0] for a in attributions], pad_id) for seq_id in range(num_sequences)
TypeError: 'NoneType' object is not subscriptable is it possible that you are not setting source attributions to
|
Yes, the second point was my bad and should be fixed already. I did not notice that the changes to |
The first error is fixed by making sure the dimensional size of the tensors stays the same for all tokens. Very interesting behavior, though, since I remember, I specifically had to put in the |
… range of layers to be specified for averaging in AggregatedAttention
@gsarti I think all open points should be addressed now. Please feel free to test already while I clean up a bit and work on updating and creating the docstrings tomorrow. |
Thank you for the update! After giving it some more thought, I decided to opt for a single centralized class for basic attention attribution. The decision was mainly driven to avoid confusion on which class to use, and aimed at enabling more flexibility with the choice of heads and layers for aggregation. The Example of default usage: import inseq
model = inseq.load_model("facebook/wmt19-en-de", "attention")
out = model.attribute("The developer argued with the designer because her idea cannot be implemented.") The default behavior is set to minimize unnecessary parameter definitions. In the default case above, the result is the average across all attention heads of the final layer. Here's a more complex usage: import inseq
model = inseq.load_model("facebook/wmt19-en-de", "attention")
out = model.attribute(
"The developer argued with the designer because her idea cannot be implemented.",
layers=(0, 5),
heads=[0, 2, 5, 7],
aggregate_heads_fn = "max"
) In the case above, the outcome is a matrix of maximum attention weights of heads 0, 2, 5 and 7 after averaging their weights across the first 5 layers of the model. Remaining todos:
|
Added some tests for attention attribution, fixed the typing issue of |
@gsarti I think we were working on the same remaining points right now. I got a bit confused about the I am currently still finishing on a test for those aggregation functions specifically (outside of the normal pipeline), but then I would also agree that we are good to go. |
@lsickert feel free to merge as soon as CI is passing! 🎉 |
Description
This PR adds the base-class for attribution methods based on attention as well as two basic attention attribution methods (aggregated attention and last-layer attention).
It also includes a small fix regarding the rounding of outputs in the cli tableview
It reverts the previous upgrade of pytorch to ^1.13.0 because of an issue with installing the dependency on certain platforms such as OSX (see related issues in pytorch: issue1, issue2
Related Issue
108
Type of Change
Checklist
CODE_OF_CONDUCT.md
document.CONTRIBUTING.md
guide.make codestyle
.