Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[debug utils] activation/weights underflow/overflow detector #11274

Merged
merged 26 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions docs/source/debugging.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
<!---
Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Debugging


## Activations Overflow

If you start getting `loss=NaN` or the model inhibits some other abnormal behavior due to `inf`s or `nan`s one needs to discover where the first overflow happens and what led to it. Luckily you can accomplish that easily by activating a special module that will do the detection automatically.
stas00 marked this conversation as resolved.
Show resolved Hide resolved

If you're using the HuggingFace `Trainer`, you just need to add:
stas00 marked this conversation as resolved.
Show resolved Hide resolved

```bash
--debug activation_overflow
```
to the normal command line arguments, or pass `debug="activation_overflow"` when creating the `Trainer` object.

If you're using your own trainer you can just do:

```python
from .debug_utils import DebugActivationOverflow
debug_overflow = DebugActivationOverflow(model)
```

`DebugActivationOverflow` inserts hooks into the model that will test each input and output and as soon as `inf` or `nan` is detected in at least one element, the program will assert and print a report like this:

```
< [0] encoder.block.2.layer.1.DenseReluDense.wo: Linear: output has infs


last 40 frames:
abs_max= 5.96e+02 < [0] encoder.block.1.layer.1.DenseReluDense.dropout: Dropout: output
abs_max= 5.96e+02 > [0] encoder.block.1.layer.1.DenseReluDense.wo: Linear: input[0]
abs_max= 3.17e+03 < [0] encoder.block.1.layer.1.DenseReluDense.wo: Linear: output
abs_max= 2.57e+00 > [0] encoder.block.1.layer.1.DenseReluDense: T5DenseGatedGeluDense: input[0]
abs_max= 3.17e+03 < [0] encoder.block.1.layer.1.DenseReluDense: T5DenseGatedGeluDense: output
abs_max= 3.17e+03 > [0] encoder.block.1.layer.1.dropout: Dropout: input[0]
abs_max= 3.52e+03 < [0] encoder.block.1.layer.1.dropout: Dropout: output
abs_max= 1.58e+03 > [0] encoder.block.1.layer.1: T5LayerFF: input[0]
abs_max= 4.04e+03 < [0] encoder.block.1.layer.1: T5LayerFF: output
abs_max= 1.51e+03 > [0] encoder.block.1: T5Block: input[0]
abs_max= 4.04e+03 < [0] encoder.block.1: T5Block: output[0]
abs_max= 1.00e+04 < [0] encoder.block.1: T5Block: output[2]
abs_max= 4.04e+03 > [0] encoder.block.2.layer.0.layer_norm: T5LayerNorm: input[0]
abs_max= 2.69e+00 < [0] encoder.block.2.layer.0.layer_norm: T5LayerNorm: output
abs_max= 2.69e+00 > [0] encoder.block.2.layer.0.SelfAttention.q: Linear: input[0]
abs_max= 1.13e+00 < [0] encoder.block.2.layer.0.SelfAttention.q: Linear: output
abs_max= 2.69e+00 > [0] encoder.block.2.layer.0.SelfAttention.k: Linear: input[0]
abs_max= 1.69e+01 < [0] encoder.block.2.layer.0.SelfAttention.k: Linear: output
abs_max= 2.69e+00 > [0] encoder.block.2.layer.0.SelfAttention.v: Linear: input[0]
abs_max= 8.92e+00 < [0] encoder.block.2.layer.0.SelfAttention.v: Linear: output
abs_max= 7.59e+00 > [0] encoder.block.2.layer.0.SelfAttention.o: Linear: input[0]
abs_max= 2.83e+02 < [0] encoder.block.2.layer.0.SelfAttention.o: Linear: output
abs_max= 2.69e+00 > [0] encoder.block.2.layer.0.SelfAttention: T5Attention: input[0]
abs_max= 2.83e+02 < [0] encoder.block.2.layer.0.SelfAttention: T5Attention: output[0]
abs_max= 1.00e+04 < [0] encoder.block.2.layer.0.SelfAttention: T5Attention: output[2]
abs_max= 2.83e+02 > [0] encoder.block.2.layer.0.dropout: Dropout: input[0]
abs_max= 3.14e+02 < [0] encoder.block.2.layer.0.dropout: Dropout: output
abs_max= 4.04e+03 > [0] encoder.block.2.layer.0: T5LayerSelfAttention: input[0]
abs_max= 4.06e+03 < [0] encoder.block.2.layer.0: T5LayerSelfAttention: output[0]
abs_max= 1.00e+04 < [0] encoder.block.2.layer.0: T5LayerSelfAttention: output[2]
abs_max= 4.06e+03 > [0] encoder.block.2.layer.1.layer_norm: T5LayerNorm: input[0]
abs_max= 6.00e+00 < [0] encoder.block.2.layer.1.layer_norm: T5LayerNorm: output
abs_max= 6.00e+00 > [0] encoder.block.2.layer.1.DenseReluDense.wi_0: Linear: input[0]
abs_max= 5.18e+01 < [0] encoder.block.2.layer.1.DenseReluDense.wi_0: Linear: output
abs_max= 6.00e+00 > [0] encoder.block.2.layer.1.DenseReluDense.wi_1: Linear: input[0]
abs_max= 3.14e+02 < [0] encoder.block.2.layer.1.DenseReluDense.wi_1: Linear: output
abs_max= 1.62e+04 > [0] encoder.block.2.layer.1.DenseReluDense.dropout: Dropout: input[0]
abs_max= 1.80e+04 < [0] encoder.block.2.layer.1.DenseReluDense.dropout: Dropout: output
abs_max= 1.80e+04 > [0] encoder.block.2.layer.1.DenseReluDense.wo: Linear: input[0]
abs_max= inf < [0] encoder.block.2.layer.1.DenseReluDense.wo: Linear: output
```

The left column shows the value of the absolute largest element, so if you have a closer look the last few frames, the inputs and outputs were in the range of 10000. So when this training was done under mixed precision the very last step overflowed (since under `fp16` the largest number before `inf` is `64e3`). To avoid overflows under `fp16` the activations must remain way below `1e4`, because `1e4*1e4 = 1e8` so any matrix multiply with large activations is going to lead to overflow.

The trace then prints the batch number (here `[0]` means the problem occurred on the first batch).

Then comes the fully qualified entry from the `state_dict`, e.g.: `encoder.block.2.layer.0.layer_norm`, so you can easily see where the problem happens and what was happening just before it.

The second to last entry show the name of the class the `forward` belongs to, and whether the report is for an input or an output and its index if either is a tuple. Only tensor variables are reported.

Another shortcut in the first columns is`>` is for input variable, `<` is for output.
stas00 marked this conversation as resolved.
Show resolved Hide resolved

Let's look at:

```
abs_max= 1.62e+04 > [0] encoder.block.2.layer.1.DenseReluDense.dropout: Dropout: input[0]
abs_max= 1.80e+04 < [0] encoder.block.2.layer.1.DenseReluDense.dropout: Dropout: output
```

This is a report for `Dropout.forward` function with the first entry for the only input and the second for the only output. You can see that it was called from an attribute `dropout` inside `DenseReluDense` class. We can see that it happened during the first layer, of the 2nd block, during the very first batch. Finally the absolute largest input elements was `1.62e+04` and same for the output was `1.80e+04`.

Going back to the full report, to act on it and to fix the problem, we need to go a few frames up where the numbers started to go up and most likely switch to the `fp32` mode here, so that the numbers don't overflow when multiplied or summed up. Of course, there might be other solutions.

Since the automatic detector only reports inputs and outputs, once you know where to look, you may want to analyse the intermediary stages of `forward` as well. In such a case you can use the helper function to inject the detector where you want it, for example:

```
from debug_utils import detect_overflow

class T5LayerFF(nn.Module):
[...]
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
detect_overflow(forwarded_states, "after layer_norm")
forwarded_states = self.DenseReluDense(forwarded_states)
detect_overflow(forwarded_states, "after DenseReluDense")
return hidden_states + self.dropout(forwarded_states)
```

You can see that we added 2 of these and now we can know the absolute largest numbers for `forwarded_states` at 2 different stages.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ TensorFlow and/or Flax.
add_new_model
fast_tokenizers
testing
debugging
serialization

.. toctree::
Expand Down
182 changes: 182 additions & 0 deletions src/transformers/debug_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

from .file_utils import ExplicitEnum, is_torch_available
from .utils import logging


if is_torch_available():
import torch


logger = logging.get_logger(__name__)


class DebugActivationOverflow:
"""
This debug class helps detect and understand where the model starts getting ``nan``s or ``inf``s in activation
elements.

To activate, initialize the object with the model ::

debug_overflow = DebugActivationOverflow(model)

then run the training as normal and if any ``nan`` or ``inf`` get detected this module will throw an exception and
will print several dozens of frames that lead to this event, each line reporting:

1. the absolute largest element of either input or output variable
2. the batch number
3. the fully qualified state_dict key of which element it was run for,
4. the class name whose ``forward`` was run
5. and finally whether it was an input or output and its index if it was a tuple.

Args:
model (:obj:`nn.Module`):
the model that fails to train due to ``nan``s or ``inf``s
stas00 marked this conversation as resolved.
Show resolved Hide resolved
max_frames_to_save (:obj:`int`, `optional`):
how many variables and their frames to record back - a few dozens is a good number, defaults to 40):
stas00 marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, model, max_frames_to_save=40):
self.model = model

# keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
self.frames = collections.deque([], max_frames_to_save)
self.save_frames = True
self.step = 0

self.analyse_model()

self.register_forward_hook()

def save_frame(self, frame):
self.frames.append(frame)

def dump_saved_frames_once(self):
# dump the previous frames only once (to help debug)
if self.save_frames:
print(f"\n\nlast {len(self.frames)} frames:")
print("\n".join(self.frames))
print("\n\n")
self.save_frames = False

def analyse_model(self):
# extract the fully qualified module names, to be able to report at run time. e.g.:
# encoder.block.2.layer.0.SelfAttention.o
#
# for shared weights only the first shared module name will be registered
self.module_names = {m: name for name, m in self.model.named_modules()}

def analyse_variable(self, var, ctx):
if torch.is_tensor(var):
if self.save_frames:
self.save_frame(get_abs_max(var, ctx))

if detect_overflow(var, ctx):
self.dump_saved_frames_once()

# now we can die, as it's pointless to continue running
raise ValueError(
"DebugActivationOverflow: inf/nan detected, aborting as there is no point running further. "
"Please scroll up above this traceback to see the activation values prior to this event."
)

def register_forward_hook(self):
self.model.apply(self._register_forward_hook)

def _register_forward_hook(self, module):
module.register_forward_hook(self.forward_hook)

def forward_hook(self, module, input, output):
# - input is a tuple of packed inputs (could be non-Tensors)
# - output could be a Tensor or a tuple of Tensors and non-Tensors

# count at which step we are (batch number)
if module == self.model:
self.step += 1

ctx = f"[{self.step}] {self.module_names[module]}: {module.__class__.__name__}"

for i, x in enumerate(input):
self.analyse_variable(x, f"> {ctx}: input[{i}]")

if isinstance(output, tuple):
for i, x in enumerate(output):
# possibly a tuple of tuples
if isinstance(x, tuple):
for j, y in enumerate(x):
self.analyse_variable(y, f"< {ctx}: output[{i}][{j}]")
else:
self.analyse_variable(x, f"< {ctx}: output[{i}]")
else:
self.analyse_variable(output, f"< {ctx}: output")


def get_abs_max(var, ctx):
abs_max = max(abs(var.min()), abs(var.max()))
return f"abs_max={abs_max:9.2e} {ctx}"


def get_min_max(var, ctx):
return f"min={var.min():9.2e} max={var.max():9.2e} {ctx}"


def detect_overflow(var, ctx):
"""
Report the count of ``nan`` and ``inf`` entries in the tensor.

This is useful for detecting overflows/underflows and best to call right after the function that did some math that
modified the variable in question.

Args:
var: tensor variable to check
ctx: the message to print as a context

Return:
True if inf or nan was detected, False otherwise
"""
detected = False
if torch.isnan(var).any().item():
detected = True
print(f"{ctx} has nans")
if torch.isinf(var).any().item():
detected = True
print(f"{ctx} has infs")

# if needed to monitor large elements can enable the following
if 0: # and detected:
n100 = var[torch.ge(var.abs(), 100)]
if n100.numel() > 0:
print(f"{ctx}: n100={n100.numel()}")
n1000 = var[torch.ge(var.abs(), 1000)]
if n1000.numel() > 0:
print(f"{ctx}: n1000={n1000.numel()}")
n10000 = var[torch.ge(var.abs(), 10000)]
if n10000.numel() > 0:
print(f"{ctx}: n10000={n10000.numel()}")

if 0:
print(f"min={var.min():9.2e} max={var.max():9.2e}")

if 0:
print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")

return detected


class DebugOption(ExplicitEnum):
ACIVATION_OVERFLOW = "activation_overflow"
TPU_METRICS_DEBUG = "tpu_metrics_debug"
2 changes: 1 addition & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3139,7 +3139,7 @@ def clean_up_tokenization(out_string: str) -> str:

def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):
"""
Depending on the input and internal state we might trigger a warning about a sequence that is too long for it's
Depending on the input and internal state we might trigger a warning about a sequence that is too long for its
corresponding model

Args:
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from torch.utils.data.sampler import RandomSampler, SequentialSampler

from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugActivationOverflow, DebugOption
from .dependency_versions_check import dep_version_check
from .file_utils import (
WEIGHTS_NAME,
Expand Down Expand Up @@ -986,6 +987,9 @@ def train(
num_train_epochs = 1
num_update_steps_per_epoch = max_steps

if DebugOption.ACIVATION_OVERFLOW in self.args.debug:
debug_overflow = DebugActivationOverflow(self.model) # noqa

delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if self.args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
Expand Down Expand Up @@ -1206,7 +1210,7 @@ def train(
self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)

if self.args.tpu_metrics_debug or self.args.debug:
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
Expand Down Expand Up @@ -1814,7 +1818,7 @@ def evaluate(
output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
self.log(output.metrics)

if self.args.tpu_metrics_debug or self.args.debug:
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())

Expand Down
Loading