Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

legacy finetune with t5 issues #12848

Closed
sacombs opened this issue Jul 22, 2021 · 9 comments
Closed

legacy finetune with t5 issues #12848

sacombs opened this issue Jul 22, 2021 · 9 comments
Assignees

Comments

@sacombs
Copy link

sacombs commented Jul 22, 2021

Hi @stas00

Splitting of from #8771 (comment)

There is a lot of great information in your post; thanks for being thorough!

I guess I dont understand what parameters I need to change within the deepspeed config file to properly offload into cpu memory. I have 473 gb of RAM available for offloading, which seems to be enough based on what you listed. I am also using the finetune script in the seq2seq legacy folder. The command is:

export BS=2; rm -rf output_dir; PYTHONPATH=../../src USE_TF=0 deepspeed --num_gpus=8 ./finetune_trainer.py --model_name_or_path "Rostlab/prot_t5_xl_uniref50" --output_dir output_dir --adam_eps 1e-06 --data_dir /mnt/data --do_eval --do_predict --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 512 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000 --sortish_sampler --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 5 --n_train 60 --n_val 10 --n_test 10 --deepspeed ../../../tests/deepspeed/ds_config_zero3.json --fp16

I had to modify finetune to include the T5Tokenizer as the AutoTokenizer wouldnt work.

For zero 3 optimization, I am using lower values for stage3_params, since the documentation indicated to use lower values to offload memory.

    "zero_optimization": {
        "stage": 3,
        "cpu_offload": true,
        "cpu_offload_params": true,
        "cpu_offload_use_pin_memory" : true,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "stage3_max_live_parameters": 1e3,
        "stage3_max_reuse_distance": 1e3,
        "stage3_prefetch_bucket_size": 2e3,
        "stage3_param_persitance_threshold": 1e3,
        "reduce_bucket_size": 3e3,
        "prefetch_bucket_size": 3e3,
        "sub_group_size": 1e3
    },
@stas00 stas00 self-assigned this Jul 22, 2021
@stas00
Copy link
Contributor

stas00 commented Jul 22, 2021

first, any reason why you're not using the latest scripts? The legacy scripts are no longer being maintained and the up-to-date scripts had great many improvements. So if it's not too hard I highly recommend switching to those. Most likely you want
https://github.com/huggingface/transformers/blob/master/examples/pytorch/translation/run_translation.py Albeit, this is orthogonal to the Deepspeed issue you wanted to discuss.

For zero 3 optimization, I am using lower values for stage3_params, since the documentation indicated to use lower values to offload memory.

After this discussion is over, let's review where you found this information, because this is incorrect. The doc says which specific parameters you need to tweak, not all of them.

Have you considered using tuned-up-for-you auto values? https://huggingface.co/transformers/master/main_classes/deepspeed.html#zero-3-config

ah, and you have a typo in at least on of the key names as well - there is no stage3_param_persitance_threshold - deepspeed is a bit troublesome as it doesn't validate keys and simply uses the default if you make a typo.

It dumps the final config when the program starts, so you can always review whether your settings "made it".

Your config is also "dated" - recent deepspeed moved to a newer config as you can see in the docs (albeit it's backward compatible).

@stas00
Copy link
Contributor

stas00 commented Jul 22, 2021

Perhaps you were referring to: "Smaller values use less memory"

stage3_param_persistence_threshold: [integer]

Description Default
Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages).

https://www.deepspeed.ai/docs/config-json/

@sacombs
Copy link
Author

sacombs commented Jul 26, 2021

@stas00,

Thanks for the pointers. I modified my ds_confg.json with the following:

json = {
    "zero_optimization": {
        "stage": 3, 
        "offload_optimizer": {
            "device": "cpu", 
            "pin_memory": true
        }, 
        "offload_param": {
            "device": "cpu", 
            "pin_memory": true
        }, 
        "overlap_comm": true, 
        "contiguous_gradients": true, 
        "sub_group_size": 1.000000e+09, 
        "reduce_bucket_size": 1.048576e+06, 
        "stage3_prefetch_bucket_size": 9.437184e+05, 
        "stage3_param_persistence_threshold": 1.024000e+04, 
        "stage3_max_live_parameters": 10.0, 
        "stage3_max_reuse_distance": 10.0, 
        "stage3_gather_fp16_weights_on_model_save": true
    }, 
    "train_batch_size": 16, 
    "train_micro_batch_size_per_gpu": 2, 
    "zero_allow_untested_optimizer": true
}

I also switched to run_translation.py in the master branch.

Even with the

     "stage3_max_live_parameters": 10.0, 
     "stage3_max_reuse_distance": 10.0, 

I am unable to use a batchsize of 2 per gpu without hitting OOM for GPU. Any thoughts on optimizing this? My commandline is:

rm -rf output_dir; USE_TF=0 deepspeed --num_gpus=8 ./run_translation.py --model_name_or_path "Rostlab/prot_t5_xl_uniref50" --output_dir output_dir --adam_eps 1e-06 --do_eval --do_predict --do_train --evaluation_strategy=steps --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 512 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --predict_with_generate --eval_steps 25000 --sortish_sampler --warmup_steps 5 --deepspeed deepsped.config --fp16 --train_file train.json --test_file train.json --validation_file train.json --source_lang a --target_lang b --overwrite_output_dir --predict_with_generate --per_device_train_batch_size=2 --per_device_eval_batch_size=2

@stas00
Copy link
Contributor

stas00 commented Jul 27, 2021

