Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use finetuner.py to train t5-large model #17534

Closed
2 of 4 tasks
ZeyiLiao opened this issue Jun 3, 2022 · 18 comments
Closed
2 of 4 tasks

How to use finetuner.py to train t5-large model #17534

ZeyiLiao opened this issue Jun 3, 2022 · 18 comments
Labels

Comments

@ZeyiLiao
Copy link

ZeyiLiao commented Jun 3, 2022

System Info

- `transformers` version: 4.3.0.dev0
- Platform: Linux-4.15.0-177-generic-x86_64-with-glibc2.17
- Python version: 3.8.13
- PyTorch version (GPU?): 1.8.1+cu111 (True)
- Tensorflow version (GPU?): not installed (NA)
- Using GPU in script?: <Yes>
- Using distributed or parallel set-up in script?: <Yes>

Who can help?

@stas00

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Follow the steps here

git clone https://github.com/huggingface/transformers
cd transformers
git checkout 7e662e6
cd examples/seq2seq
wget https://cdn-datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz

pip install -r requirement.txt

cd ../..
pip install .

cd examples/seq2seq

pip install fairscale, deepspeed==0.3.10

#run script 1
export BS=16; rm -r output_dir; PYTHONPATH=../../src USE_TF=0 python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py --model_name_or_path t5-large --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000 --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 500 --n_train 2000 --n_val 500

Error trace1

Traceback (most recent call last):
  File "./finetune_trainer.py", line 367, in <module>
    main()
  File "./finetune_trainer.py", line 152, in main
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
  File "/home/zeyi/transformers/src/transformers/hf_argparser.py", line 52, in __init__
    self._add_dataclass_arguments(dtype)
  File "/home/zeyi/transformers/src/transformers/hf_argparser.py", line 93, in _add_dataclass_arguments
    elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
  File "/home/zeyi/.conda/envs/test/lib/python3.8/typing.py", line 774, in __subclasscheck__
    return issubclass(cls, self.__origin__)
TypeError: issubclass() arg 1 must be a class
Traceback (most recent call last):
  File "./finetune_trainer.py", line 367, in <module>
    main()
  File "./finetune_trainer.py", line 152, in main
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
  File "/home/zeyi/transformers/src/transformers/hf_argparser.py", line 52, in __init__
    self._add_dataclass_arguments(dtype)
  File "/home/zeyi/transformers/src/transformers/hf_argparser.py", line 93, in _add_dataclass_arguments
    elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
  File "/home/zeyi/.conda/envs/test/lib/python3.8/typing.py", line 774, in __subclasscheck__
    return issubclass(cls, self.__origin__)
TypeError: issubclass() arg 1 must be a class
Killing subprocess 69967
Killing subprocess 69968
Traceback (most recent call last):
  File "/home/zeyi/.conda/envs/test/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/zeyi/.conda/envs/test/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/zeyi/.conda/envs/test/lib/python3.8/site-packages/torch/distributed/launch.py", line 340, in <module>
    main()
  File "/home/zeyi/.conda/envs/test/lib/python3.8/site-packages/torch/distributed/launch.py", line 326, in main
    sigkill_handler(signal.SIGTERM, None)  # not coming back
  File "/home/zeyi/.conda/envs/test/lib/python3.8/site-packages/torch/distributed/launch.py", line 301, in sigkill_handler
    raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/zeyi/.conda/envs/test/bin/python', '-u', './finetune_trainer.py', '--local_rank=1', '--model_name_or_path', 't5-large', '--output_dir', 'output_dir', '--adam_eps', '1e-06', '--data_dir', 'wmt_en_ro', '--do_eval', '--do_train', '--evaluation_strategy=steps', '--freeze_embeds', '--label_smoothing', '0.1', '--learning_rate', '3e-5', '--logging_first_step', '--logging_steps', '1000', '--max_source_length', '128', '--max_target_length', '128', '--num_train_epochs', '1', '--overwrite_output_dir', '--per_device_eval_batch_size', '16', '--per_device_train_batch_size', '16', '--predict_with_generate', '--eval_steps', '25000', '--sortish_sampler', '--task', 'translation_en_to_ro', '--test_max_target_length', '128', '--val_max_target_length', '128', '--warmup_steps', '500', '--n_train', '2000', '--n_val', '500']' returned non-zero exit status 1.

