-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] How to preshard a model for tensor parallism #2379
Comments
Hi @lanking520 for loading this presharded version: import os
import torch
import deepspeed
from huggingface_hub import snapshot_download
model = "microsoft/bloom-deepspeed-inference-fp16"
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
repo_root = snapshot_download(model)
checkpoints_json = os.path.join(repo_root, "ds_inference_config.json")
model = deepspeed.init_inference(
model,
mp_size=world_size,
base_dir=repo_root,
dtype=torch.half,
checkpoint=checkpoints_json,
replace_method="auto",
replace_with_kernel_inject=True,
) @RezaYazdaniAminabadi would be able to give more details about creating presharded versions for other models. |
@mrwyattii thanks for the reply, I know how to load the model. But wondering more on how we can pre-shard the model. This also applies to other model like OPT models. Loading them in CPU and sharding take crazy long time. Maybe it could be done in one-off, save the sharded one to disk and next time skip loading them on CPU again. Would appreciate If there is any instruction that can share to save the sharded model with DeepSpeed |
Here is how I did it (if I correctly understand the issue): say 5-GB shards
this step will require 2x model-size cpu memory and then a bit more. and then use the resulting model like so:
|
Nope, this only shards the model based on its size. It doesn't tell which part of the model goes to which GPU in DeepSpeed. Tensor Parallelism means vertical sharding, where models are defined in TP and also sizing. In fact, if you look at the INT8 BLOOM model, each GPU has 4 vertical shards (TP4) and distributed for 8 GPUs. This cannot be done without using DeepSpeed itselves. |
@RezaYazdaniAminabadi I found your PR here is really helpful: #2132 |
Hi @lanking520, Thanks for your interest in this part. |
@lanking520 can you try this again? #2547 should address your issue |
nice, will test them today |
@jeffra @RezaYazdaniAminabadi do you happened to have a code sample for OPT model? |
@lanking520 Here is a small code sample for saving a sharded OPT model: import os
import torch
import transformers
import deepspeed
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--save_ckpt", action="store_true")
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--world_size", type=int, default=int(os.getenv("WORLD_SIZE", 1)))
args = parser.parse_args()
model_name = "facebook/opt-1.3b"
inputs = ["DeepSpeed is the"]
ckpt_path = "/data/sharded-opt-model/"
inf_config = {
"replace_with_kernel_inject": True,
"dtype": torch.float16,
"replace_method": "auto",
"enable_cuda_graph": False,
"tensor_parallel": {"tp_size": args.world_size},
}
config = transformers.AutoConfig.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
if args.save_ckpt:
inf_config["save_mp_checkpoint_path"] = ckpt_path
model = transformers.AutoModelForCausalLM.from_config(
config, torch_dtype=torch.float16
)
else:
inf_config["checkpoint"] = os.path.join(ckpt_path, "ds_inference_config.json")
with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
model = transformers.AutoModelForCausalLM.from_config(
config, torch_dtype=torch.float16
)
model = deepspeed.init_inference(model, config=inf_config)
if not args.save_ckpt:
tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
for t in tokens:
if torch.is_tensor(tokens[t]):
tokens[t] = tokens[t].to(f"cuda:{args.local_rank}")
greedy_output = model.generate(**tokens)
outputs = tokenizer.batch_decode(greedy_output, skip_special_tokens=True)
if args.local_rank == 0:
print(outputs) To save the checkpoint: Verify the sharded checkpoints were created: venv ❯ ls /data/sharded-opt-model
ds_inference_config.json tp_00_00.pt tp_00_02.pt tp_00_04.pt tp_00_06.pt tp_01_00.pt tp_01_02.pt tp_01_04.pt tp_01_06.pt
non-tp.pt tp_00_01.pt tp_00_03.pt tp_00_05.pt tp_00_07.pt tp_01_01.pt tp_01_03.pt tp_01_05.pt tp_01_07.pt Load the sharded checkpoint and run a query: |
@mrwyattii. Thanks for sharing. I am able to get this file. But it is completely not loadable to get it back. I tried bloom way but doesn't work into that way |
This is the standard way loading back a bloom model. This is not working for OPT |
@lanking520 I've updated the code in my previous comment to include loading the model with |
Also check out the example we have in the DeepSpeedExamples repo: #2547 has some information about how to run that script. |
When i used example.py, the model was saved well, but there was a problem loading the model. In ckpt_path,
(+)
In : ["Deepspeed is the"] In order for the model to be loaded normally, it must be given |
@slrsnpdla I'm able to reproduce the bad output you are seeing. This appears to only happen for some models. I've extended the unit tests for sharded checkpoints to include a correctness test in #2643. I'm seeing failures for gpt-neo and gpt-j models as well. |
Also reproducible from my end, tested OPT, GPT-Neo and GPT-J is kind of broken. |
Hi @lanking520, I am working on resolving this issue. I will let you know once I have the solution tested completely. |
Hi @RezaYazdaniAminabadi or @jeffra do we have any weekly/monthly community meeting for DeepSpeed? I would like to attend if there is one. |
Hi @lanking520, I have verified several model architectures with this PR and using this test-suite. All works fine on my side. Could you please try this on your end and see if the issue is resolved?
Regarding your last question, I don't think there is any meeting currently set up. But, I think this is a great idea. I let @jeffra or @tjruwase chime in here, and we might be able to set up something. Thanks, |
Will start testing this week. Thanks @RezaYazdaniAminabadi. |
Hi @mrwyattii , Thank you for this example! I have 2 questiones regarding your code.
This line My first question is: Is there a way to load the model in RAM only 1 time instead of 4 times to save RAM and still save preshard checkpoints? My second question is (might be a stupid one): When you run |
|
When we took your PR and tested it with your test suite. We got the following error for both OPT 1.3B and GPTJ 6B.
|
@RezaYazdaniAminabadi @lekurile did you get any luck to not seeing the above error? ^^ |
Hi @lanking520 , I think the error that you're seeing maybe comes from the flag |
@Wenhan-Tan this is the way why Meta tensor is here:
This is a must have step in order to use DeepSpeed checkpoint loading. You need to have a placeholder at the place to build the full model. I think it still stay true with @RezaYazdaniAminabadi 's commit. We need to send the model body in to get full weights equipped by DS. For some reason, the checkpoint weight was not taken by DeepSpeed |
@lanking520 You're right! I also ran the script, and the checkpoint weight loading worked on my machine. It makes me wonder whether if the weight saving is successful on his @sindhuvahinis machine. |
@Wenhan-Tan @RezaYazdaniAminabadi I verified it again with larger instance with 8GPUs. Did you test for GPTJ 6b model? I was able to generate checkpoints. I tried But loading back the generated checkpoints throws the below error.
To reproduce I am using the deepspeed test-suite
|
Hi @sindhuvahinis , I didn't try GPTJ but it did work on GPTNeox for me |
@mrwyattii @slrsnpdla @lanking520, Wrt the bad outputs, in the provided script at the presharding step perhaps there could be an issue with the usage of Please see if the following example makes sense. $ diff -u example-original.py example-modified.py
--- example-original.py
+++ example-modified.py
@@ -26,8 +26,8 @@
if args.save_ckpt:
inf_config["save_mp_checkpoint_path"] = ckpt_path
- model = transformers.AutoModelForCausalLM.from_config(
- config, torch_dtype=torch.float16
+ model = transformers.AutoModelForCausalLM.from_pretrained(
+ pretrained_model_name_or_path=model_name, torch_dtype=torch.float16
)
else:
inf_config["checkpoint"] = os.path.join(ckpt_path, "ds_inference_config.json") Using the original script to do sharding and then verify that the outputs are not correct: $ deepspeed --num_gpus 1 example-original.py --save_ckpt
(..)
Loading extension module transformer_inference...
Time to load transformer_inference op: 0.0009701251983642578 seconds
Saving tp-sharded checkpoints
[] [INFO] [launch.py:350:main] Process 1195148 exits successfully.
$ ls /data/sharded-opt-model/
ds_inference_config.json non-tp.pt tp_00_00.pt tp_00_01.pt tp_00_02.pt tp_00_03.pt tp_00_04.pt tp_00_05.pt tp_00_06.pt tp_00_07.pt
$ deepspeed --num_gpus 1 example-original.py
(..)
Requested memory: 0.375000 (GigaBytes)
Setting maximum total tokens (input + output) to 1024
------------------------------------------------------
['DeepSpeed is the grant grant ► grantrecentElsaElsa grantrecent Observer simplify IndigoElsaElsa Indigo']
[] [INFO] [launch.py:350:main] Process 1200105 exits successfully. On my end the modified script produces the expected reproducible results: $ rm -f /data/sharded-opt-model/*
$ deepspeed --num_gpus 1 example-modified.py --save_ckpt
(..)
Loading extension module transformer_inference...
Time to load transformer_inference op: 0.0009124279022216797 seconds
Saving tp-sharded checkpoints
[] [INFO] [launch.py:350:main] Process 1209907 exits successfully.
$ ls /data/sharded-opt-model/
ds_inference_config.json non-tp.pt tp_00_00.pt tp_00_01.pt tp_00_02.pt tp_00_03.pt tp_00_04.pt tp_00_05.pt tp_00_06.pt tp_00_07.pt
$ deepspeed --num_gpus 1 example-modified.py
(..)
Requested memory: 0.375000 (GigaBytes)
Setting maximum total tokens (input + output) to 1024
------------------------------------------------------
["DeepSpeed is the best.\nI've been using DeepSpeed for a while now. It"]
[] [INFO] [launch.py:350:main] Process 1211738 exits successfully. DeepSpeed version: 0.8.0 |
I am trying to run this with GPTNeoX 20B. I managed to save a sharded model with the example above with a single GPU A6000. Also inference looks correct. Though, when I try to save the sharded model using 4 A4000, the script takes long time (I also need a very big amount of CPU RAM as discussed above). The script has been running for more than 2 hours. GPUs are at 100% usage. Did anyone have the same experience with this? |
Hi @simoroma , for GPTNeoX 20B, it took me about a little less than 2 hours to save the sharded model. Did you leave the script running and save the sharded model successfully? If not, it could be that you don't have enough RAM or GPU memory. |
Thanks @Wenhan-Tan it ran for 4 hours and I stopped it. I have 4 A4000 and 200 GB of RAM. The code first correctly uses about 165 GB of RAM. Then starts to use GPUs at 100% using about 10 GB VRAM each. Though it does not get to save the sharded model. If I run the same code on a single A6000 it works correctly. I can also use it for inference on a single GPU. If I use the sharded model saved with a single GPU and move it to the pod with 4 GPU, the model gets correctly splitted on 4 GPUs at inference time. About 10 GB VRAM each. But I never get inference results. The code gets stuck and GPU are at 100% utilization. |
Hi @simoroma , I never tried loading 1-GPU sharded model on 2 GPUs using DeepSpeed. I ran the same script on 2 A100-40GB GPUs and both saving and inference work for GPTNeox 20B. If your 4 A4000s have more GPU memory than your single A6000, then this is probably a bug. |
The sharded model out of a single A6000 could be loaded correctly with 4 RTX 3090 GPUs. Though results were gibberish. I don't know why but I had no issues saving and then loading the code with 4 RTX 3090. The script was getting stuck with 4 A4000. |
Hi, thanks for the above suggestions. I managed to pre-shard OPT model. Now I wanted to pre-shard a T5 model (like T0 or Flan-T5 )but I failed. Here is my code:
And it reports errors:
Is there anyone who could help me to solve this problem? |
I met the same problem, but i have no idea how to solve it |
Hi, I followed this tutorial and succeeded. The code is:
|
the same |
Have you solved this problem? When I use deepspeed inference, run |
@kevinuserdd Hi, what you're seeing is exactly what it's supposed to be. I looked up all over the internet and this looks like is a non-solvable problem. You have to have X times of Model Size RAM in order to parallelize your model into X number of GPUs. If you have any ideas or found any solutions, please let me know. This will be incredibly helpful for people who do not have infinite RAM. |
Hi All, for DeepSpeed 0.8.3. The following models could be applied on HuggingFace checkpoint with low CPU memory:
We have CI run nightly to verify. For Model support native sharding on DeepSpeed
We have tested all these fours and working. If you need help, please feel free to send us an email and we can work with you a solution on sharding. |
But I have four A100 graphics cards, and the graphics memory is sufficient. When I test the Bloom model, it has the same parameter size as Chatglm. I use Bloom for inference, and the graphics memory of multiple graphics cards will be shared; But using chatglm won't do it. I doubt if it's related to the model itself, and I can't parallelize the model |
@kevinuserdd Hi, sorry I read your response wrong. I only got more RAM when I want to parallelize. The increased RAM is gone when I successfully parallelized my model into GPUs. And my total GPU memory wasn't increased as RAM did. Maybe chatglm is not fully supported yet for parallelization. |
@lanking520 Is Llama 7B/65B supported? |
Yes |
Hi @mrwyattii , would it be possible to use this script for autoTP rather than kernel_injection ?
|
Currently we are trying to run inference with pretrained BLOOM model. However, the loading takes very long due to DeepSpeed sharding in runtime. Since there is a pre-sharded version of BLOOM:
Is that possible to share the script that done the above job or any guidance on how to preshard to speed up loading experience?
The text was updated successfully, but these errors were encountered: