From f10873b7ed77e3c59a4f2e4b73e8cd40711c9151 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 30 Apr 2021 11:15:46 -0700 Subject: [PATCH] [debug utils] activation/weights underflow/overflow detector (#11274) * sync * add activation overflow debug utility * cleanup * document detect_overflow * import torch * add deprecation warning * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * convert to rst, add note * add class * fix docs * improve the doc * rework to dump a lot more info about each frame * complete expansion * cleanup * format * cleanup * doesn't have to be transformers * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * wrap long line * style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/debugging.rst | 295 ++++++++++++++++++ docs/source/index.rst | 1 + docs/source/internal/trainer_utils.rst | 8 +- src/transformers/debug_utils.py | 326 ++++++++++++++++++++ src/transformers/tokenization_utils_base.py | 2 +- src/transformers/trainer.py | 8 +- src/transformers/training_args.py | 36 ++- 7 files changed, 668 insertions(+), 8 deletions(-) create mode 100644 docs/source/debugging.rst create mode 100644 src/transformers/debug_utils.py diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst new file mode 100644 index 00000000000000..b13dc1a5e77746 --- /dev/null +++ b/docs/source/debugging.rst @@ -0,0 +1,295 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + + + +Debugging +======================================================================================================================= + +Underflow and Overflow Detection +----------------------------------------------------------------------------------------------------------------------- + +.. note:: + + This feature is currently available for PyTorch-only. + +.. note:: + + This feature can be used with any ``nn.Module``-based model + +If you start getting ``loss=NaN`` or the model inhibits some other abnormal behavior due to ``inf`` or ``nan`` in +activations or weights one needs to discover where the first underflow or overflow happens and what led to it. Luckily +you can accomplish that easily by activating a special module that will do the detection automatically. + +If you're using :class:`~transformers.Trainer`, you just need to add: + +.. code-block:: bash + + --debug underflow_overflow + +to the normal command line arguments, or pass ``debug="underflow_overflow"`` when creating the +:class:`~transformers.TrainingArguments` object. + +If you're using your own training loop or another Trainer you can accomplish the same with: + +.. code-block:: python + + from .debug_utils import DebugUnderflowOverflow + debug_overflow = DebugUnderflowOverflow(model) + +:class:`~transformers.debug_utils.DebugUnderflowOverflow` inserts hooks into the model that immediately after each +forward call will test input and output variables and also the corresponding module's weights. As soon as ``inf`` or +``nan`` is detected in at least one element of the activations or weights, the program will assert and print a report +like this (this was caught with ``google/mt5-small`` under fp16 mixed precision): + +.. code-block:: + + Detected inf/nan during batch_number=0 + Last 21 forward frames: + abs min abs max metadata + encoder.block.1.layer.1.DenseReluDense.dropout Dropout + 0.00e+00 2.57e+02 input[0] + 0.00e+00 2.85e+02 output + [...] + encoder.block.2.layer.0 T5LayerSelfAttention + 6.78e-04 3.15e+03 input[0] + 2.65e-04 3.42e+03 output[0] + None output[1] + 2.25e-01 1.00e+04 output[2] + encoder.block.2.layer.1.layer_norm T5LayerNorm + 8.69e-02 4.18e-01 weight + 2.65e-04 3.42e+03 input[0] + 1.79e-06 4.65e+00 output + encoder.block.2.layer.1.DenseReluDense.wi_0 Linear + 2.17e-07 4.50e+00 weight + 1.79e-06 4.65e+00 input[0] + 2.68e-06 3.70e+01 output + encoder.block.2.layer.1.DenseReluDense.wi_1 Linear + 8.08e-07 2.66e+01 weight + 1.79e-06 4.65e+00 input[0] + 1.27e-04 2.37e+02 output + encoder.block.2.layer.1.DenseReluDense.dropout Dropout + 0.00e+00 8.76e+03 input[0] + 0.00e+00 9.74e+03 output + encoder.block.2.layer.1.DenseReluDense.wo Linear + 1.01e-06 6.44e+00 weight + 0.00e+00 9.74e+03 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense + 1.79e-06 4.65e+00 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.dropout Dropout + 3.18e-04 6.27e+04 input[0] + 0.00e+00 inf output + +The example output has been trimmed in the middle for brevity. + +The second column shows the value of the absolute largest element, so if you have a closer look at the last few frames, +the inputs and outputs were in the range of ``1e4``. So when this training was done under fp16 mixed precision the very +last step overflowed (since under ``fp16`` the largest number before ``inf`` is ``64e3``). To avoid overflows under +``fp16`` the activations must remain way below ``1e4``, because ``1e4 * 1e4 = 1e8`` so any matrix multiplication with +large activations is going to lead to a numerical overflow condition. + +At the very start of the trace you can discover at which batch number the problem occurred (here ``Detected inf/nan +during batch_number=0`` means the problem occurred on the first batch). + +Each reported frame starts by declaring the fully qualified entry for the corresponding module this frame is reporting +for. If we look just at this frame: + +.. code-block:: + + encoder.block.2.layer.1.layer_norm T5LayerNorm + 8.69e-02 4.18e-01 weight + 2.65e-04 3.42e+03 input[0] + 1.79e-06 4.65e+00 output + +Here, ``encoder.block.2.layer.1.layer_norm`` indicates that it was a layer norm for the first layer, of the second +block of the encoder. And the specific calls of the ``forward`` is ``T5LayerNorm``. + +Let's look at the last few frames of that report: + +.. code-block:: + + Detected inf/nan during batch_number=0 + Last 21 forward frames: + abs min abs max metadata + [...] + encoder.block.2.layer.1.DenseReluDense.wi_0 Linear + 2.17e-07 4.50e+00 weight + 1.79e-06 4.65e+00 input[0] + 2.68e-06 3.70e+01 output + encoder.block.2.layer.1.DenseReluDense.wi_1 Linear + 8.08e-07 2.66e+01 weight + 1.79e-06 4.65e+00 input[0] + 1.27e-04 2.37e+02 output + encoder.block.2.layer.1.DenseReluDense.wo Linear + 1.01e-06 6.44e+00 weight + 0.00e+00 9.74e+03 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense + 1.79e-06 4.65e+00 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.dropout Dropout + 3.18e-04 6.27e+04 input[0] + 0.00e+00 inf output + +The last frame reports for ``Dropout.forward`` function with the first entry for the only input and the second for the +only output. You can see that it was called from an attribute ``dropout`` inside ``DenseReluDense`` class. We can see +that it happened during the first layer, of the 2nd block, during the very first batch. Finally, the absolute largest +input elements was ``6.27e+04`` and same for the output was ``inf``. + +You can see here, that ``T5DenseGatedGeluDense.forward`` resulted in output activations, whose absolute max value was +around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have ``Dropout`` which renormalizes +the weights, after it zeroed some of the elements, which pushes the absolute max value to more than 64K, and we get an +overlow (``inf``). + +As you can see it's the previous frames that we need to look into when the numbers start going into very large for fp16 +numbers. + +Let's match the report to the code from ``models/t5/modeling_t5.py``: + +.. code-block:: python + + class T5DenseGatedGeluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.gelu_act = ACT2FN["gelu_new"] + + def forward(self, hidden_states): + hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + +Now it's easy to see the ``dropout`` call, and all the previous calls as well. + +Since the detection is happening in a forward hook, these reports are printed immediately after each ``forward`` +returns. + +Going back to the full report, to act on it and to fix the problem, we need to go a few frames up where the numbers +started to go up and most likely switch to the ``fp32`` mode here, so that the numbers don't overflow when multiplied +or summed up. Of course, there might be other solutions. For example, we could turn off ``amp`` temporarily if it's +enabled, after moving the original ``forward`` into a helper wrapper, like so: + +.. code-block:: python + + def _forward(self, hidden_states): + hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + import torch + def forward(self, hidden_states): + if torch.is_autocast_enabled(): + with torch.cuda.amp.autocast(enabled=False): + return self._forward(hidden_states) + else: + return self._forward(hidden_states) + +Since the automatic detector only reports on inputs and outputs of full frames, once you know where to look, you may +want to analyse the intermediary stages of any specific ``forward`` function as well. In such a case you can use the +``detect_overflow`` helper function to inject the detector where you want it, for example: + +.. code-block:: python + + from debug_utils import detect_overflow + + class T5LayerFF(nn.Module): + [...] + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + detect_overflow(forwarded_states, "after layer_norm") + forwarded_states = self.DenseReluDense(forwarded_states) + detect_overflow(forwarded_states, "after DenseReluDense") + return hidden_states + self.dropout(forwarded_states) + +You can see that we added 2 of these and now we track if ``inf`` or ``nan`` for ``forwarded_states`` was detected +somewhere in between. + +Actually, the detector already reports these because each of the calls in the example above is a `nn.Module``, but +let's say if you had some local direct calculations this is how you'd do that. + +Additionally, if you're instantiating the debugger in your own code, you can adjust the number of frames printed from +its default, e.g.: + +.. code-block:: python + + from .debug_utils import DebugUnderflowOverflow + debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100) + +Specific batch absolute mix and max value tracing +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The same debugging class can be used for per-batch tracing with the underflow/overflow detection feature turned off. + +Let's say you want to watch the absolute min and max values for all the ingredients of each ``forward`` call of a given +batch, and only do that for batches 1 and 3. Then you instantiate this class as: + +.. code-block:: python + + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3]) + +And now full batches 1 and 3 will be traced using the same format as the underflow/overflow detector does. + +Batches are 0-indexed. + +This is helpful if you know that the program starts misbehaving after a certain batch number, so you can fast-forward +right to that area. Here is a sample truncated output for such configuration: + +.. code-block:: + + *** Starting batch number=1 *** + abs min abs max metadata + shared Embedding + 1.01e-06 7.92e+02 weight + 0.00e+00 2.47e+04 input[0] + 5.36e-05 7.92e+02 output + [...] + decoder.dropout Dropout + 1.60e-07 2.27e+01 input[0] + 0.00e+00 2.52e+01 output + decoder T5Stack + not a tensor output + lm_head Linear + 1.01e-06 7.92e+02 weight + 0.00e+00 1.11e+00 input[0] + 6.06e-02 8.39e+01 output + T5ForConditionalGeneration + not a tensor output + + *** Starting batch number=3 *** + abs min abs max metadata + shared Embedding + 1.01e-06 7.92e+02 weight + 0.00e+00 2.78e+04 input[0] + 5.36e-05 7.92e+02 output + [...] + +Here you will get a huge number of frames dumped - as many as there were forward calls in your model, so it may or may +not what you want, but sometimes it can be easier to use for debugging purposes than a normal debugger. For example, if +a problem starts happening at batch number 150. So you can dump traces for batches 149 and 150 and compare where +numbers started to diverge. + +You can also specify the batch number after which to stop the training, with: + +.. code-block:: python + + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3) diff --git a/docs/source/index.rst b/docs/source/index.rst index c6c9afbfd7e7e9..083b50ea2677c4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -405,6 +405,7 @@ Flax), PyTorch, and/or TensorFlow. add_new_model fast_tokenizers testing + debugging serialization .. toctree:: diff --git a/docs/source/internal/trainer_utils.rst b/docs/source/internal/trainer_utils.rst index c649eb3ab4e4ff..65720d15bafcc4 100644 --- a/docs/source/internal/trainer_utils.rst +++ b/docs/source/internal/trainer_utils.rst @@ -1,4 +1,4 @@ -.. +.. Copyright 2020 The HuggingFace Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with @@ -46,3 +46,9 @@ Distributed Evaluation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.HfArgumentParser + + +Debug Utilities +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.debug_utils.DebugUnderflowOverflow diff --git a/src/transformers/debug_utils.py b/src/transformers/debug_utils.py new file mode 100644 index 00000000000000..45384a80134ba1 --- /dev/null +++ b/src/transformers/debug_utils.py @@ -0,0 +1,326 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +from .file_utils import ExplicitEnum, is_torch_available +from .utils import logging + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +class DebugUnderflowOverflow: + """ + This debug class helps detect and understand where the model starts getting very large or very small, and more + importantly ``nan`` or ``inf`` weight and activation elements. + + There are 2 working modes: + + 1. Underflow/overflow detection (default) + 2. Specific batch absolute min/max tracing without detection + + Mode 1: Underflow/overflow detection + + To activate the underflow/overflow detection, initialize the object with the model :: + + debug_overflow = DebugUnderflowOverflow(model) + + then run the training as normal and if ``nan`` or ``inf`` gets detected in at least one of the weight, input or + output elements this module will throw an exception and will print ``max_frames_to_save`` frames that lead to this + event, each frame reporting + + 1. the fully qualified module name plus the class name whose ``forward`` was run + 2. the absolute min and max value of all elements for each module weights, and the inputs and output + + For example, here is the header and the last few frames in detection report for ``google/mt5-small`` run in fp16 mixed precision :: + + Detected inf/nan during batch_number=0 + Last 21 forward frames: + abs min abs max metadata + [...] + encoder.block.2.layer.1.DenseReluDense.wi_0 Linear + 2.17e-07 4.50e+00 weight + 1.79e-06 4.65e+00 input[0] + 2.68e-06 3.70e+01 output + encoder.block.2.layer.1.DenseReluDense.wi_1 Linear + 8.08e-07 2.66e+01 weight + 1.79e-06 4.65e+00 input[0] + 1.27e-04 2.37e+02 output + encoder.block.2.layer.1.DenseReluDense.wo Linear + 1.01e-06 6.44e+00 weight + 0.00e+00 9.74e+03 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense + 1.79e-06 4.65e+00 input[0] + 3.18e-04 6.27e+04 output + encoder.block.2.layer.1.dropout Dropout + 3.18e-04 6.27e+04 input[0] + 0.00e+00 inf output + + You can see here, that ``T5DenseGatedGeluDense.forward`` resulted in output activations, whose absolute max value + was around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have ``Dropout`` which + renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than + 64K, and we get an overlow. + + As you can see it's the previous frames that we need to look into when the numbers start going into very large for + fp16 numbers. + + The tracking is done in a forward hook, which gets invoked immediately after ``forward`` has completed. + + By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :: + + debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100) + + + + Mode 2. Specific batch absolute min/max tracing without detection + + The second work mode is per-batch tracing with the underflow/overflow detection feature turned off. + + Let's say you want to watch the absolute min and max values for all the ingredients of each ``forward`` call of a + given batch, and only do that for batches 1 and 3. Then you instantiate this class as :: + + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3]) + + And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed. + + This is helpful if you know that the program starts misbehaving after a certain batch number, so you can + fast-forward right to that area. + + + + You can also specify the batch number after which to stop the training, with :: + + debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1,3], abort_after_batch_num=3) + + This feature is mainly useful in the tracing mode, but you can use it for any more. + + Args: + model (:obj:`nn.Module`): + The model to debug. + max_frames_to_save (:obj:`int`, `optional`, defaults to 21): + How many frames back to record + trace_batch_nums(:obj:`List[int]`, `optional`, defaults to ``[]``): + Which batch numbers to trace (turns detection off) + abort_after_batch_num (:obj:`int`, `optional`, defaults to :obj:`None`): + Whether to abort after a certain batch number has finished + + """ + + def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None): + self.model = model + self.trace_batch_nums = trace_batch_nums + self.abort_after_batch_num = abort_after_batch_num + + # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence + self.frames = collections.deque([], max_frames_to_save) + self.frame = [] + self.batch_number = 0 + self.total_calls = 0 + self.detected_overflow = False + self.prefix = " " + + self.analyse_model() + + self.register_forward_hook() + + def save_frame(self, frame=None): + if frame is not None: + self.expand_frame(frame) + self.frames.append("\n".join(self.frame)) + self.frame = [] # start a new frame + + def expand_frame(self, line): + self.frame.append(line) + + def trace_frames(self): + print("\n".join(self.frames)) + self.frames = [] + + def reset_saved_frames(self): + self.frames = [] + + def dump_saved_frames(self): + print(f"\nDetected inf/nan during batch_number={self.batch_number}") + print(f"Last {len(self.frames)} forward frames:") + print(f"{'abs min':8} {'abs max':8} metadata") + print("\n".join(self.frames)) + print("\n\n") + self.frames = [] + + def analyse_model(self): + # extract the fully qualified module names, to be able to report at run time. e.g.: + # encoder.block.2.layer.0.SelfAttention.o + # + # for shared weights only the first shared module name will be registered + self.module_names = {m: name for name, m in self.model.named_modules()} + # self.longest_module_name = max(len(v) for v in self.module_names.values()) + + def analyse_variable(self, var, ctx): + if torch.is_tensor(var): + self.expand_frame(get_abs_min_max(var, ctx)) + if detect_overflow(var, ctx): + self.detected_overflow = True + elif var is None: + self.expand_frame(f"{'None':>17} {ctx}") + else: + self.expand_frame(f"{'not a tensor':>17} {ctx}") + + def batch_start_frame(self): + self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***") + self.expand_frame(f"{'abs min':8} {'abs max':8} metadata") + + def batch_end_frame(self): + self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n") + + def create_frame(self, module, input, output): + self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}") + + # params + for name, p in module.named_parameters(recurse=False): + self.analyse_variable(p, name) + + # inputs + if isinstance(input, tuple): + for i, x in enumerate(input): + self.analyse_variable(x, f"input[{i}]") + else: + self.analyse_variable(input, "input") + + # outputs + if isinstance(output, tuple): + for i, x in enumerate(output): + # possibly a tuple of tuples + if isinstance(x, tuple): + for j, y in enumerate(x): + self.analyse_variable(y, f"output[{i}][{j}]") + else: + self.analyse_variable(x, f"output[{i}]") + else: + self.analyse_variable(output, "output") + + self.save_frame() + + def register_forward_hook(self): + self.model.apply(self._register_forward_hook) + + def _register_forward_hook(self, module): + module.register_forward_hook(self.forward_hook) + + def forward_hook(self, module, input, output): + # - input is a tuple of packed inputs (could be non-Tensors) + # - output could be a Tensor or a tuple of Tensors and non-Tensors + + last_frame_of_batch = False + + trace_mode = True if self.batch_number in self.trace_batch_nums else False + if trace_mode: + self.reset_saved_frames() + + if self.total_calls == 0: + self.batch_start_frame() + self.total_calls += 1 + + # count batch numbers - the very first forward hook of the batch will be called when the + # batch completes - i.e. it gets called very last - we know this batch has finished + if module == self.model: + self.batch_number += 1 + last_frame_of_batch = True + + self.create_frame(module, input, output) + + # if last_frame_of_batch: + # self.batch_end_frame() + + if trace_mode: + self.trace_frames() + + if last_frame_of_batch: + self.batch_start_frame() + + if self.detected_overflow and not trace_mode: + self.dump_saved_frames() + + # now we can abort, as it's pointless to continue running + raise ValueError( + "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. " + "Please scroll up above this traceback to see the activation values prior to this event." + ) + + # abort after certain batch if requested to do so + if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num: + raise ValueError( + f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to `abort_after_batch_num={self.abort_after_batch_num}` arg" + ) + + +def get_abs_min_max(var, ctx): + abs_var = var.abs() + return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}" + + +def detect_overflow(var, ctx): + """ + Report of the tensor contains any ``nan`` and ``inf`` entries. + + This is useful for detecting overflows/underflows and best to call right after the function that did some math that + modified the variable in question. + + The function contains a few other helper features that you can enable and tweak directly if you want to track + various other things. + + Args: + var: tensor variable to check + ctx: the message to print as a context + + Return: + True if ``inf`` or ``nan`` was detected, False otherwise + """ + detected = False + if torch.isnan(var).any().item(): + detected = True + print(f"{ctx} has nans") + if torch.isinf(var).any().item(): + detected = True + print(f"{ctx} has infs") + + # if needed to monitor large elements can enable the following + if 0: # and detected: + n100 = var[torch.ge(var.abs(), 100)] + if n100.numel() > 0: + print(f"{ctx}: n100={n100.numel()}") + n1000 = var[torch.ge(var.abs(), 1000)] + if n1000.numel() > 0: + print(f"{ctx}: n1000={n1000.numel()}") + n10000 = var[torch.ge(var.abs(), 10000)] + if n10000.numel() > 0: + print(f"{ctx}: n10000={n10000.numel()}") + + if 0: + print(f"min={var.min():9.2e} max={var.max():9.2e}") + + if 0: + print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})") + + return detected + + +class DebugOption(ExplicitEnum): + UNDERFLOW_OVERFLOW = "underflow_overflow" + TPU_METRICS_DEBUG = "tpu_metrics_debug" diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index eed034256617e9..abb62a9bf598b1 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -3154,7 +3154,7 @@ def clean_up_tokenization(out_string: str) -> str: def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool): """ - Depending on the input and internal state we might trigger a warning about a sequence that is too long for it's + Depending on the input and internal state we might trigger a warning about a sequence that is too long for its corresponding model Args: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5565bdb2eab4fb..eebea8b4a2dd72 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -59,6 +59,7 @@ from . import __version__ from .configuration_utils import PretrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from .debug_utils import DebugOption, DebugUnderflowOverflow from .dependency_versions_check import dep_version_check from .file_utils import ( CONFIG_NAME, @@ -1078,6 +1079,9 @@ def train( num_train_epochs = int(args.num_train_epochs) num_update_steps_per_epoch = max_steps + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE if args.deepspeed: deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( @@ -1301,7 +1305,7 @@ def train( self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) - if args.tpu_metrics_debug or args.debug: + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: if is_torch_tpu_available(): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) @@ -1905,7 +1909,7 @@ def evaluate( self.log(output.metrics) - if self.args.tpu_metrics_debug or self.args.debug: + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 37572c8705f408..6f1794315080ab 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -19,6 +19,7 @@ from enum import Enum from typing import Any, Dict, List, Optional +from .debug_utils import DebugOption from .file_utils import ( cached_property, is_sagemaker_dp_enabled, @@ -191,8 +192,6 @@ class TrainingArguments: Rank of the process during distributed training. tpu_num_cores (:obj:`int`, `optional`): When training on TPU, the number of TPU cores (automatically passed by launcher script). - debug (:obj:`bool`, `optional`, defaults to :obj:`False`): - When training on TPU, whether to print debug metrics or not. dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) or not. @@ -274,6 +273,16 @@ class TrainingArguments: The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded labels are changed from 0s and 1s to :obj:`label_smoothing_factor/num_labels` and :obj:`1 - label_smoothing_factor + label_smoothing_factor/num_labels` respectively. + debug (:obj:`str` or list of :class:`~transformers.debug_utils.DebugOption`, `optional`, defaults to :obj:`""`): + Enable one or more debug features. This is an experimental feature. + + Possible options are: + + - :obj:`"underflow_overflow"`: detects overflow in model's input/outputs and reports the last frames that + led to the event + - :obj:`"tpu_metrics_debug"`: print debug metrics on TPU + + The options should be separated by whitespaces. adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of :class:`~transformers.AdamW`. @@ -437,9 +446,18 @@ class TrainingArguments: ) tpu_metrics_debug: bool = field( default=False, - metadata={"help": "Deprecated, the use of `--debug` is preferred. TPU: Whether to print debug metrics"}, + metadata={ + "help": "Deprecated, the use of `--debug tpu_metrics_debug` is preferred. TPU: Whether to print debug metrics" + }, + ) + debug: str = field( + default="", + metadata={ + "help": "Whether or not to enable debug mode. Current options: " + "`underflow_overflow` (Detect underflow and overflow in activations and weights), " + "`tpu_metrics_debug` (print debug metrics on TPU)." + }, ) - debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"}) dataloader_drop_last: bool = field( default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} @@ -631,6 +649,16 @@ def __post_init__(self): elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.") + if self.tpu_metrics_debug: + warnings.warn( + "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--debug tpu_metrics_debug` instead", + FutureWarning, + ) + self.debug += " tpu_metrics_debug" + self.tpu_metrics_debug = False + if isinstance(self.debug, str): + self.debug = [DebugOption(s) for s in self.debug.split()] + if self.deepspeed: # - must be run very last in arg parsing, since it will use a lot of these settings. # - must be run before the model is created.