Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Good Second Issue: T5 FP16 in Pytorch #9295

Open
patrickvonplaten opened this issue Dec 24, 2020 · 11 comments
Open

Good Second Issue: T5 FP16 in Pytorch #9295

patrickvonplaten opened this issue Dec 24, 2020 · 11 comments
Assignees
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!

Comments

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 24, 2020

🚀 Feature request

This "Good second issue" should revisit some of the problems we were having with FP16 for T5ForConditionalGeneration: #4586 and help to make T5 compatible with fp16.

Requirements:

  • use transformers master
  • use newest pytorch version
  • have access to GPU

Context:

To better explain the context, let's define the three different pre-trained T5 models types we have:

  • T5v1 (original T5): => this corresponds to all those checkpoints: t5-small, t5-base, t5-large, t5-3b, t5-11b
  • T5v1_1 (improved T5): => this corresponds to all those checkpoints: google/t5-v1_1-small, google/t5-v1_1-base, google/t5-v1_1-large, google/t5-v1_1-xl, google/t5-v1_1-xxl. T5v1_1 has a slightly different architecture than T5v1. More info on differences can be found here: 🌟 T5 V1.1 #6285
  • MT5 (multi-lingual T5): => this model is identical in architecture to T5v1_1 but has different pre-trained weights and a much larger word embedding matrix.

As shown in this issue #4586 , training T5v1 in fp16 mode led in the past to numerical overflow in the T5LayerFF forward pass:

class T5LayerFF(nn.Module):
.

At the time of this issue: #4586, T5v1 was added with a small bug that led to slightly wrong outputs that was only fixed by this PR: #8518.

Also, now there are new T5 checkpoints, notably the T5v1_1 and MT5 checkpoints, where it would be very interesting to see whether fp16 can work with those.

Feature Request

So for this feature request, we should two scenarios:

  1. Inference:

For each T5 model type we should test when the models break during inference. This can be as easy as testing the following script for a bunch of different checkpoints on different input_str:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

checkpoint = "t5-small"  # "google/mt5-small", "google/t5-v1_1-small"

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

input_str = "Hello there. This is the input."  # here it would be better to test much larger inputs

input_ids = tokenizer(input_str, return_tensors="pt").input_ids.to('cuda')

# FP32
output_fp32 = model.generate(input_ids)

# FP16
model.half()
output_fp16 = model.generate(input_ids)

if output_fp32.tolist() == output_fp16.tolist():
    print("SUCCESS: Output is equal!")
else:
    print("Output is different!")
    print("FP32", output_fp32)
    print("FP16", output_fp16)
  1. Training (the more interesting part):

This is probably more important and will require more time / skill. In order to check how T5 does in FP16 training, I'd recommend to use the newly added Seq2SeqTrainer:

class Seq2SeqTrainer(Trainer):
. I would recommend to train on a summarization task, such as CNN/Dailymail. One could closely follow, this notebook: https://colab.research.google.com/drive/1Ekd5pUeCX7VOrMx94_czTkwNtLN32Uyu?usp=sharing, but replacing Bert2Bert with the different T5 models. Ideally different "fp16 backends" should be tested:
fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
and one should try to see whether hacks as proposed in #4586 (comment) can solve the problem. It would be very interesting to see whether the error happens only for T5v1 or also for T5v1_1 and MT5 and it what point. For each type it would be great to test for "small", "base" and if possible even "large". Ideally, one should first create a short summarization fine-tuning script (happy to help here) and then run a bunch of different experiments with different fp16 backends and different models.

Possible Outcome

The results of those experiments should be documented here or even better on https://discuss.huggingface.co/. Ideally, a solution to the problem is found and one could publish a nice blog post explaining how to effectively train T5.

Motivation

T5 is one of the most widely used models of Transformers at the moment so that more results to this issue would be extremely useful for the community. In addition, this issue can be a great opportunity to learn more about the limits of fp16 and why some models still do require full fp32 support (or at least until bfloat16 is better supported in torch). This is not an easy issue to tackle, but an extremely important one.

Your contribution

I'm happy to help along the way, starting with making a nice T5 summarization training pipeline that lets one easily test on different models, and fp16 backends.

@patrickvonplaten patrickvonplaten self-assigned this Dec 24, 2020
@patrickvonplaten patrickvonplaten added the Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! label Dec 24, 2020
@patil-suraj patil-suraj self-assigned this Dec 24, 2020
@patil-suraj
Copy link
Contributor

patil-suraj commented Jan 8, 2021

here's what I found

t5-small is the only T5 model that works in fp16 at the moment. The rest of the models produce nan loss/logits.

for all the models and versions (v1, v1.1, mT5) at some point we get inf values in hidden_states after applying the final linear layer (wo) in T5DenseReluDense and T5DenseGatedGeluDense.

class T5DenseReluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
hidden_states = F.relu(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
class T5DenseGatedGeluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"]
def forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states

which results in nan values in T5LayerNorm.

Also for t5-large, t5-v1_1-base, t5-v1_1-large, there are inf values in the output of T5LayerSelfAttention and T5LayerCrossAttention, specifically where we add the attn output with the hidden_states

hidden_states = hidden_states + self.dropout(attention_output[0])

layer_output = hidden_states + self.dropout(attention_output[0])

This happens during both training and inference, to reproduce

model = T5ForConditionalGeneration.from_pretrained("t5-base").to("cuda:0").eval()
model.half()

tokenizer = T5Tokenizer.from_pretrained("t5-base")

ARTICLE = """summarize: Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin's comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of 'My God' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object.  Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn't been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they'd recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren't revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor's office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot's license. Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot's license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it's only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren't going to keep doing their job and they're upset about that and so they're suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person's problems." Germanwings crash compensation: What we know. Who was the captain of Germanwings Flight 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel, and Anna-Maja Rappard contributed to this report."""

inputs = tokenizer(ARTICLE, max_length=512, truncation=True, return_tensors="pt").to("cuda:0")
out = model(**inputs, decoder_input_ids=torch.tensor([[tokenizer.pad_token_id]]).to("cuda:0"))
torch.isnan(out.logits).any()
# => True

Proposed fix

To avoid inf values we could clamp the hidden_states to the max values for the current data type if there are inf in it. i.e

if torch.isinf(hidden_states).any():
    clamp_value = torch.finfo(hidden_states.dtype).max - 1000
    hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

we need to add this after self attn, cross-attn, and the feed-forward layer which is where the inf values occur. This works for both apex and amp

To verify this fix, I trained t5-base, t5-v1_1-base and t5-v1_1-small on cnn/dm for 10k steps (1.11 epochs)
Here's the training command, to run this clone this fork and check out the fix-t5-fp16 branch. navigate to examples/seq2seq dir, follow the instructions in the readme to download cnn_dm and dataset, and then run the following command

export M=google/t5-v1_1-base
export OUT_DIR=t5-v1_1-base-cnn-fp16
export DATA_DIR=cnn_dm

python finetune_trainer.py \
    --model_name_or_path $M \
    --data_dir $DATA_DIR \
    --output_dir $OUT_DIR --overwrite_output_dir \
    --max_steps=10000 \
    --gradient_accumulation_steps=8 \
    --learning_rate=1e-4 \
    --per_device_train_batch_size=4 \
    --n_val 500 \
    --max_target_length=56 --val_max_target_length=128 \
    --fp16 --fp16_backend apex \
    --do_train --do_eval --evaluation_strategy steps \
    --logging_steps=100 --logging_first_step --eval_steps=2500 --save_steps=2500 --save_total_limit=2 \
    --sortish_sampler \

for evaluation

python run_eval.py \
    t5-v1_1-base-cnn-fp16  cnn_dm/test.source hypothesis.txt \
    --reference_path cnn_dm/test.target \
    --score_path metrics.json \
    --device cuda:0 \
    --prefix summarize: \
    --bs 16 \
    --fp16 \

and got the following metrics (ROUGE2)

  1. for t5-base: 19.2804
  2. for t5-v1.1-base: 18.4316
    (note that the score for t5-base is more because it's already pre-trained on cnn/dm)

To compare this, evaluated the pre-trained t5-base in both fp32 and fp16, which gave the following results

  1. fp16: 18.3681
  2. fp32: 18.394

So the results are close enough.

To verify the fix for t5-large, I evaluated the pre-trained t5-large in fp32 and fp16 (use the same command above to evaluate t5-large) and got the following results

  1. fp16: 19.2734
  2. fp32: 19.2342

Surprisingly, rouge2 is slightly better in fp16.

So with the above fix, the following model types now work in fp16 (opt level 01), and give descent speed-up :)

  • T5v1: t5-small, t5-base, t5-large
  • T5v1_1: google/t5-v1_1-small, google/t5-v1_1-base
  • MT5: google/mt5-small, google/mt5-base

google/t5-v1_1-large and google/mt5-large should also work, will confirm after running few experiments.

One interesting observation,
For inference, the t5-base fine-tuned with fp16 and evaluated in fp32 is faster than pre-trained t5-base evaluated in fp16. See this colab

Update:
google/t5-v1_1-large still gives nan loss after about 200 steps

@patrickvonplaten
Copy link
Contributor Author

Great work! We should also share those results on the forum: https://discuss.huggingface.co/ :-)

@patil-suraj
Copy link
Contributor

Hi @exelents

To answer your question,
as mentioned above these changes will enable fp16 for all small and base version with apex 01 and native amp.
For large models, I only tested it for inference, and it works. Right now I'm training large models and will report the results here.

DeepSpeed handles it's own fp16 and I don't know all the details about it, so won't be able to help there at the moment. @stas00 might have some ideas as he's working with deepspeed.

To sum up, this fix works with apex 01 and native amp with Seq2SeqTrainer for training and with .half for inference.

@stas00
Copy link
Contributor

stas00 commented Jan 11, 2021

DeepSpeed handles it's own fp16 and I don't know all the details about it, so won't be able to help there at the moment. @stas00 might have some ideas as he's working with deepspeed.

I would like the DeepSpeed integration to be merged and then anybody can start experimenting and seeing what else might be needed to be tweaked. To start with I've been primarily focusing on training/eval just working. The next stage would be using and tuning up.

@mxa4646
Copy link

mxa4646 commented Feb 1, 2021

Hi @patil-suraj
It seems like huggingface still hasn't repaired the FP16 problem in MT5-large or MT5-xl, do you or anynoe else have any plans on it?

@patrickvonplaten
Copy link
Contributor Author

Hey @mxa4646,

T5 was never made to be fully compatible with FP16, it was trained using bfloat16, which has a different range than PyTorch's fp16. There is a good chance though that training T5 with deepspeed and fp16 will work!

@dorost1234
Copy link

Hi
I am training mt5-small with deepspeed, with fp16, and I am getting always nan, so far could not managed to make it work, do you mind to share how you set parameters to make it work? I am having a hard time with this and kindly appreciate your help @patrickvonplaten

@stas00
Copy link
Contributor

stas00 commented Mar 20, 2021

T5 was never made to be fully compatible with FP16, it was trained using bfloat16,

Thank you for this insight, @patrickvonplaten - I didn't know that!

I was reading up on bfloat16 for a related issue #10816 and it looks like the main issue is that whenever one does an aggregate operation on big numbers in bfloat16 or fp16 - the accumulate needs to be in fp32. So for example the fix applied here: #10815 - so perhaps it is possible to identify such operations and change them to some_torch_operator(..., , dtype=torch.float32) so most of the math will still be fp16, but there will be no overflow. And it won't impact the normal fp32 logic, as it'd already be of the same type. And this operation doesn't take much extra memory (other than doubling of the resulting variable size).

But here it sounds like the problem is different and it's that bfloat16 may not convert to the same value in fp16. I wonder if someone tried to convert the weights and compare the difference.

Perhaps it's enough to take the models and finetune them on the same data but in mixed precision and perhaps it'd rectify its level of precision.

@Oxi84
Copy link

Oxi84 commented Nov 21, 2021

I tried the T5-large in fp16 and it is slower which is really strange. For everything else the same for the same test data i get 5.62 sec with Fp32 and 6.95 sec for Fp16. However fp16 uses almost 50% less memory.

@pmollerus23
Copy link
Contributor

Has this model been implemented for PyTorch yet?

@samedii
Copy link

samedii commented Apr 30, 2024

This becomes an issue in PixArt Sigma now too.

Edit: It is not an issue for PixArt Sigma. Just made a mistake with not providing masked attention.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Projects
None yet
Development

No branches or pull requests

8 participants