-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPT-J-6B #13022
GPT-J-6B #13022
Conversation
(need to change to torch ops later)
GPT-J fixes
@patrickvonplaten The script currently within Mesh Transformer JAX builds a split checkpoint in the format used by the @finetuneanon fork of |
If we feel that the solution I outlined is the best solution, I can put that plan into action and update the repo on Model Hub. Maybe vote 👍/👎 on this? |
- Remove unused imports and variables - Clean up docstrings - Port experimental parallelization code from GPT-2 into GPT-J
@g-karthik I have ported over the experimental parallelization code from GPT-2 into GPT-J. I wouldn't personally recommend using it to anyone unless they need to, but it should work in a pinch when a better solution is unavailable. (I note that this bears no resemblance to the implementation of model pararallelism in Mesh Transformer JAX and should not be thought as an equivalent implementation or replacement for that implementation.) |
@EricHallahan: Thank you very much for the awesome work on the issue!! there is one thing to remark regarding the 16 VRAM configuration: As far as I can tell, even with the floating point revision set correctly (and also when only the local files are fp16, local_files_only=True), the model will still be loaded in float32 until model.half() is called, thus requiring the 23G RAM to be available before the model.half() and before the model.to(device) is called. By extension this means, that a text generation pipeline cannot be loaded with device=0 in a VRAM<23G setting, as .half() isn’t called automatically anywhere. In this case the model must be loaded, .halved(), and then passed to the pipeline via the argument. Correct me if this observation is wrong. Is there any way of loading/moving the model iteratively to GPU so that the 23G RAM limitation can be circumvented, similar as done in @finetuneanon repository? (Probably out of scope for this very PR, but likely a problem for larger models in general in the future). Presumably this can be done using the state dict, but I‘m not deep enough into the inner working to judge this. Also tagging @patrickvonplaten |
@oborchers Yes, we have had multiple people test this via Colab and they have reported the same issue. I have verified that choosing the The multi-part loading scheme used by the @finetuneanon fork was purposefully built to bypass the suboptimal way that |
Thanks a lot for the detailed message here. What we currently do in
=> we have this logic mainly for models like BERT for which one would load the "base"-model and than add a randomely initialized head for the specific downstream task. It becomes quite clear however that this make less sense for GPT-like models. |
FYI: Model acceleration for GPT-J via deepspeed in the making: microsoft/DeepSpeed#1332 |
ImportantWe will merge GPT-J now to master. Note that at the moment GPT-J cannot be run on a free google colab GPU since loading the model weights in fp16 requires too much CPU RAM. At the moment one needs at least 26 GB of CPU RAM in order to load GPT-J in fp16-precision. We are working on fixing the problem so that in a next step one can load GPT-J with just 12 GB of CPU of RAM. |
I feel the need to reiterate that there remains a redundant |
Thanks for letting me know - is it ok if I put the "correct fp32 weigts" in the main branch for now? Or do you prefer "fp16"? Both are fine with us :-) Think we can't completely delete the "main" branch for now (cc @LysandreJik) |
That is my understanding.
Putting the single precision weights in |
+1 this |
Ok great - just uploaded the correct weigths to "main". You can see that the sha256 between "main": https://huggingface.co/EleutherAI/gpt-j-6B/blob/main/pytorch_model.bin and "float32" https://huggingface.co/EleutherAI/gpt-j-6B/blob/float32/pytorch_model.bin match now :-) |
What does this PR do?
Introduces the long awaited
GPT J
model class to HuggingFace! Concurrently with this PR being merged I will make a GPT J 6B checkpoint public on the EleutherAI HF page for people to use. The model has been evaluated as being within error tolerances of the GPT J 6B model we released in Jax two months ago.@patil-suraj was very helpful in assisting me to understand HF philosophy and how to make this PR most in line with the rest of the codebase. Other than that, the major design consideration was to make the configs compatible with GPT-2 rather than GPT-Neo. GPT-Neo has some usability limitations due to its configs having names unrelated to GPT-2’s (see #12183 for details). Given those problems and my hope that GPT-Neo will have it’s configs updated in the future, it seemed like a clear choice to align GPT J with GPT-2.
Shout outs to @finetuneanon whose implementation this one is based off of, as well as @kumuruz for assistence optimizing and debugging.
Supersedes #12243 #13010 #13022
Closes #12098
Before submitting
Pull Request section?
to it if that's the case.
It was discussed in Slack with @patil-suraj
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?