Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable AMP for BetterTransformer #952

Closed
ViktorooReps opened this issue Apr 4, 2023 · 6 comments
Closed

Enable AMP for BetterTransformer #952

ViktorooReps opened this issue Apr 4, 2023 · 6 comments

Comments

@ViktorooReps
Copy link

Feature request

Allow for the BetterTransformer models to be inferenced with AMP.

Motivation

Models transformed with BetterTransformer raise error when used with AMP:

bettertransformers.models.base

    ...
    def forward_checker(self, *args, **kwargs):
        if torch.is_autocast_enabled() or torch.is_autocast_cpu_enabled():
            raise ValueError("Autocast is not supported for `BetterTransformer` integration.")

        if self.training and not self.is_decoder:
            raise ValueError(
                "Training is not supported for `BetterTransformer` integration.",
                " Please use `model.eval()` before running the model.",
            )
    ...

Why is that? I tried setting torch.is_autocast_enabled to lambda: False and everything works just fine at least for XLMRobertaModel:

>>> import torch
>>> from transformers import AutoModel
>>> from optimum.bettertransformer import BetterTransformer
>>> m = AutoModel.from_pretrained('xlm-roberta-base')
>>> BetterTransformer.transform(m, keep_original_model=False)
XLMRobertaModel(
  (embeddings): XLMRobertaEmbeddings(
    (word_embeddings): Embedding(250002, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): XLMRobertaEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayerBetterTransformer()
    )
  )
  (pooler): XLMRobertaPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)
>>> with torch.amp.autocast('cuda'):
...     m(**{name: t.to('cuda') for name, t in m.dummy_inputs.items()})
... 
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ <stdin>:2 in <module>                                                                            │
│                                                                                                  │
│ /home/viktor-sch/Clones/talisman-ie/venv/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1501 in _call_impl                                                                              │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/viktor-sch/Clones/talisman-ie/venv/lib/python3.10/site-packages/transformers/models/xlm_ro │
│ berta/modeling_xlm_roberta.py:854 in forward                                                     │
│                                                                                                  │
│    851 │   │   │   inputs_embeds=inputs_embeds,                                                  │
│    852 │   │   │   past_key_values_length=past_key_values_length,                                │
│    853 │   │   )                                                                                 │
│ ❱  854 │   │   encoder_outputs = self.encoder(                                                   │
│    855 │   │   │   embedding_output,                                                             │
│    856 │   │   │   attention_mask=extended_attention_mask,                                       │
│    857 │   │   │   head_mask=head_mask,                                                          │
│                                                                                                  │
│ /home/viktor-sch/Clones/talisman-ie/venv/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1501 in _call_impl                                                                              │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/viktor-sch/Clones/talisman-ie/venv/lib/python3.10/site-packages/transformers/models/xlm_ro │
│ berta/modeling_xlm_roberta.py:528 in forward                                                     │
│                                                                                                  │
│    525 │   │   │   │   │   encoder_attention_mask,                                               │
│    526 │   │   │   │   )                                                                         │
│    527 │   │   │   else:                                                                         │
│ ❱  528 │   │   │   │   layer_outputs = layer_module(                                             │
│    529 │   │   │   │   │   hidden_states,                                                        │
│    530 │   │   │   │   │   attention_mask,                                                       │
│    531 │   │   │   │   │   layer_head_mask,                                                      │
│                                                                                                  │
│ /home/viktor-sch/Clones/talisman-ie/venv/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1501 in _call_impl                                                                              │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/viktor-sch/Clones/talisman-ie/venv/lib/python3.10/site-packages/optimum/bettertransformer/ │
│ models/encoder_models.py:235 in forward                                                          │
│                                                                                                  │
│    232 │   │   This is just a wrapper around the forward function proposed in:                   │
│    233 │   │   https://github.com/huggingface/transformers/pull/19553                            │
│    234 │   │   """                                                                               │
│ ❱  235 │   │   super().forward_checker()                                                         │
│    236 │   │                                                                                     │
│    237 │   │   if hidden_states.is_nested:                                                       │
│    238 │   │   │   attention_mask = None                                                         │
│                                                                                                  │
│ /home/viktor-sch/Clones/talisman-ie/venv/lib/python3.10/site-packages/optimum/bettertransformer/ │
│ models/base.py:134 in forward_checker                                                            │
│                                                                                                  │
│   131 │                                                                                          │
│   132 │   def forward_checker(self, *args, **kwargs):                                            │
│   133 │   │   if torch.is_autocast_enabled() or torch.is_autocast_cpu_enabled():                 │
│ ❱ 134 │   │   │   raise ValueError("Autocast is not supported for `BetterTransformer` integrat   │
│   135 │   │                                                                                      │
│   136 │   │   if self.training and not self.is_decoder:                                          │
│   137 │   │   │   raise ValueError(                                                              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Autocast is not supported for `BetterTransformer` integration.
>>> torch.is_autocast_enabled = lambda: False
>>> torch.is_autocast_enabled()
False
>>> with torch.amp.autocast('cuda'):
...     m(**{name: t.to('cuda') for name, t in m.dummy_inputs.items()})
... 
BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 0.0797,  0.1111,  0.0501,  ..., -0.1485,  0.0307,  0.0300],
         [-0.0471, -0.0122,  0.0330,  ..., -0.0804, -0.0655,  0.0632],
         [-0.0111,  0.0361,  0.0429,  ..., -0.1467, -0.0251,  0.0606],
         [ 0.0074,  0.0218, -0.1271,  ..., -0.3417, -0.1976,  0.1173],
         [ 0.0797,  0.1111,  0.0501,  ..., -0.1485,  0.0306,  0.0301]],

        [[ 0.0625,  0.1222, -0.0087,  ..., -0.2471, -0.0332,  0.0863],
         [ 0.0743,  0.1297,  0.0457,  ..., -0.1627,  0.0325,  0.0482],
         [-0.0586,  0.0403, -0.0375,  ..., -0.1715, -0.0187,  0.1560],
         [-0.0482,  0.0794, -0.0066,  ..., -0.1591, -0.0139,  0.0940],
         [-0.0421,  0.0817, -0.0104,  ..., -0.1484, -0.0112,  0.0803]],

        [[ 0.2380,  0.2632,  0.0340,  ..., -0.2380,  0.0531,  0.1011],
         [ 0.0011,  0.0960,  0.0071,  ...,  0.0116, -0.0556,  0.1343],
         [ 0.0035,  0.0674, -0.0191,  ...,  0.0397, -0.0757,  0.0947],
         [-0.0608,  0.0363, -0.0187,  ..., -0.2528, -0.0873,  0.3649],
         [ 0.0608,  0.1712, -0.0644,  ..., -0.2470, -0.0682,  0.2463]]],
       device='cuda:0', grad_fn=<ToPaddedTensorBackward0>), pooler_output=tensor([[-0.0355,  0.2510,  0.1194,  ..., -0.0778, -0.0937,  0.1245],
        [-0.0283,  0.2449,  0.1166,  ..., -0.0936, -0.1037,  0.1198],
        [-0.0764,  0.1265,  0.0545,  ...,  0.0739, -0.2233,  0.0834]],
       device='cuda:0', dtype=torch.float16, grad_fn=<TanhBackward0>), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)

Your contribution

My guess would be is that originally it was disabled since NestedTensor had no fp16 backends. Since now it is not the case (at least in PyTorch 2.0.0) I can replace this AMP enable check with torch version check.

@fxmarty
Copy link
Contributor

fxmarty commented Apr 4, 2023

@younesbelkada any idea?

@younesbelkada
Copy link
Contributor

My guess would be is that originally it was disabled since NestedTensor had no fp16 backends. Since now it is not the case (at least in PyTorch 2.0.0) I can replace this AMP enable check with torch version check.

This solution sounds good to me! Do you mind opening a PR to add that fix? Otherwise happy to do it

@ViktorooReps
Copy link
Author

My guess would be is that originally it was disabled since NestedTensor had no fp16 backends. Since now it is not the case (at least in PyTorch 2.0.0) I can replace this AMP enable check with torch version check.

This solution sounds good to me! Do you mind opening a PR to add that fix? Otherwise happy to do it

I will get on to it in a moment. Will tag you for a review if you don't mind

@ViktorooReps
Copy link
Author

It turns out fast path calculation is indeed not supported with mixed precision in torch. By setting torch.is_autocast_enabled = lambda: False I forced fast path calculation on possibly incorrect input data.

@younesbelkada
Copy link
Contributor

Thanks a lot for digging into that!

@fxmarty
Copy link
Contributor

fxmarty commented Jul 26, 2023

Hi, autocast is now supported with #1225, to the extent pytorch supports it (dispatching to an other compute path if autocast is enabled).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants