-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
masking compatible with fullgraph compile #91
base: master
Are you sure you want to change the base?
masking compatible with fullgraph compile #91
Conversation
ah yea, that does look a bit confusing, needs a tiny bit more work do you think you can try fitting all the logic into one function, |
we can reassess after your refactor |
@theAdamColton have you tried the updated LFQ? curious how you got good results on the previous broken one |
With the previous LFQ i set entropy loss and commit loss to very low weights and it did actually work. |
I've also been experimenting with the entropy loss from maskgit, it does it slightly different than the current lfq code here. The one there seems to work pretty well |
Also, this is a different issue, but I think here where the entropy is computed, maybe it should use F.log_softmax to separately compute the log probs from the distances, instead of taking the log of the probs to get the log probs. |
@theAdamColton how is that different? can you show me in code? |
@lucidrains
this is what I mean:
I don't know if it would make a difference, but it's what the maskgit code does. Using log_softmax might fix precision issues from the pytorch log_softmax doc |
I think the numerical stability is accounted for by the epsilon in the log I have in the file, but do let me know otherwise |
anyways, I've put in my hours today, happy Saturday! See if you can get that mask to go into the masked mean fn and I'll review it again |
d9967be
to
34b9e97
Compare
0a931ca
to
d92c330
Compare
this adds some slightly confusing masking code, but improves speed by 3x by making the shape of intermediate tensors non-dynamic. The masked_mean code is equivalent, up to fp precision, with the old code that used tensor indexing
Before, using LFQ with masking was not compatible with torch.compile with fullgraph=True or with dynamic=False. It was compatible with plain torch.compile, but the masked tensor indexing caused graph breaks
I added an example that uses masked sequences, to make sure it works properly
I did a benchmark. I ran the example code that uses masking. This was on a 3090 GPU
The speedup might be worth the extra confusingness in the code