Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perplexity Speedup #4108

Merged
merged 20 commits into from
Apr 20, 2022
Merged

Perplexity Speedup #4108

merged 20 commits into from
Apr 20, 2022

Conversation

emibaylor
Copy link
Contributor

@emibaylor emibaylor commented Apr 6, 2022

This PR makes necessary changes to perplexity such that:

  • it runs much faster (via batching)
  • it throws an error when input is empty, or when input is one word without token
  • it adds the option to add a token

Issues:

  • The values returned are extremely high, and I'm worried they aren't correct. Even if they are correct, they are sometimes returned as inf, which is not very useful (see comment below for some of the output values).
    • If the values are not correct, can you help me find the error?
    • If the values are correct, it might be worth it to measure something like perplexity per word, which would allow us to get actual values for the larger perplexities, instead of just inf

Future:

  • stride is not currently implemented here. I have some thoughts on how to make it happen with batching, but I think it would be better to get another set of eyes to look at any possible errors causing such large values now rather than later.

@sashavor
Copy link
Contributor

sashavor commented Apr 6, 2022

WRT the high values, can you add some unit tests with some [string, model] pairs and their resulting perplexity code, and @TristanThrush can run the same pairs through his version of the code?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 6, 2022

The documentation is not available anymore as the PR was closed or merged.

Comment on lines 5 to 32
def main():
print("\n\n\n------------ WIKITEXT: ------------")
input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", keep_in_memory=False, split="test")["text"][:50]
print("input len with empty strings:", len(input_texts))
input_texts = [s for s in input_texts if s != ""]
print("input len without empty strings:", len(input_texts))

perplexity = datasets.load_metric("../perplexity")

results_w_start = perplexity.compute(input_texts=input_texts, model_id="gpt2", device="gpu", add_start_token=True)
print("\n\nRES W START:", results_w_start)

results_no_start = perplexity.compute(
input_texts=input_texts, model_id="gpt2", device="gpu", add_start_token=False
)
print("\n\nRES NO START:", results_no_start)

print("\n\n\n------------ SMALL SNIPPETS: ------------")
input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"]
print(input_texts)

results_w_start = perplexity.compute(input_texts=input_texts, model_id="gpt2", device="gpu", add_start_token=True)
print("\n\nRES W START:", results_w_start)

results_no_start = perplexity.compute(
input_texts=input_texts, model_id="gpt2", device="gpu", add_start_token=False
)
print("\n\nRES NO START:", results_no_start)
Copy link
Contributor Author

@emibaylor emibaylor Apr 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, the results I get from this code are below. Note that they are all extremely high, and many are inf:

`--- WIKITEXT: ---
RES W START: {'perplexity': [5.26044112490735e+33, 1754766870839296.0, 4.851017998674971e+21, 1.6137719812847106e+33, 1.123557053605896e+35, 1.4592395342302413e+17, 2.2242306063538344e+21, 1.6256261294134198e+37, 5.6552604843696494e+23, 5.283058871869749e+23, 1.362616559062237e+31, 3.2894723458301157e+30, inf, 2.8323108695319055e+32, 4.517404432663069e+36, 6.06021648422544e+21, 2.7673238221998803e+25, inf, 1.4324680994914391e+32, 1.0707699658437287e+29, 1.3665097609664754e+23, 2.646539983794732e+26, 1.1760327816258613e+21, 5.623264614959543e+28, 8.202277198380542e+22, 1.6075579260687217e+30, 7.034012837768928e+25]}

RES NO START: {'perplexity': [6.347511649255629e+27, 7.746963798754648e+35, 5.412155643337211e+36, 7.443727567431987e+26, 7.610168923774734e+33, 6.160369930664121e+35, 5.744005654242058e+36, 6.396653209350697e+32, inf, inf, 4.597405287035285e+27, 1.5910406783685568e+27, 1.2694701550401124e+26, 9.076024979217206e+25, 2.580195144338176e+31, inf, 1.8766434173654982e+35, 1.4048184593730499e+29, inf, 5.501222136759501e+28, inf, inf, 1.6581890056188304e+38, inf, inf, inf, inf]}

  • for reference, this is the list I get when using my previous method on the same data:
    [110.6924819946289, 16.05736541748047, 13.696072578430176, 146.76963806152344, 34.011749267578125, 14.06321907043457, 16.367246627807617, 32.93799591064453, 15.872761726379395, 14.444437026977539, 109.6595687866211, 36.309959411621094, 37.09142303466797, 37.722312927246094, 480.6309814453125, 16.176923751831055, 12.493497848510742, 72.48567962646484, 18.401371002197266, 34.41492462158203, 14.89440631866455, 8.929665565490723, 21.79900550842285, 15.352792739868164, 12.775267601013184, 18.014751434326172, 10.050305366516113]

--- SMALL SNIPPETS: ---
RES W START: {'perplexity': [2.6583647714047857e+31, 4.2388171843120376e+26, 3.074649231109679e+28]}

RES NO START: {'perplexity': [2.605767698770934e+26, 7.023913198776316e+20, 1.3889064307219916e+20]}`

  • for reference, this is the list I get when using my previous method on the same data:
    [11.10894775390625, 159.0147705078125, 64.53162384033203]

@sashavor @TristanThrush

@sashavor
Copy link
Contributor

sashavor commented Apr 6, 2022

I thought that the perplexity metric should output the average perplexity value of all the strings that it gets as input (not a perplexity value per string, as the new version does).
@lhoestq , @TristanThrush thoughts?

@TristanThrush
Copy link
Contributor

I thought that the perplexity metric should output the average perplexity value of all the strings that it gets as input (not a perplexity value per string, as the new version does). @lhoestq , @TristanThrush thoughts?

I support this change from Emi. If we have a perplexity function that loads GPT2 and then returns an average over all of the strings, then it is impossible to get multiple perplexities of a batch of strings efficiently. If we have this new perplexity function that is built for batching, then it is possible to get a batch of perplexities efficiently and you can still compute the average efficiently afterwards.

Copy link
Contributor

@TristanThrush TristanThrush left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Emi

Looks like we are almost there, thanks for this! Here are some notes:

  • let's not add a bos token if there is none. If we do this, then the model won't know what that token means because it wasn't trained with one. Just throw an assertion error if add_start_token is true and the model doesn't have a bos token
  • let's remove stride as an argument because it doesn't do anything now. We could also add in the documentation what happens now when the text is too long (I think it's truncated?).
  • I noticed that every time I run the compute funtion, I get different results even if I have the same arguments. Is the model being randomly initialized? This could explain the numbers that are way too high

@lhoestq
Copy link
Member

lhoestq commented Apr 7, 2022

Thanks a lot for working on this @emibaylor @TristanThrush :)

For consistency with the other metrics, I think it's nice if we return the mean perplexity. Though I agree that having the separate perplexities per sample can also be useful. What do you think about returning both ?

return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}

we're also doing this for the COMET metric.

@TristanThrush
Copy link
Contributor

Thanks a lot for working on this @emibaylor @TristanThrush :)

For consistency with the other metrics, I think it's nice if we return the mean perplexity. Though I agree that having the separate perplexities per sample can also be useful. What do you think about returning both ?

return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}

we're also doing this for the COMET metric.

Thanks! Sounds great to me.


model = AutoModelForCausalLM.from_pretrained(model_id)
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="<PAD>")
tokenizer = AutoTokenizer.from_pretrained(model_id)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To update you all, @TristanThrush, @lhoestq, @sashavor:
The tokenization/special tokens seems to have been the main culprit (@TristanThrush you were right!)

What I've done as a (temporary and probably not optimal) fix for now is to use one of the other existing special tokens so as to not have to expand the vocabulary. In theory, this shouldn't mess up any of the computations (as the attention mask masks out all special tokens anyways, I believe), and indeed gets the outputs to pretty much match (at least to a few decimal places) what I was getting before.

I've also removed stride as per @TristanThrush 's comment, as well as changed it to return both the mean_perplexity and the list of perplexities, as per @lhoestq 's suggestion.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool thanks ! It looks all good, I just left a comment about the testing script

from datasets import load_dataset


def main():
Copy link
Member

@lhoestq lhoestq Apr 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you format this file as a pytest script ? This way we will be able to include this test in a CI if we want

To do so:

  • you just have to remove the if name == "main":\n main() at the end, and rename the main() function to something like test_perplexity()
  • remove the print statements, you can use comments or logging if you want
  • use assert statements or numpy.testing.assert_allclose to make the the resulting values are the correct ones (you can hardcode the expected values directly in the script)

Then you should be able to do

pytest metrics/perplexity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was mainly using this file to make the examples, and was leaning towards deleting it before merging. Do you think it would be useful to keep it? Or is it ok for me to get rid of it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok I thought it was a test script sorry. You can remove it yes !

@lhoestq
Copy link
Member

lhoestq commented Apr 13, 2022

The CI fail is unrelated to your PR and has been fixed on master, feel free to merge the master branch into your PR to fix the CI ;)

Copy link
Contributor

@TristanThrush TristanThrush left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks so much!

@emibaylor emibaylor merged commit e90a7d4 into master Apr 20, 2022
@emibaylor emibaylor deleted the perplexity-speedup branch April 20, 2022 12:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants