Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to get intermediate values of pretrained models? #524

Closed
snie2012 opened this issue Feb 19, 2019 · 4 comments
Closed

How to get intermediate values of pretrained models? #524

snie2012 opened this issue Feb 19, 2019 · 4 comments
Labels
feature A new feature question Further information is requested wontfix This will not be worked on

Comments

@snie2012
Copy link

Is it possible to get the intermediate values of the hidden layers for the pretrained sequence tagging models when making prediction? For example, the output of the LSTM layer of a pretrained POS tagging model. If so, how?

@snie2012 snie2012 added the question Further information is requested label Feb 19, 2019
@snie2012 snie2012 changed the title Intermediate values of pretrained models How to get intermediate values of pretrained models? Feb 19, 2019
@alanakbik alanakbik added the feature A new feature label Feb 19, 2019
@alanakbik
Copy link
Collaborator

Hello @snie2012 that's a good idea - unfortunately there is currently no in-built way to do this, but this is something we should really add. I'll add a 'feature' tag to this issue.

For now, one thing you could do is to modify the forward() method of the SequenceTagger so that it returns the intermediate tensors as well and get the embeddings from there.

@snie2012
Copy link
Author

Thanks for your reply @alanakbik . It'll be great if there is built in method to do this. Currently, I manage to use register_forward_hook to get the intermediate values. It works something like this:

sentence = Sentence('I love Berlin .')
pos_tagger = SequenceTagger.load('pos')

output = torch.zeros(4, 1, 53)  # The intermediate value will be stored in this variable 
def hook(m, i, o): output.copy_(o.data)  # The hook that copies intermediate value during forward pass
pos_tagger.linear.register_forward_hook(hook)  # Register the hook

pos_tagger.predict(sentence)  # Run the model on a sentence
print(output)  # The intermediate value is now in output

@alanakbik
Copy link
Collaborator

Ah this is great and might be a good blueprint for a convenience method that uses hooks like this to get the embeddings. So thanks for sharing this solution!

@stale
Copy link

stale bot commented Apr 30, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix This will not be worked on label Apr 30, 2020
@stale stale bot closed this as completed May 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A new feature question Further information is requested wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants