-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perplexity Speedup #4108
Perplexity Speedup #4108
Conversation
WRT the high values, can you add some unit tests with some [string, model] pairs and their resulting perplexity code, and @TristanThrush can run the same pairs through his version of the code? |
The documentation is not available anymore as the PR was closed or merged. |
def main(): | ||
print("\n\n\n------------ WIKITEXT: ------------") | ||
input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", keep_in_memory=False, split="test")["text"][:50] | ||
print("input len with empty strings:", len(input_texts)) | ||
input_texts = [s for s in input_texts if s != ""] | ||
print("input len without empty strings:", len(input_texts)) | ||
|
||
perplexity = datasets.load_metric("../perplexity") | ||
|
||
results_w_start = perplexity.compute(input_texts=input_texts, model_id="gpt2", device="gpu", add_start_token=True) | ||
print("\n\nRES W START:", results_w_start) | ||
|
||
results_no_start = perplexity.compute( | ||
input_texts=input_texts, model_id="gpt2", device="gpu", add_start_token=False | ||
) | ||
print("\n\nRES NO START:", results_no_start) | ||
|
||
print("\n\n\n------------ SMALL SNIPPETS: ------------") | ||
input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"] | ||
print(input_texts) | ||
|
||
results_w_start = perplexity.compute(input_texts=input_texts, model_id="gpt2", device="gpu", add_start_token=True) | ||
print("\n\nRES W START:", results_w_start) | ||
|
||
results_no_start = perplexity.compute( | ||
input_texts=input_texts, model_id="gpt2", device="gpu", add_start_token=False | ||
) | ||
print("\n\nRES NO START:", results_no_start) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference, the results I get from this code are below. Note that they are all extremely high, and many are inf
:
`--- WIKITEXT: ---
RES W START: {'perplexity': [5.26044112490735e+33, 1754766870839296.0, 4.851017998674971e+21, 1.6137719812847106e+33, 1.123557053605896e+35, 1.4592395342302413e+17, 2.2242306063538344e+21, 1.6256261294134198e+37, 5.6552604843696494e+23, 5.283058871869749e+23, 1.362616559062237e+31, 3.2894723458301157e+30, inf, 2.8323108695319055e+32, 4.517404432663069e+36, 6.06021648422544e+21, 2.7673238221998803e+25, inf, 1.4324680994914391e+32, 1.0707699658437287e+29, 1.3665097609664754e+23, 2.646539983794732e+26, 1.1760327816258613e+21, 5.623264614959543e+28, 8.202277198380542e+22, 1.6075579260687217e+30, 7.034012837768928e+25]}
RES NO START: {'perplexity': [6.347511649255629e+27, 7.746963798754648e+35, 5.412155643337211e+36, 7.443727567431987e+26, 7.610168923774734e+33, 6.160369930664121e+35, 5.744005654242058e+36, 6.396653209350697e+32, inf, inf, 4.597405287035285e+27, 1.5910406783685568e+27, 1.2694701550401124e+26, 9.076024979217206e+25, 2.580195144338176e+31, inf, 1.8766434173654982e+35, 1.4048184593730499e+29, inf, 5.501222136759501e+28, inf, inf, 1.6581890056188304e+38, inf, inf, inf, inf]}
- for reference, this is the list I get when using my previous method on the same data:
[110.6924819946289, 16.05736541748047, 13.696072578430176, 146.76963806152344, 34.011749267578125, 14.06321907043457, 16.367246627807617, 32.93799591064453, 15.872761726379395, 14.444437026977539, 109.6595687866211, 36.309959411621094, 37.09142303466797, 37.722312927246094, 480.6309814453125, 16.176923751831055, 12.493497848510742, 72.48567962646484, 18.401371002197266, 34.41492462158203, 14.89440631866455, 8.929665565490723, 21.79900550842285, 15.352792739868164, 12.775267601013184, 18.014751434326172, 10.050305366516113]
--- SMALL SNIPPETS: ---
RES W START: {'perplexity': [2.6583647714047857e+31, 4.2388171843120376e+26, 3.074649231109679e+28]}
RES NO START: {'perplexity': [2.605767698770934e+26, 7.023913198776316e+20, 1.3889064307219916e+20]}`
- for reference, this is the list I get when using my previous method on the same data:
[11.10894775390625, 159.0147705078125, 64.53162384033203]
I thought that the perplexity metric should output the average perplexity value of all the strings that it gets as input (not a perplexity value per string, as the new version does). |
I support this change from Emi. If we have a perplexity function that loads GPT2 and then returns an average over all of the strings, then it is impossible to get multiple perplexities of a batch of strings efficiently. If we have this new perplexity function that is built for batching, then it is possible to get a batch of perplexities efficiently and you can still compute the average efficiently afterwards. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Emi
Looks like we are almost there, thanks for this! Here are some notes:
- let's not add a bos token if there is none. If we do this, then the model won't know what that token means because it wasn't trained with one. Just throw an assertion error if add_start_token is true and the model doesn't have a bos token
- let's remove stride as an argument because it doesn't do anything now. We could also add in the documentation what happens now when the text is too long (I think it's truncated?).
- I noticed that every time I run the compute funtion, I get different results even if I have the same arguments. Is the model being randomly initialized? This could explain the numbers that are way too high
Thanks a lot for working on this @emibaylor @TristanThrush :) For consistency with the other metrics, I think it's nice if we return the mean perplexity. Though I agree that having the separate perplexities per sample can also be useful. What do you think about returning both ? return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} we're also doing this for the COMET metric. |
Thanks! Sounds great to me. |
|
||
model = AutoModelForCausalLM.from_pretrained(model_id) | ||
model = model.to(device) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="<PAD>") | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To update you all, @TristanThrush, @lhoestq, @sashavor:
The tokenization/special tokens seems to have been the main culprit (@TristanThrush you were right!)
What I've done as a (temporary and probably not optimal) fix for now is to use one of the other existing special tokens so as to not have to expand the vocabulary. In theory, this shouldn't mess up any of the computations (as the attention mask masks out all special tokens anyways, I believe), and indeed gets the outputs to pretty much match (at least to a few decimal places) what I was getting before.
I've also removed stride as per @TristanThrush 's comment, as well as changed it to return both the mean_perplexity and the list of perplexities, as per @lhoestq 's suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool thanks ! It looks all good, I just left a comment about the testing script
from datasets import load_dataset | ||
|
||
|
||
def main(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you format this file as a pytest
script ? This way we will be able to include this test in a CI if we want
To do so:
- you just have to remove the
if name == "main":\n main()
at the end, and rename themain()
function to something liketest_perplexity()
- remove the print statements, you can use comments or logging if you want
- use
assert
statements ornumpy.testing.assert_allclose
to make the the resulting values are the correct ones (you can hardcode the expected values directly in the script)
Then you should be able to do
pytest metrics/perplexity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was mainly using this file to make the examples, and was leaning towards deleting it before merging. Do you think it would be useful to keep it? Or is it ok for me to get rid of it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh ok I thought it was a test script sorry. You can remove it yes !
The CI fail is unrelated to your PR and has been fixed on master, feel free to merge the master branch into your PR to fix the CI ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks so much!
This PR makes necessary changes to perplexity such that:
Issues:
inf
, which is not very useful (see comment below for some of the output values).inf
Future:
stride
is not currently implemented here. I have some thoughts on how to make it happen with batching, but I think it would be better to get another set of eyes to look at any possible errors causing such large values now rather than later.