#run script 2
export BS=16; rm -r output_dir; PYTHONPATH=../../src USE_TF=0 python -m torch.distributed.launch --nproc_per_node=2 ./run_seq2seq.py --model_name_or_path t5-large --output_dir output_dir --adam_eps 1e-06 --dataset_name wmt16 --dataset_config "ro-en" --do_eval --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000 --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 500 --n_train 2000 --n_val 500

#Error trace2

Traceback (most recent call last):
  File "./run_seq2seq.py", line 499, in <module>
    main()
  File "./run_seq2seq.py", line 212, in main
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
  File "/home/zeyi/transformers/src/transformers/hf_argparser.py", line 166, in parse_args_into_dataclasses
    raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
ValueError: Some specified arguments are not used by the HfArgumentParser: ['--freeze_embeds', '--test_max_target_length', '128', '--n_train', '2000', '--n_val', '500']
Traceback (most recent call last):
  File "./run_seq2seq.py", line 499, in <module>
    main()
  File "./run_seq2seq.py", line 212, in main
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
  File "/home/zeyi/transformers/src/transformers/hf_argparser.py", line 166, in parse_args_into_dataclasses
    raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
ValueError: Some specified arguments are not used by the HfArgumentParser: ['--freeze_embeds', '--test_max_target_length', '128', '--n_train', '2000', '--n_val', '500']
Killing subprocess 72522
Killing subprocess 72523
Traceback (most recent call last):
  File "/home/zeyi/.conda/envs/test/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/zeyi/.conda/envs/test/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/zeyi/.conda/envs/test/lib/python3.8/site-packages/torch/distributed/launch.py", line 340, in <module>
    main()
  File "/home/zeyi/.conda/envs/test/lib/python3.8/site-packages/torch/distributed/launch.py", line 326, in main
    sigkill_handler(signal.SIGTERM, None)  # not coming back
  File "/home/zeyi/.conda/envs/test/lib/python3.8/site-packages/torch/distributed/launch.py", line 301, in sigkill_handler
    raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/zeyi/.conda/envs/test/bin/python', '-u', './run_seq2seq.py', '--local_rank=1', '--model_name_or_path', 't5-large', '--output_dir', 'output_dir', '--adam_eps', '1e-06', '--dataset_name', 'wmt16', '--dataset_config', 'ro-en', '--do_eval', '--do_train', '--evaluation_strategy=steps', '--freeze_embeds', '--label_smoothing', '0.1', '--learning_rate', '3e-5', '--logging_first_step', '--logging_steps', '1000', '--max_source_length', '128', '--max_target_length', '128', '--num_train_epochs', '1', '--overwrite_output_dir', '--per_device_eval_batch_size', '16', '--per_device_train_batch_size', '16', '--predict_with_generate', '--eval_steps', '25000', '--sortish_sampler', '--task', 'translation_en_to_ro', '--test_max_target_length', '128', '--val_max_target_length', '128', '--warmup_steps', '500', '--n_train', '2000', '--n_val', '500']' returned non-zero exit status 1.

Expected behavior

I hope that it will run the model with deepspeed or shared techniques. Actually I want to train the t5-11b model and want to change the dataset dir to my dataset but even can not reproduce what @stas00 shared before.

@ZeyiLiao ZeyiLiao added the bug label Jun 3, 2022
@stas00
Copy link
Contributor

stas00 commented Jun 3, 2022

This approach you tried is very old and is not supported any longer.

Please switch to modern tools and it should just work.

Here are a few current examples:

straight DDP:

rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0,1 python \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
--output_dir output_dir --adam_eps 1e-06 --do_train --label_smoothing 0.1 \
--learning_rate 3e-5 --logging_first_step --logging_steps 500 \
--max_source_length 128 --max_target_length 128 --val_max_target_length 128 \
--num_train_epochs 1 --overwrite_output_dir --per_device_train_batch_size 2 \
--predict_with_generate --sortish_sampler --source_lang en --target_lang ro \
--dataset_name wmt16 --dataset_config ro-en --source_prefix \
'translate English to Romanian: ' --warmup_steps 50 --max_train_samples 50 

same with deepspeed

rm -r output_dir; PYTHONPATH=src USE_TF=0 deepspeed --num_gpus 2 \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
--output_dir output_dir --overwrite_output_dir --max_source_length 128 \
--max_target_length 128 --val_max_target_length 128 --do_train \
--num_train_epochs 1 --per_device_train_batch_size 2 --learning_rate 3e-3 \
--dataset_name wmt16 --dataset_config ro-en --source_lang en --target_lang ro \
--source_prefix 'translate English to Romanian: ' --max_train_samples 50 \
--deepspeed tests/deepspeed/ds_config_zero3.json --save_steps 1 

make sure it works, adapt to your data, and then replace with the large model size.

Please let me know if this unblocked you and please share the link where you found the old info so that we could update that thread with the new information.

Thank you

@ZeyiLiao
Copy link
Author

ZeyiLiao commented Jun 3, 2022

Hi @stas00 , the order info comes from here.

I run the following scripts to install required package:

pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html

git clone https://github.com/huggingface/transformers
pip install .

pip install fairscale, deepspeed

pip install -r /exmaples/pytorch/translation/requirement.txt

os.environment['CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

I tried the straight DPP and deepspeed scripts, all says the following error though I add "--per_device_train_batch_size 2":

run_translation.py: error: argument --per_device_train_batch_size: expected one argument.

What's more, I want to run language inference task with t5 model and do you have any recommendation which example script should I use?

@stas00
Copy link
Contributor

stas00 commented Jun 4, 2022

run_translation.py: error: argument --per_device_train_batch_size: expected one argument.

oops, my bad - I fixed the examples in my reply #17534 (comment)

What's more, I want to run language inference task with t5 model and do you have any recommendation which example script should I use?

same script, you just tell it to eval instead of train, here is a few ways for one gpu:


# non-distributed 1-gpu fp32 eval only

rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0 python examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --output_dir output_dir --adam_eps 1e-06 --do_eval --evaluation_strategy=steps --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 500 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 2 --predict_with_generate --eval_steps 2500 --sortish_sampler --source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" --source_prefix "translate English to Romanian: " --val_max_target_length 128 --warmup_steps 50 --max_eval_samples 50 

# non-distributed 1-gpu --fp16_full_eval eval only

rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0 python examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --output_dir output_dir --adam_eps 1e-06 --do_eval --evaluation_strategy=steps --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 500 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size 2 --predict_with_generate --eval_steps 2500 --sortish_sampler --source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" --source_prefix "translate English to Romanian: " --val_max_target_length 128 --warmup_steps 50 --max_eval_samples 50 --fp16_full_eval

and you can adapt those to multi-gpu and/or deepspeed based on the first examples I shared.

but basically I removed the training args and replaced those with eval-only args.

The 2nd (last) example shows how to do it in half-precision which may not work well (depending on the model), so start with the normal fp32 eval (i.e. w/o --fp16_full_eval)

Of course, play with the values of the args to fit your environment.

I just wonder how to download this dataset as the following script:

you don't download it directly - load_dataset does it automatically for you at runtime (should have Internet).

@ZeyiLiao
Copy link
Author

ZeyiLiao commented Jun 4, 2022

Thanks for your detailed reply @stas00 ! I tried the t5-small model and they works so I changed it to t5-11b with 3 questions here.

In my case, I could not use straight DDP otherwise CUDA will run out of memory.

When I use deepspeed script

export MASTER_PORT=9999; rm -r output_dir; PYTHONPATH=src USE_TF=0 deepspeed --num_gpus 4 
examples/pytorch/translation/run_translation.py --model_name_or_path t5-11b --output_dir output_dir --
overwrite_output_dir --max_source_length 128 --max_target_length 128 --val_max_target_length 128 --do_train --
num_train_epochs 4 --per_device_train_batch_size 8 --learning_rate 1e-4  --source_lang prompt --target_lang completion 
--train_file=/home/zeyi/lr_dataset/data/processed/logic_comp1_nt_v0_infer1.0_balance_seed42_filtered/csv_file/train/train.json 
--test_file=/home/zeyi/lr_dataset/data/processed/logic_comp1_nt_v0_infer1.0_balance_seed42_filtered/csv_file/test/test.json 
--validation_file=/home/zeyi/lr_dataset/data/processed/logic_comp1_nt_v0_infer1.0_balance_seed42_filtered/csv_file/dev/dev.json 
--max_train_samples 50 --deepspeed tests/deepspeed/ds_config_zero3.json --save_strategy epoch

It said that

Traceback (most recent call last):
File "examples/pytorch/translation/run_translation.py", line 652, in
main()
File "examples/pytorch/translation/run_translation.py", line 261, in main
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
File "/home/zeyi/transformers/src/transformers/hf_argparser.py", line 214, in parse_args_into_dataclasses
obj = dtype(**inputs)
File "", line 102, in init
File "/home/zeyi/transformers/src/transformers/training_args.py", line 1012, in post_init
and (self.device.type != "cuda")
File "/home/zeyi/transformers/src/transformers/utils/import_utils.py", line 802, in wrapper
return func(*args, **kwargs)
File "/home/zeyi/transformers/src/transformers/training_args.py", line 1264, in device
return self._setup_devices
File "/home/zeyi/transformers/src/transformers/utils/generic.py", line 49, in get
cached = self.fget(obj)
File "/home/zeyi/transformers/src/transformers/utils/import_utils.py", line 802, in wrapper
return func(*args, **kwargs)
File "/home/zeyi/transformers/src/transformers/training_args.py", line 1225, in _setup_devices
deepspeed.init_distributed()
File "/home/zeyi/.conda/envs/test/lib/python3.8/site-packages/deepspeed/utils/distributed.py", line 51, in init_distributed
torch.distributed.init_process_group(backend=dist_backend,
File "/home/zeyi/.conda/envs/test/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 500, in init_process_group
store, rank, world_size = next(rendezvous_iterator)
File "/home/zeyi/.conda/envs/test/lib/python3.8/site-packages/torch/distributed/rendezvous.py", line 190, in _env_rendezvous_handler
store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)
RuntimeError: Address already in use

But I find some info and do add export MASTER_PORT=9999 at the beginning of scripts.
I also use netstat -nltp but can not find which jobs is the zombie task.
What should I do to delete those zombie running process.

And can I add a parameter like here i.e. --sharded_ddp to use sharded_ddp instead of straight ddp?(I am not sure I totally understand the definition of straight ddp and sharded ddp)

In my previous code, I will pass some generator option to t5 model

self.generator_options = {'min_length': 1, 'max_length': 128, 'num_beams': 1, 'num_return_sequences': 1, 'do_sample': False, 'top_k': 50, 'top_p': 1.0,
'temperature': 1.0, 'length_penalty': 1.0, 'repetition_penalty': 1.0}

output_ids = self.reasoner.generate(batch['all_inps'], **self.generator_options)

So how can I do the same thing here?

@stas00
Copy link
Contributor

stas00 commented Jun 5, 2022

RuntimeError: Address already in use

But I find some info and do add export MASTER_PORT=9999 at the beginning of scripts. I also use netstat -nltp but can not find which jobs is the zombie task. What should I do to delete those zombie running process.

Normally you just kill them manually. Upgrade your deepspeed, the zombies should get killed automatically.

You should pass an explicit argument to deepspeed with the desired setting if you don't want the default port.

  --master_port MASTER_PORT
                        (optional) Port used by PyTorch distributed for communication during training.
  --master_addr MASTER_ADDR
                        (optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.

And can I add a parameter like here i.e. --sharded_ddp to use sharded_ddp instead of straight ddp?(I am not sure I totally understand the definition of straight ddp and sharded ddp)

That's another implementation of ZeRO protocol. You don't need it.

In my previous code, I will pass some generator option to t5 model

self.generator_options = {'min_length': 1, 'max_length': 128, 'num_beams': 1, 'num_return_sequences': 1, 'do_sample': False, 'top_k': 50, 'top_p': 1.0,
'temperature': 1.0, 'length_penalty': 1.0, 'repetition_penalty': 1.0}

output_ids = self.reasoner.generate(batch['all_inps'], **self.generator_options)

So how can I do the same thing here?

Please run:

python examples/pytorch/translation/run_translation.py --help 

you will see the existing options there (e.g. . --num_beams)

If you want to customize the example script, these generate args are passed here (num_beams)

num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")

all the generate options are here:

@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None,
max_new_tokens: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
constraints: Optional[List[Constraint]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
**model_kwargs,

@ZeyiLiao
Copy link
Author

ZeyiLiao commented Jun 9, 2022

Hi @stas00 , thank you for your reply! I trained it wity your updated scripts but the job was stopped accidently.

So I tried to resume from checkpoints by this scripts(without --overwrite_output_dir, and output_dir_1 is the folder with checkpoints)

deepspeed examples/pytorch/translation/run_translation.py --model_name_or_path t5-11b --output_dir output_dir_1 --max_source_length 128 --max_target_length 128 --val_max_target_length 128 --do_train --num_train_epochs 4 --per_device_train_batch_size 16 --learning_rate 1e-4  --source_lang prompt --target_lang completion 
--train_file=
/home/zeyi/lr_dataset/data/processed/logic_comp1_nt_v0_infer1.0_balance_seed42_trim_filtered/json_file_t5_11b/train/train.json
--test_file=
/home/zeyi/lr_dataset/data/processed/logic_comp1_nt_v0_infer1.0_balance_seed42_trim_filtered/json_file_t5_11b/test/test.json 
--validation_file=
/home/zeyi/lr_dataset/data/processed/logic_comp1_nt_v0_infer1.0_balance_seed42_trim_filtered/json_file_t5_11b/dev/dev.json 
--deepspeed tests/deepspeed/ds_config_zero3.json --save_strategy epoch --evaluation_strategy epoch --load_best_model_at_end

But it said that

Using /home/zeyi/.cache/torch_extensions as PyTorch extensions root...
No modifications detected for re-loaded extension module utils, skipping build step...
Loading extension module utils...
Time to load utils op: 0.0007061958312988281 seconds
[INFO|deepspeed.py:444] 2022-06-09 15:28:40,179 >> Attempting to resume from output_dir_1/checkpoint-3126
[2022-06-09 15:31:04,178] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 40204
[2022-06-09 15:31:04,178] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 40205
[2022-06-09 15:31:04,178] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 40206
[2022-06-09 15:31:04,178] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 4020

@stas00
Copy link
Contributor

stas00 commented Jun 9, 2022

this usually means that you didn't have enough cpu memory to resume

Unfortunately it's a bug in deepspeed, where instead of loading the checkpoint directly to gpu it first loads it to cpu.
I filed a bug report here microsoft/DeepSpeed#1971
Please voice your need in this issue so that it's seen that it needs higher priority.

I can offer you a hack that may help. Basically you need to stagger the checkpoint loading so that not all 4 processes try to load it to cpu memory at once.

@stas00
Copy link
Contributor

stas00 commented Jun 9, 2022

something like this should work to stagger the checkpoint loading:

diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py
index 9fa22d462..ce2f39cc5 100644
--- a/src/transformers/deepspeed.py
+++ b/src/transformers/deepspeed.py
@@ -447,6 +447,12 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
         deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/global_step*"))

         if len(deepspeed_checkpoint_dirs) > 0:
+
+            # hack to stagger checkpoint loading so that they don't all try to use cpu at the same time
+            rank = trainer.args.local_rank
+            from time import sleep
+            sleep(rank*20)
+
             logger.info(f"Attempting to resume from {resume_from_checkpoint}")
             # this magically updates self.optimizer and self.lr_scheduler
             load_path, _ = deepspeed_engine.load_checkpoint(

adjust 20 to perhaps smaller or longer wait in secs.

so here the following happens:

process 0 sleeps for 0 secs, process 1 for 20 secs, 2 for 40 secs, etc. so each process gets full use of CPU memory alone.

you can apply the patch manually or with:

git clone https://github.com/huggingface/transformers
cd transformers
git apply patch.txt
pip install -e .

assuming you saved my code as patch.txt (attached it to this comment as well so you can just download it)

patch.txt

@ZeyiLiao
Copy link
Author

@stas00 ,Thank you! I have sucessfully trained the t5-11b.

And here, I want to do the inference in my setup code. Since it's hard to load t5-11b on one GPU, I use model.parallelize to do the inference part.

model = T5ForConditionalGeneration.from_pretrained('./checkpoint)
device_map = {
0: [0, 1, 2],
1: [3, 4, 5, 6, 7, 8, 9],
2: [10, 11, 12, 13, 14, 15, 16],
3: [17, 18, 19, 20, 21, 22, 23],
}
model.parallelize(device_map)
model.predict()

But the errors said:

Traceback (most recent call last):
  File "/home/zeyi/lr_dataset/src/main.py", line 294, in <module>
    trainer.test(model=model_ckpt, test_dataloaders=loader)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 907, in test
    return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 683, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 950, in _test_impl
    results = self._run(model, ckpt_path=self.tested_ckpt_path)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
    self._dispatch()
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1271, in _dispatch
    self.training_type_plugin.start_evaluating(self)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 178, in start_evaluating
    self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 201, in spawn
    mp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), nprocs=self.num_processes)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 136, in join
    signal_name=name
torch.multiprocessing.spawn.ProcessExitedException: process 2 terminated with signal SIGABRT
wandb: Waiting for W&B process to finish... (failed 1). Press Control-C to abort syncing.
wandb:                                                                                
wandb: Synced lrgenerative_logic_comp1_v7_1.0_new_seed42_trim_filtered_t5_11b_13_06_2022_45964ce7: https://wandb.ai/soumya_research/lr_dataset/runs/snh11aqq
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20220613_132827-snh11aqq/logs
[W CudaIPCTypes.cpp:21] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

I have find some solution said that set num_workers = 0, but it still doesn't work.

@stas00
Copy link
Contributor

stas00 commented Jun 13, 2022

@stas00 ,Thank you! I have sucessfully trained the t5-11b.

Super!

And here, I want to do the inference in my setup code. Since it's hard to load t5-11b on one GPU, I use model.parallelize to do the inference part.

parallelize is about to be deprecated and as such is no longer supported. Please use deepspeed instead, it's many folds more superior to the naive parallelization.

@ZeyiLiao
Copy link
Author

@stas00 ,Thanks a lot!

In my case, we use pytorch-lightning and what I want to do is
model = T5ForConditionalGeneration.from_pretrained('./checkpoint)
And follow the doc here to set

trainer = Trainer(accelerator="gpu", devices=4, strategy="deepspeed_stage_3_offload")
trainer.predict()

But although I am just doing prediction, why it will still call the def configure_optimizers(self) function.

In addition to that, it gave an error although I do have ninja package.

[2022-06-13 16:55:48,399] [WARNING] [engine.py:1122:_configure_optimizer] **** You are using ZeRO with an untested optimizer, proceed with caution *****
[2022-06-13 16:55:48,405] [WARNING] [coalesced_collectives.py:26:<module>] unable to find torch.distributed._reduce_scatter_base. will fall back to torch.distributed.reduce_scatter which will result in suboptimal performance. please consider upgrading your pytorch installation.
Using /home/zeyi/.cache/torch_extensions as PyTorch extensions root...
Traceback (most recent call last):
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 683, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 950, in _test_impl
    results = self._run(model, ckpt_path=self.tested_ckpt_path)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1184, in _run
    self._pre_dispatch()
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1219, in _pre_dispatch
    self.accelerator.pre_dispatch(self)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 136, in pre_dispatch
    self.training_type_plugin.pre_dispatch()
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 389, in pre_dispatch
    self.init_deepspeed()
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 461, in init_deepspeed
    self._initialize_deepspeed_inference(model)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 563, in _initialize_deepspeed_inference
    dist_init_required=False,
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/deepspeed/__init__.py", line 130, in initialize
    config_params=config_params)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 294, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1124, in _configure_optimizer
    self.optimizer = self._configure_zero_optimizer(basic_optimizer)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1439, in _configure_zero_optimizer
    communication_data_type=self.communication_data_type)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/deepspeed/runtime/zero/stage3.py", line 292, in __init__
    util_ops = UtilsBuilder().load()
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/deepspeed/ops/op_builder/builder.py", line 463, in load
    return self.jit_load(verbose)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/deepspeed/ops/op_builder/builder.py", line 512, in jit_load
    verbose=verbose)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1091, in load
    keep_intermediates=keep_intermediates)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1302, in _jit_compile
    is_standalone=is_standalone)
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1373, in _write_ninja_file_and_build_library
    verify_ninja_availability()
  File "/home/zeyi/.conda/envs/lr_dataset/lib/python3.7/site-packages/torch/utils/cpp_extension.py", line 1429, in verify_ninja_availability
    raise RuntimeError("Ninja is required to load C++ extensions")
RuntimeError: Ninja is required to load C++ extensions
python-BaseException

I am just worried about is it reasonable to work like this?

  1. Trained the t5-11b by transformer.Trainer.
  2. Just load the checkpoint saved before and use Pytorch-lightning to do the prediction
    3.Since can not load t5-11b on one GPU, I set the strategy to deepspeed_stage_3_offloadfor trainer.

@stas00
Copy link
Contributor

stas00 commented Jun 14, 2022

wrt to the traceback you shared, pip install ninja should do the trick, even though it should have already been installed. something $PATH env var is missing the bin dir where pip installs to, check with:

which ninja

it should give you the path to the binary. Don't try to run deepspeed again until the above returns the path. if it returns nothing it means that your python's env bin dir is not in your $PATH env var.

wrt PL-specific issues please ask at PL Issues as I'm not a PL user.

@stas00
Copy link
Contributor

stas00 commented Jun 14, 2022

there is another workaround that requires no ninja and it's to prebuild deepspeed https://huggingface.co/docs/transformers/main/main_classes/deepspeed#installation (local install where you clone deepspeed and then build it)

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@kmizrak-nci
Copy link

The git apply patch.txt throws an error of
error: corrupt patch at line 17

Am I missing something in the application of it, or missing an argument?

@stas00
Copy link
Contributor

stas00 commented Aug 3, 2022

bad copy-n-paste? Just insert it manually - it's just a few lines of code and you can tell where to insert by the context around it.

@ZeyiLiao
Copy link
Author

Hi @stas00 , hope you all godd! And would deep-speed be compatible with Auto-regressive model here, like I need to fine-tuning a large OPT model. (BTW:Tried hard on PL trainer but always miss some weight of layers). Thanks!

@stas00
Copy link
Contributor

stas00 commented Nov 14, 2022

I haven't tried it, but I don't see any reason why it shouldn't work. OPT has been out for quite a few months now so surely if it didn't work we would have heard by now and fixed it. Give it a try and if you run into problems please start a new Issue. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants