Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Desiderata] Captum-like implementation for Inseq compatibility #1

Open
gsarti opened this issue Feb 12, 2024 · 6 comments
Open

[Desiderata] Captum-like implementation for Inseq compatibility #1

gsarti opened this issue Feb 12, 2024 · 6 comments
Labels
enhancement New feature or request

Comments

@gsarti
Copy link

gsarti commented Feb 12, 2024

Hi @rachtibat,

Great job on AttnLRP, your LRP adaptation seems very promising to attribute the behavior of Transformer-based LMs!

Provided you guys are still working on the codebase, I was wondering whether it would be possible to have an implementation that is interoperable with the Captum LRP class. This would allow us to include the method in the inseq library (reference: inseq-team/inseq#122), enabling out-of-the-box support for:

  • Using AttnLRP on any Transformers-based decoder-only and encoder-decoder model in 🤗 transformers.
  • Easy benchmarking with other attribution approaches.
  • Advanced post-processing of attribution output and customization of attribution targets (e.g. contrastive attribution)

inseq has been already used in a number of publications since its release last year, and having an accessible implementation of AttnLRP there would undoubtedly help to democratize the access to your method.
From an implementation perspective, I'm not an LRP connaisseur but my understanding is that for ensuring Captum compatibility it would be enough to specify your custom propagation rules matching the base class provided here.

Let me know your thoughts, and congrats again on the excellent work!

@rachtibat
Copy link
Owner

Hey @gsarti,

I'm glad that you like our paper! You do great work at inseq!

While LRP substantially outperforms other methods, it has an initial 'set-up' cost i.e. there is currently no implementation in PyTorch that is able to automatically apply the rules to all operations in a PyTorch graph. For instance, in LRP we must apply the epsilon rule on every summation operation. This means, that if we have a line of code such as
c = a + b
we have to attach our LRP rule to this line of code somehow. In this repository, I am implementing custom PyTorch autograd function. This means that we have to replace the line of code with
c = epsilon_sum.apply(a, b)
So, the user has to put in some extra effort.

I'm not aware of a way to do this kind of code manipulation/graph manipulation on the fly. I just found this tutorial on torch.fx. Maybe this is the solution?

As a consequence, I'm implementing LRP right now in the style of zennit, but I'm optimistic that we can somehow integrate it into captum for pre-defined model architectures such as Llama 2 etc. Just to run some benchmarks against other methods for instance.
(The LRP implementation of captum is not optimal in our usecase because they use hooks and hooks are quite inefficient, but maybe we can agree on a new adaption of their LRP class?).

Best greetings, and thank you again (:

@gsarti
Copy link
Author

gsarti commented Feb 13, 2024

Thanks for your prompt reply @rachtibat!

I see the issue with setup costs, thanks for clarifying! I had an in-depth look to torch.fx some time ago for inference-time mid-forward interventions (e.g. for the Value Zeroing method we're adding in PR inseq-team/inseq#173), and I also had a chance to chat about it with Captum lead devs at EMNLP. In general, it is very cumbersome and counter-intuitive to perform very targeted interventions, but maybe for replacing all operations of a specific type it can be manageable. I'd be very interested to see if you come up with a solution using torch.fx to make the implementation generalizable!

Would the zennit implementation you have in mind support multi-token generation? In my experience, this is the main limiting factor to applying such techniques to autoregressive LMs (which we address by looping attribution in inseq), especially since people usually want to customize generation parameters à-la-transformers without reinventing the wheel.

@rachtibat
Copy link
Owner

Hey @gsarti,

awesome, that you already had so much experience with torch.fx. Alright, good to know, maybe it is really manageable with simple operation replacement without doing fancy manipulations.

I'm not quite sure what you mean by multi-token generation, but I try to give you an idea, what is possible if someone wants to explain several tokens at once:
Assumed a LLM generated a sequence of N tokens.

  • If we want to explain token N-2, we have to perform a forward pass with N-2 tokens (tokens after N-2 are not necessary) and one backward pass, where we initialize the relevance for the backward pass at output position N-2 with the logit output of the model and initialize all remainder outputs to zero.
  • If we want to explain token 2, 5, N-4 at once, i.e. a superposition of three attributions, we perform a forward pass with N-4 tokens and one backward pass where we initialize the relevance with output logits 2, 5, N-4 and set all remainder output relevances to zero.
  • if we want to explain token 2, 5, N-4 separated, then we must perform 3 forward and 3 backward passes with 2, 5 and N-4 tokens each or perform the attribution with batchsize 3 with N-4 tokens each (to allow parallilization). Or perform 1 forward pass with N-4 tokens, but we keep the backward graph in memory and perform 3 backward passes. This way, we obtain three independent attributions of each token isolated.

I hope this explains it, if it is unclear, you can ask again (:

@gsarti
Copy link
Author

gsarti commented Feb 15, 2024

Thanks for the response @rachtibat! To clarify, the background to my question was that typically library like Captum provide an interface to streamline the attribution of a single forward output (the first bullet point you describe). However, there is no simple abstraction to automate the "one attribution per generation step" process you describe in the third bullet point (although in the case of Captum, they actually added something akin to this in v0.7). The main reason of inseq existence was precisely to automate this process while enabling the full customization of the model.generate method of 🤗 transformers.

The 2nd approach you mention (the one proposing a "superposition" of 3 attributions) looks very interesting, and I think it's the first time I see this idea! But I have a doubt: this would mean, effectively, taking the output logit of previous tokens (e.g. 2, and 5 in your example) when computing the forward for token N-4 and using it to propagate relevance back into the model. Don't you think this is a bit unnatural to extract rationales, provided only the last token when computing predictions at every generation step? Not sure what information the preceding embeddings would provide in this context. Curious to hear your thoughts!

@rachtibat
Copy link
Owner

Hey,

afaik transformers are trained with a next token prediction at any output position. If you look at the huggingface implementation of Llama 2 for instance you see that the labels for CrossEntropy are the inputs shifted by one. So the model actually predicts at each output token and not just the last token. Because of the causual masking in attention heads, each output position N can only see the prior N-1 input tokens and does an independent prediction.
This is why, we can actually do what I've described in bullet point 2 (:

I already tried it and it is equal to computing the attribution for each output token separately and adding them up or computing a superimposed heatmap at once. This is also due to the fact that LRP is an additive explanatory model i.e. the attribution can be disentangled into several independent relevance flows. We described this phenomenon in this paper:
https://www.nature.com/articles/s42256-023-00711-8

So, I think this might be a feature only present in additive explanatory models.
I hope it is somewhat clear (:

@gsarti
Copy link
Author

gsarti commented Feb 23, 2024

This is very interesting, you're right! I was thinking of inference, but it is true that at training time the model does indeed predict a token at every position. The fact that it results in a simple sum of independent relevance flows is definitely an upside of additive models, looking forward to test it out! :)

@rachtibat rachtibat added the enhancement New feature or request label Nov 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants