-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(inference): remove weight norm on inference so mps
backend will work without CPU fallback
#783
Conversation
… work without CPU fallback
mps
backend will work without CPU fallback
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #783 +/- ##
==========================================
- Coverage 19.43% 19.38% -0.06%
==========================================
Files 39 39
Lines 3452 3467 +15
Branches 484 489 +5
==========================================
+ Hits 671 672 +1
- Misses 2763 2777 +14
Partials 18 18
☔ View full report in Codecov by Sentry. |
Will this change work as before for other devices? If you don't have great confidence, although the code will be messy, I would like you to add |
I tried it with cpu backend as well and I'm pretty confident it'll work with no issues in CUDA either, but I haven't tested it on CUDA as I'm traveling and don't have access to a cuda machine at the moment |
I don't think there are that many MPS users compared to CUDA users, so I would like to wait patiently until someone reports it works or you could enclose the changes in if statements. (Sorry but I don't have the energy to test this.) |
I am ashamed to say that I did not understand about weight norm. I guess it should be excluded when inference as you implied. |
@allcontributors add shenberg userTesting, ideas, code |
I've put up a pull request to add @shenberg! 🎉 |
Description of change
Added code to remove weight norms from
hubert
andnet_g
(inference only fornet_g
), because at least as of now (PyTorch 2.0.1), themps
backend in pytorch does not support weight norm.Pull-Request Checklist
main
branchpre-commit run -a
passes with this change or ci passespoetry run pytest
passes with this change or ci passes