Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

where can I download the 176B checkpoint in deepspeed format? #319

Open
xuyifan-0731 opened this issue Jul 23, 2022 · 17 comments
Open

where can I download the 176B checkpoint in deepspeed format? #319

xuyifan-0731 opened this issue Jul 23, 2022 · 17 comments

Comments

@xuyifan-0731
Copy link

Hello, I used the 176B checkpoint of bloom-176B(https://huggingface.co/bigscience/bloom), but had problem in resolving the layer files. Should I download different type of checkpoint to use in this repo, or what code should I use to run the evalution based on bloom-176B checkpoint/
Thanks a lot.

@xuyifan-0731
Copy link
Author

@stas00

@stas00
Copy link
Contributor

stas00 commented Jul 23, 2022

What are you after the bf16 weights split across TPs?

or the optim states - that's 2.3TB of data!

I don't know what: "had problem in resolving the layer files" means and how you were loading the model - it works with just AutoModel.from_pretrained("bigscience/bloom")

@xuyifan-0731
Copy link
Author

I used this:
`CHECKPOINT_PATH="/mnt/yrfs/aohan/checkpoints/bloom"
VARIANT="bloom-176B"

TP_SIZE=1
PP_SIZE=1

EVAL_MICRO_BATCH_SIZE=6
SEQ_LEN=2048

MEGATRON_REQUIRED_ARGS="
--num-layers -1
--hidden-size -1
--num-attention-heads -1
--seq-length -1
--max-position-embeddings -1
"

CMD="./tasks/eval_harness/evaluate.py
--load $CHECKPOINT_PATH
--results_path $VARIANT-results.json
--tensor-model-parallel-size $TP_SIZE
--pipeline-model-parallel-size $PP_SIZE
--micro-batch-size $EVAL_MICRO_BATCH_SIZE
--no-load-optim
--no-load-rng
--inference
--deepspeed
--deepspeed_config ds_config.json
--seq-length $SEQ_LEN
--adaptive_seq_len
--eval_fp32
--task_list copa
$MEGATRON_REQUIRED_ARGS
"

N_GPUS=1
LAUNCHER="deepspeed --num_gpus $N_GPUS"
echo $LAUNCHER $CMD

$LAUNCHER $CMD`

I try to evaluate the 176B model, but I got this error:
dir /mnt/yrfs/aohan/checkpoints/bloom self.layer_files ['/mnt/yrfs/aohan/checkpoints/bloom/pytorch_model.bin.index.json'] LAYER_FILE_PREFIX layer_ self.original_tp_degree 0 Traceback (most recent call last): File "./tasks/eval_harness/evaluate.py", line 461, in <module> main() File "./tasks/eval_harness/evaluate.py", line 416, in main model = load_ds_checkpoint_and_setup_megatron(args) File "./tasks/eval_harness/evaluate.py", line 297, in load_ds_checkpoint_and_setup_megatron ds_checkpoint = DeepSpeedCheckpoint(args.load, File "/root/bloom/Megatron-DeepSpeed-main/tools/convert_checkpoint/deepspeed_checkpoint.py", line 39, in __init__ self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree ZeroDivisionError: integer division or modulo by zero
I think it means i can not use this checkpont(from https://huggingface.co/bigscience/bloom), because the layer files are not existed.

@xuyifan-0731
Copy link
Author

I would like to evaluate bloom-176B on MMLU, Big-Bench, FewCLUE

@stas00
Copy link
Contributor

stas00 commented Jul 23, 2022

ok, so you do want the original weights - got it - we have a script that converts from Meg-DS to HF, but not the other way around.

I will ask if we can release the Meg-DS weights on the hub.

@xuyifan-0731
Copy link
Author

thanks a lot. It looks like I can only use HF checkpoint now.
But I face a problem while using it.
I have a machine with 8*A100, and i would like to use bloom-176B to generate text, and apply evaluation on different datasets. I use the code in bigscience-workshop/bigscience/sortval repo,bigscience/evaluation/generation/generate.py. But this code only allows the model to run on a graphics card.
Apperently single A100 is not enough:
RuntimeError: CUDA out of memory. Tried to allocate 1.53 GiB (GPU 0; 79.17 GiB total capacity; 77.14 GiB already allocated; 1.20 GiB free; 77.14 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

How to use bloom-176B to generate or evaluate on Multi-graphics? Should I change the code in generate.py,or use other code?
I have tried use accelerate:

from transformers import AutoTokenizer, AutoModel
from transformers import AutoModelForCausalLM
import torch
from accelerate import Accelerator
def generate_from_text(model, text, tokenizer, max_length=200, greedy=False, top_k=0):
    input_ids = tokenizer.encode(text, return_tensors='pt').to("cuda:0")
    max_length = input_ids.size(-1) + max_length
    
    greedy_output = model.generate(
        input_ids.to('cuda:0'),
        max_length=max_length,
        do_sample=not greedy,
        top_k=None if greedy else top_k,
    )
    return tokenizer.decode(greedy_output[0], skip_special_tokens=True)

tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom")
print("load tokenizer")
accelerator = Accelerator()
device = accelerator.device
model = AutoModelForCausalLM.from_pretrained(
    arg.checkpoints,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    offload_folder="./offload",
)
model = accelerator.prepare(model)
print("load model")


text = 'hello'
output = generate_from_text(model, text, tokenizer, max_length=500, greedy=True, top_k=1)
print(output)

and it shows :

load tokenizer
Traceback (most recent call last):
File "/mnt/yrfs/xuyifan/bloom/bigscience-sorteval/eva.py", line 28, in
model = accelerator.prepare(model)
File "/mnt/yrfs/qinkai/miniconda3/envs/bloom12/lib/python3.9/site-packages/accelerate/accelerator.py", line 545, in prepare
result = tuple(self._prepare_one(obj, first_pass=True) for obj in args)
File "/mnt/yrfs/qinkai/miniconda3/envs/bloom12/lib/python3.9/site-packages/accelerate/accelerator.py", line 545, in
result = tuple(self._prepare_one(obj, first_pass=True) for obj in args)
File "/mnt/yrfs/qinkai/miniconda3/envs/bloom12/lib/python3.9/site-packages/accelerate/accelerator.py", line 441, in _prepare_one
return self.prepare_model(obj)
File "/mnt/yrfs/qinkai/miniconda3/envs/bloom12/lib/python3.9/site-packages/accelerate/accelerator.py", line 565, in prepare_model
model = model.to(self.device)
File "/mnt/yrfs/qinkai/miniconda3/envs/bloom12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 927, in to
return self._apply(convert)
File "/mnt/yrfs/qinkai/miniconda3/envs/bloom12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 579, in _apply
module._apply(fn)
File "/mnt/yrfs/qinkai/miniconda3/envs/bloom12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 579, in _apply
module._apply(fn)
File "/mnt/yrfs/qinkai/miniconda3/envs/bloom12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 579, in _apply
module._apply(fn)
[Previous line repeated 2 more times]
File "/mnt/yrfs/qinkai/miniconda3/envs/bloom12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 602, in _apply
param_applied = fn(param)
File "/mnt/yrfs/qinkai/miniconda3/envs/bloom12/lib/python3.9/site-packages/torch/nn/modules/module.py", line 925, in convert
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
RuntimeError: CUDA out of memory. Tried to allocate 1.53 GiB (GPU 0; 79.17 GiB total capacity; 77.14 GiB already allocated; 1.20 GiB free; 77.14 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I released same question on https://huggingface.co/bigscience/bloom/discussions/62

@stas00
Copy link
Contributor

stas00 commented Jul 24, 2022

Please see: #308

@stas00
Copy link
Contributor

stas00 commented Jul 25, 2022

The original Meg-DS checkpoint is here: https://huggingface.co/bigscience/bloom-optimizer-states

@mayank31398
Copy link
Collaborator

mayank31398 commented Aug 7, 2022

@xuyifanbupt if you are trying to deploy BLOOM 176B as a server deployment, you can find here

@asafkar
Copy link

asafkar commented Aug 10, 2022

Hi @stas00,
Besides the optimizer states, is there an equivalent checkpoint for inference that is in the Meg-DS format, which I can use with deepspeed inference, in order to run inference with Pipeline parallelism?
As far as I understand, currently with the HF weights, I would have to re-write the model so it could pipelined, and possibly need to write a custom script to load the weights to the altered model.

Is there an alternative method, or does a such a checkpoint exist publicly?

thanks

@mayank31398
Copy link
Collaborator

@asafkar , the ds inference script is compatible with HF checkpoints.

@asafkar
Copy link

asafkar commented Aug 10, 2022

@mayank31398 I was actually referring to the other way around -
i.e. I want to do ds inference, while using pipeline parallelism.

For that to work (if I understand correctly), I would have 2 options -

  1. Load the HF checkpoint, and then restructure it for the layers to be sequential (as done here). I'm not sure it's that easy, since the HF forward function would also have to be altered to fit this.
  2. Load the Meg-DS checkpoint into the BLOOM model which is instantiated as a GPTModelPipe , and then I can easily use pipeline parallel (although not sure how this would work with DS-inference.)

Not sure which is easier to do, and which one would actually work...

@stas00
Copy link
Contributor

stas00 commented Aug 10, 2022

is there an equivalent checkpoint for inference that is in the Meg-DS format,

https://huggingface.co/bigscience/bloom-optimizer-states is the full Meg-DS checkpoint.

edit: hmm, I think you're correct it's incomplete. I will push the rest of the files in - will take a while. I will update here when it's done.

it will appear here once uploaded: https://huggingface.co/bigscience/bloom-megatron-deepspeed

edit: uploaded

@stas00
Copy link
Contributor

stas00 commented Aug 11, 2022

@asafkar, so it looks like I created the new repo for nothing, the https://huggingface.co/bigscience/bloom-optimizer-states was already the full checkpoint. Why did you say it only had optim state files and not the rest of the files? They should be all there.

The listing is limited at the moment to just 50 entries so one can't see the remainder of the files on the hub, I made a request to ameliorate that.

Please confirm that https://huggingface.co/bigscience/bloom-optimizer-states has all you need and I will remove the second repo.

@mayank31398
Copy link
Collaborator

Wait @asafkar does DS-inference support pipeline parallelism?
I thought it was only tensor parallel for generation

@stas00
Copy link
Contributor

stas00 commented Aug 11, 2022

  • DS-Inference = TP
  • DS-ZeRO = TP-like
  • Accelerate = PP
  • Megatron-Deepspeed = TP+PP

(plus DP in all)

@asafkar
Copy link

asafkar commented Aug 11, 2022

@stas00 I haven't checked the https://huggingface.co/bigscience/bloom-optimizer-states repo yet, I was merely asking whether it will support what I'm trying to do - I'm sorry if I wasn't clear enough about it.
I will give it a shot and let you know.

@mayank31398 regarding DS-inference, https://arxiv.org/pdf/2207.00032.pdf clearly states that they support PP, so I thought it should be covered by the actual package as well. I don't recall actually seeing an example of PP with DS-inference, so I hope I'm not mistaken. Perhaps I will try that first with a toy model to make sure it is supported...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants