diff --git a/_static/img/thumbnails/cropped/amp.png b/_static/img/thumbnails/cropped/amp.png new file mode 100644 index 00000000000..a6916ce5605 Binary files /dev/null and b/_static/img/thumbnails/cropped/amp.png differ diff --git a/advanced_source/dispatcher.rst b/advanced_source/dispatcher.rst index 23ba0f96be1..4f3b52fea32 100644 --- a/advanced_source/dispatcher.rst +++ b/advanced_source/dispatcher.rst @@ -105,6 +105,8 @@ speaking, the structure of your registrations will look like this: that provides implementations for all basic operators on the XLA dispatch key. +.. _autograd-support: + Adding autograd support ----------------------- @@ -299,6 +301,28 @@ the safest choice for the execution type: at::autocast::cached_cast(exec_type, t1)); } +If your custom op is :ref:`autograd-enabled`, you only need to write and register +an autocast wrapper for the same name onto which the autograd wrapper is registered. +For example, if you wanted an autocast wrapper for the ``myadd`` function shown +in the autograd section, all you'd need is + +.. code-block:: cpp + + Tensor myadd_autocast(const Tensor& self, const Tensor& other) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + return myadd(at::autocast::cached_cast(, self), + at::autocast::cached_cast(, other)); + } + + TORCH_LIBRARY_IMPL(myops, Autocast, m) { + m.impl("myadd", myadd_autocast); + } + +There are no separate gymnastics to make the backward method autocast compatible. +However, the backward method defined in your custom autograd function will run in the same +dtype as autocast sets for the forward method, so you should choose a ```` +suitable for both your forward and backward methods. + Batched ^^^^^^^ diff --git a/recipes_source/recipes/README.txt b/recipes_source/recipes/README.txt index f93ee92c2c6..a182b0a11c5 100644 --- a/recipes_source/recipes/README.txt +++ b/recipes_source/recipes/README.txt @@ -56,3 +56,7 @@ PyTorch Recipes 14. mobile_perf.py PyTorch Mobile Performance Recipes https://pytorch.org/tutorials/recipes/mobile_perf.html + +15. amp_recipe.py + Automatic Mixed Precision + https://pytorch.org/tutorials/recipes/amp_recipe.html diff --git a/recipes_source/recipes/amp_recipe.py b/recipes_source/recipes/amp_recipe.py new file mode 100644 index 00000000000..c1ec52a3883 --- /dev/null +++ b/recipes_source/recipes/amp_recipe.py @@ -0,0 +1,325 @@ +# -*- coding: utf-8 -*- +""" +Automatic Mixed Precision +************************* +**Author**: `Michael Carilli `_ + +`torch.cuda.amp `_ provides convenience methods for mixed precision, +where some operations use the ``torch.float32`` (``float``) datatype and other operations +use ``torch.float16`` (``half``). Some ops, like linear layers and convolutions, +are much faster in ``float16``. Other ops, like reductions, often require the dynamic +range of ``float32``. Mixed precision tries to match each op to its appropriate datatype, +which can reduce your network's runtime and memory footprint. + +Ordinarily, "automatic mixed precision training" uses `torch.cuda.amp.autocast `_ and +`torch.cuda.amp.GradScaler `_ together. + +This recipe measures the performance of a simple network in default precision, +then walks through adding ``autocast`` and ``GradScaler`` to run the same network in +mixed precision with improved performance. + +You may download and run this recipe as a standalone Python script. +The only requirements are Pytorch 1.6+ and a CUDA-capable GPU. + +Mixed precision primarily benefits Tensor Core-enabled architectures (Volta, Turing, Ampere). +This recipe should show significant (2-3X) speedup on those architectures. +On earlier architectures (Kepler, Maxwell, Pascal), you may observe a modest speedup. +Run ``nvidia-smi`` to display your GPU's architecture. +""" + +import torch, time, gc + +# Timing utilities +start_time = None + +def start_timer(): + global start_time + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.synchronize() + start_time = time.time() + +def end_timer_and_print(local_msg): + torch.cuda.synchronize() + end_time = time.time() + print("\n" + local_msg) + print("Total execution time = {:.3f} sec".format(end_time - start_time)) + print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated())) + +########################################################## +# A simple network +# ---------------- +# The following sequence of linear layers and ReLUs should show a speedup with mixed precision. + +def make_model(in_size, out_size, num_layers): + layers = [] + for _ in range(num_layers - 1): + layers.append(torch.nn.Linear(in_size, in_size)) + layers.append(torch.nn.ReLU()) + layers.append(torch.nn.Linear(in_size, out_size)) + return torch.nn.Sequential(*tuple(layers)).cuda() + +########################################################## +# ``batch_size``, ``in_size``, ``out_size``, and ``num_layers`` are chosen to be large enough to saturate the GPU with work. +# Typically, mixed precision provides the greatest speedup when the GPU is saturated. +# Small networks may be CPU bound, in which case mixed precision won't improve performance. +# Sizes are also chosen such that linear layers' participating dimensions are multiples of 8, +# to permit Tensor Core usage on Tensor Core-capable GPUs (see :ref:`Troubleshooting` below). +# +# Exercise: Vary participating sizes and see how the mixed precision speedup changes. + +batch_size = 512 # Try, for example, 128, 256, 513. +in_size = 4096 +out_size = 4096 +num_layers = 3 +num_batches = 50 +epochs = 3 + +# Creates data in default precision. +# The same data is used for both default and mixed precision trials below. +# You don't need to manually change inputs' dtype when enabling mixed precision. +data = [torch.randn(batch_size, in_size, device="cuda") for _ in range(num_batches)] +targets = [torch.randn(batch_size, out_size, device="cuda") for _ in range(num_batches)] + +loss_fn = torch.nn.MSELoss().cuda() + +########################################################## +# Default Precision +# ----------------- +# Without ``torch.cuda.amp``, the following simple network executes all ops in default precision (``torch.float32``): + +net = make_model(in_size, out_size, num_layers) +opt = torch.optim.SGD(net.parameters(), lr=0.001) + +start_timer() +for epoch in range(epochs): + for input, target in zip(data, targets): + output = net(input) + loss = loss_fn(output, target) + loss.backward() + opt.step() + opt.zero_grad() # set_to_none=True here can modestly improve performance +end_timer_and_print("Default precision:") + +########################################################## +# Adding autocast +# --------------- +# Instances of `torch.cuda.amp.autocast `_ +# serve as context managers that allow regions of your script to run in mixed precision. +# +# In these regions, CUDA ops run in a dtype chosen by autocast +# to improve performance while maintaining accuracy. +# See the `Autocast Op Reference `_ +# for details on what precision autocast chooses for each op, and under what circumstances. + +for epoch in range(0): # 0 epochs, this section is for illustration only + for input, target in zip(data, targets): + # Runs the forward pass under autocast. + with torch.cuda.amp.autocast(): + output = net(input) + # output is float16 because linear layers autocast to float16. + assert output.dtype is torch.float16 + + loss = loss_fn(output, target) + # loss is float32 because mse_loss layers autocast to float32. + assert loss.dtype is torch.float32 + + # Exits autocast before backward(). + # Backward passes under autocast are not recommended. + # Backward ops run in the same dtype autocast chose for corresponding forward ops. + loss.backward() + opt.step() + opt.zero_grad() # set_to_none=True here can modestly improve performance + +########################################################## +# Adding GradScaler +# ----------------- +# `Gradient scaling `_ +# helps prevent gradients with small magnitudes from flushing to zero +# ("underflowing") when training with mixed precision. +# +# `torch.cuda.amp.GradScaler `_ +# performs the steps of gradient scaling conveniently. + +# Constructs scaler once, at the beginning of the convergence run, using default args. +# If your network fails to converge with default GradScaler args, please file an issue. +# The same GradScaler instance should be used for the entire convergence run. +# If you perform multiple convergence runs in the same script, each run should use +# a dedicated fresh GradScaler instance. GradScaler instances are lightweight. +scaler = torch.cuda.amp.GradScaler() + +for epoch in range(0): # 0 epochs, this section is for illustration only + for input, target in zip(data, targets): + with torch.cuda.amp.autocast(): + output = net(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales the gradients of the optimizer's assigned params. + # If these gradients do not contain infs or NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(opt) + + # Updates the scale for next iteration. + scaler.update() + + opt.zero_grad() # set_to_none=True here can modestly improve performance + +########################################################## +# All together: "Automatic Mixed Precision" +# ------------------------------------------ +# (The following also demonstrates ``enabled``, an optional convenience argument to ``autocast`` and ``GradScaler``. +# If False, ``autocast`` and ``GradScaler``\ 's calls become no-ops. +# This allows switching between default precision and mixed precision without if/else statements.) + +use_amp = True + +net = make_model(in_size, out_size, num_layers) +opt = torch.optim.SGD(net.parameters(), lr=0.001) +scaler = torch.cuda.amp.GradScaler(enabled=use_amp) + +start_timer() +for epoch in range(epochs): + for input, target in zip(data, targets): + with torch.cuda.amp.autocast(enabled=use_amp): + output = net(input) + loss = loss_fn(output, target) + scaler.scale(loss).backward() + scaler.step(opt) + scaler.update() + opt.zero_grad() # set_to_none=True here can modestly improve performance +end_timer_and_print("Mixed precision:") + +########################################################## +# Inspecting/modifying gradients (e.g., clipping) +# -------------------------------------------------------- +# All gradients produced by ``scaler.scale(loss).backward()`` are scaled. If you wish to modify or inspect +# the parameters' ``.grad`` attributes between ``backward()`` and ``scaler.step(optimizer)``, you should +# unscale them first using `scaler.unscale_(optimizer) `_. + +for epoch in range(0): # 0 epochs, this section is for illustration only + for input, target in zip(data, targets): + with torch.cuda.amp.autocast(): + output = net(input) + loss = loss_fn(output, target) + scaler.scale(loss).backward() + + # Unscales the gradients of optimizer's assigned params in-place + scaler.unscale_(opt) + + # Since the gradients of optimizer's assigned params are now unscaled, clips as usual. + # You may use the same value for max_norm here as you would without gradient scaling. + torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.1) + + scaler.step(opt) + scaler.update() + opt.zero_grad() # set_to_none=True here can modestly improve performance + +########################################################## +# Saving/Resuming +# ---------------- +# To save/resume Amp-enabled runs with bitwise accuracy, use +# `scaler.state_dict `_ and +# `scaler.load_state_dict `_. +# +# When saving, save the scaler state dict alongside the usual model and optimizer state dicts. +# Do this either at the beginning of an iteration before any forward passes, or at the end of +# an iteration after ``scaler.update()``. + +checkpoint = {"model": net.state_dict(), + "optimizer": opt.state_dict(), + "scaler": scaler.state_dict()} +# Write checkpoint as desired, e.g., +# torch.save(checkpoint, "filename") + +########################################################## +# When resuming, load the scaler state dict alongside the model and optimizer state dicts. + +# Read checkpoint as desired, e.g., +# dev = torch.cuda.current_device() +# checkpoint = torch.load("filename", +# map_location = lambda storage, loc: storage.cuda(dev)) +net.load_state_dict(checkpoint["model"]) +opt.load_state_dict(checkpoint["optimizer"]) +scaler.load_state_dict(checkpoint["scaler"]) + +########################################################## +# If a checkpoint was created from a run *without* Amp, and you want to resume training *with* Amp, +# load model and optimizer states from the checkpoint as usual. The checkpoint won't contain a saved scaler state, so +# use a fresh instance of ``GradScaler``. +# +# If a checkpoint was created from a run *with* Amp and you want to resume training *without* Amp, +# load model and optimizer states from the checkpoint as usual, and ignore the saved scaler state. + +########################################################## +# Inference/Evaluation +# -------------------- +# ``autocast`` may be used by itself to wrap inference or evaluation forward passes. ``GradScaler`` is not necessary. + +########################################################## +# .. _advanced-topics: +# +# Advanced topics +# --------------- +# See the `Automatic Mixed Precision Examples `_ for advanced use cases including: +# +# * Gradient accumulation +# * Gradient penalty/double backward +# * Networks with multiple models, optimizers, or losses +# * Multiple GPUs (``torch.nn.DataParallel`` or ``torch.nn.parallel.DistributedDataParallel``) +# * Custom autograd functions (subclasses of ``torch.autograd.Function``) +# +# If you perform multiple convergence runs in the same script, each run should use +# a dedicated fresh GradScaler instance. GradScaler instances are lightweight. +# +# If you're registering a custom C++ op with the dispatcher, see the +# `autocast section `_ +# of the dispatcher tutorial. + +########################################################## +# .. _troubleshooting: +# +# Troubleshooting +# --------------- +# Speedup with Amp is minor +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. Your network may fail to saturate the GPU(s) with work, and is therefore CPU bound. Amp's effect on GPU performance +# won't matter. +# +# * A rough rule of thumb to saturate the GPU is to increase batch and/or network size(s) +# as much as you can without running OOM. +# * Try to avoid excessive CPU-GPU synchronization (``.item()`` calls, or printing values from CUDA tensors). +# * Try to avoid sequences of many small CUDA ops (coalesce these into a few large CUDA ops if you can). +# 2. Your network may be GPU compute bound (lots of matmuls/convolutions) but your GPU does not have Tensor Cores. +# In this case a reduced speedup is expected. +# 3. Matmul dimensions are not Tensor Core-friendly. Make sure matmuls' participating sizes are multiples of 8. +# (For NLP models with encoders/decoders, this can be subtle. Also, convolutions used to have similar size constraints +# for Tensor Core use, but for CuDNN versions 7.3 and later, no such constraints exist. See +# `here `_ for guidance.) +# +# Loss is inf/NaN +# ~~~~~~~~~~~~~~~ +# First, check if your network fits an :ref:`advanced use case`. +# See also `Prefer binary_cross_entropy_with_logits over binary_cross_entropy `_. +# +# If you're confident your Amp usage is correct, you may need to file an issue, but before doing so, it's helpful to gather the following information: +# +# 1. Disable ``autocast`` or ``GradScaler`` individually (by passing ``enabled=False`` to their constructor) and see if infs/NaNs persist. +# 2. If you suspect part of your network (e.g., a complicated loss function) overflows , run that forward region in ``float32`` +# and see if infs/NaNs persist. +# `The autocast docstring `_'s last code snippet +# shows forcing a subregion to run in ``float32`` (by locally disabling autocast and casting the subregion's inputs). +# +# Type mismatch error (may manifest as CUDNN_STATUS_BAD_PARAM) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Autocast tries to cover all ops that benefit from or require casting. +# `Ops that receive explicit coverage `_ +# are chosen based on numerical properties, but also on experience. +# If you see a type mismatch error in an autocast-enabled forward region or a backward pass following that region, +# it's possible autocast missed an op. +# +# Please file an issue with the error backtrace. ``export TORCH_SHOW_CPP_STACKTRACES=1`` before running your script to provide +# fine-grained information on which backend op is failing. diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst index 86438135e1d..f8986363092 100644 --- a/recipes_source/recipes_index.rst +++ b/recipes_source/recipes_index.rst @@ -167,6 +167,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu :link: ../recipes/android_native_app_with_custom_op.html :tags: Mobile +.. Automatic Mixed Precision + +.. customcarditem:: + :header: Automatic Mixed Precision + :card_description: Use torch.cuda.amp to reduce runtime and save memory on NVIDIA GPUs. + :image: ../_static/img/thumbnails/cropped/amp.png + :link: ../recipes/recipes/amp_recipe.html + :tags: Model-Optimization + .. End of tutorial card section .. raw:: html @@ -199,6 +208,7 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu /recipes/recipes/Captum_Recipe /recipes/recipes/tensorboard_with_pytorch /recipes/recipes/dynamic_quantization + /recipes/recipes/amp_recipe /recipes/torchscript_inference /recipes/deployment_with_flask /recipes/distributed_rpc_profiling