I had no problem doing mostly the same with the current version of examples with just 4x v100-16GB GPUs - I didn't change anything from the default ds config in the repo and it took only 6GB / gpu for training and ~10GB / gpu for eval.

cd transformers
BS=4; PYTHONPATH=src USE_TF=0 /usr/bin/time -v deepspeed --num_gpus 4 \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-3b --output_dir output_dir \
--overwrite_output_dir --max_train_samples 10 --max_eval_samples 10 --max_source_length 512 \
--max_target_length 128 --val_max_target_length 128 --do_train --do_eval --num_train_epochs 1 \
--per_device_train_batch_size $BS --per_device_eval_batch_size $BS --learning_rate 3e-3 \
--warmup_steps 500 --predict_with_generate --save_steps 0 --eval_steps 1 --group_by_length \
--dataset_name wmt16 --dataset_config ro-en --source_lang en --target_lang ro --source_prefix \
"translate English to Romanian: " --deepspeed tests/deepspeed/ds_config_zero3.json

probably can easily do a much larger BS on this one and 8 gpus you definitely shouldn't have any problems.

I highly recommend to use the default ds config and not change anything there unless you really need to.

@sacombs
Copy link
Author

sacombs commented Jul 27, 2021

I was able to use your command and train using the ro-en dataset and t5-3b.

However, I am trying to use a custom model: "Rostlab/prot_t5_xl_uniref50". This is based on t5-3b, but without the denoising objective in t5. I looked at the model card and it also does not have the task-specific parameters in its config.json for translation/summarization. I think this means that I might need to change the Trainer, but I am not sure what is specifically needed.

Before I started down the deepspeed path, I was using a training loop that I had created with model parallelization. The train step is below:

model = T5ForConditionalGeneration.from_pretrained(model_name)
# model = model.to(device)
device_map = {0: [0],
             1: [1, 2, 3 ],
             2: [4, 5, 6 ],
             3: [7, 8, 9, 10 ],
             4: [11, 12, 13, 14],
             5: [15, 16, 17],
             6: [18, 19, 20],
             7: [21, 22, 23]
             }

model.parallelize(device_map)


def run_a_train_epoch():
    print ("Training...")
    all_losses = []
    model.train()
    for batch_idx, batch in enumerate(train_dataloader):
        if batch_idx > 0 and batch_idx % 20 == 0:
            print(f"Trained {batch_idx} batches...")
        #print ("Batch: ", batch_idx)
        #print (_, data)
                
        ids = batch['source_ids'].to('cuda:0', dtype = torch.long)
        mask = batch['source_mask'].to('cuda:0', dtype = torch.long)
        y = batch['target_ids'].to('cuda:0', dtype = torch.long)
        
        y_ids = y[:, :-1].contiguous()
        decoder_attention_mask = batch['target_mask'].to('cuda:0', dtype = torch.long)
        
        y_mask = decoder_attention_mask[:, :-1].contiguous()
        
        outputs = model(input_ids = ids, attention_mask = mask, labels=y_ids, decoder_attention_mask=y_mask)
        
        loss = outputs[0]
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        all_losses.append(loss)

    train_loss = sum(all_losses) / len(all_losses)
    return train_loss
 

Doing this, I was only able to train on 2 batches at once. Is it possible to use trainer with this model or do you have any pointers on transferring this to deepspeed?

@stas00
Copy link
Contributor

stas00 commented Jul 27, 2021

You don't need to transfer anything to Deepspeed, Deepspeed ZeRO simply provides a much simpler way of doing model parallelism w/o needing to change the model. That is whatever model you use it'll just work. Deepspeed magically parallelizes whatever you throw at it (well, most of the time).

So your goal is to use a t5-3b model with a slightly different task. I don't see any reason why it won't just work out of the box.

I used run_translation.py as an example to test that everything works and scales. You can adapt it to your needs. run_translation.py is the same as the old legacy finetune_trainer.py except it was massively cleaned up, improved and then split off to do just one task - translation. e.g. examples/pytorch/summarization is another split off from finetune_trainer.py.

Perhaps you can follow this plan:

  1. study the existing example scripts and find the one that is the closest to your needs
  2. adapt it to your exact needs by porting over whatever extra code you wrote in your finetune_trainer.py
  3. test that it works with just python perhaps on a small model
  4. add deepspeed using the default settings of tests/deepspeed/ds_config_zero3.json to scale it up this time on the full model.

@sacombs
Copy link
Author

sacombs commented Jul 27, 2021

I am not sure what is going on...I stepped through the code and made sure that I was not missing anything by printing out the tokens/masks and several other points. The only thing that I can get to work with this model, dataset, and run_translation.py is a per_device_batch_size of 1. I am using the tests/deepspeed/ds_config_zero3.json with the run_translation.py script. I have been able to use the original t5-3b model with the ro-en translation dataset and your configuration file with a per device batch size of 8 just fine.

Not sure where to go from here.

Thanks!

@stas00
Copy link
Contributor

stas00 commented Jul 28, 2021

a model is a model is a model is a model - it doesn't matter which t5-3b derivative you use - it will take the exact same amount of memory. What matters is your code - it's possible that you do something that leaks memory or allocates more than the example program does.

The next step is to either to try to compare how your program is different, or to use the memory profiler and see where the bulk of memory is allocated. You can start with just enabling --skip_memory_metrics 0 (unskip that is) with the current examples and it'll report the memory allocations in the first gpu. or you can use various other pytorch profilers.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants