Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

How to get intermediate values of pretrained models? #2527

Closed
snie2012 opened this issue Feb 19, 2019 · 9 comments
Closed

How to get intermediate values of pretrained models? #2527

snie2012 opened this issue Feb 19, 2019 · 9 comments
Assignees

Comments

@snie2012
Copy link

snie2012 commented Feb 19, 2019

Is it possible to get the intermediate values of the hidden layers for the pretrained sequence tagging models when making prediction? For example, the output of the LSTM layer of a pretrained tagging model. If so, how?

@joelgrus
Copy link
Contributor

I have an experimental PR that does this:

https://github.com/allenai/allennlp/pull/2211/files#diff-a64f3186684b9877702c7c4e2540950aR63

basically you can register a hook on the LSTM layer and use that hook to grab its output. that PR is still kind of half-baked, but it should point you in the right direction.

@HarshTrivedi
Copy link
Contributor

If the model is your own code, then you can also change that to have hidden state in the output dict in forward method. Eg. output_dict["hidden_state"] = my_lstm_state. (Make sure first dim of my_lstm_state is batch_size). Even if the pretrained model was generated with old model code, when it's loaded the new code will be used. If you are using default predict_instance in your predictor, it will have the santized hidden state in output already. If not you need to make sure that key hidden_state key is actually passed on in the returned json dict of predict_instance.

@HarshTrivedi
Copy link
Contributor

However, what @joelgrus is suggesting is obviously cleaner, since that way you don't make hardcoded changes in model code just to tweak what predictor needs to give out at prediction time.

@snie2012
Copy link
Author

Thanks for your replies @joelgrus @HarshTrivedi . I am also looking at using forward hooks to get intermediate values mainly from pretrained models. It works pretty well:)

@MeiqiGuo
Copy link

@snie2012 Hi, could you please specify how did you solve this issue? I am also looking for intermediate values of pertained model on new data. Thanks in advance!

@joelgrus
Copy link
Contributor

I have a new PR that's even cleaner, but it's not merged yet:

https://github.com/allenai/allennlp/pull/2581/files

@snie2012
Copy link
Author

@MeiqiGuo As @joelgrus says, use hooks can do the job.
In addition from the above example, here is another one flairNLP/flair#524 (comment)

@MeiqiGuo
Copy link

@snie2012 @joelgrus Thanks for your help! I tried changing the forward function and it works. I will try this cleaner version with hooks later.

@schmmd
Copy link
Member

schmmd commented Mar 15, 2019

Closing since #2581 is merged.

@schmmd schmmd closed this as completed Mar 15, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants