Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable AMP for BetterTransformer at torch >= 2.0.0 #953

Conversation

ViktorooReps
Copy link

@ViktorooReps ViktorooReps commented Apr 4, 2023

What does this PR do?

Fixes #952

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

optimum/bettertransformer/models/base.py Outdated Show resolved Hide resolved
tests/bettertransformer/testing_utils.py Outdated Show resolved Hide resolved
optimum/bettertransformer/models/base.py Outdated Show resolved Hide resolved
optimum/bettertransformer/models/base.py Outdated Show resolved Hide resolved
@ViktorooReps
Copy link
Author

I launched the tests for bettertransformer folder, but I am not sure if my GPU got picked up by optimum..
image

@ViktorooReps
Copy link
Author

Should be good now. @fxmarty can you have a look?

@fxmarty
Copy link
Contributor

fxmarty commented Apr 4, 2023

Apparently some tests do not pass. You can try pytest tests/bettertransformer/test_*.py -k "test_raise_autocast" -s to reproduce locally.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ViktorooReps
Copy link
Author

___________________________________________________________________________________ BetterTransformersEncoderTest.test_raise_autocast_16_xlm_roberta ___________________________________________________________________________________

a = (<test_encoder.BetterTransformersEncoderTest testMethod=test_raise_autocast_16_xlm_roberta>,), kw = {}

    @wraps(func)
    def standalone_func(*a, **kw):
>       return func(*(a + p.args), **p.kwargs, **kw)

venv/lib/python3.8/site-packages/parameterized/parameterized.py:620: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/bettertransformer/test_encoder.py:208: in test_raise_autocast
    self._test_raise_autocast(model_id, model_type)
tests/bettertransformer/testing_utils.py:227: in _test_raise_autocast
    _ = bt_model(**inputs)
venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.8/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py:854: in forward
    encoder_outputs = self.encoder(
venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.8/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py:528: in forward
    layer_outputs = layer_module(
venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = BertLayerBetterTransformer()
hidden_states = nested_tensor([
  tensor([[ 1.1092, -0.0496, -0.6536,  ...,  1.0547,  0.0991,  1.5331],
          [ 1.1092, -0.0496, -...00,  1.0547e+00,  9.9146e-02,
            1.5331e+00]], requires_grad=True)
], grad_fn=<NestedTensorFromMaskBackward0>)
attention_mask = None, _ = (None, None, None, None, False)

    def forward(self, hidden_states, attention_mask, *_):
        r"""
        This is just a wrapper around the forward function proposed in:
        https://github.com/huggingface/transformers/pull/19553
        """
        super().forward_checker()
    
        if hidden_states.is_nested:
            attention_mask = None
    
        if attention_mask is not None:
            # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
            # 0->false->keep this token -inf->true->mask this token
            attention_mask = attention_mask.bool()
            attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
            hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
            attention_mask = None
    
>       hidden_states = torch._transformer_encoder_layer_fwd(
            hidden_states,
            self.embed_dim,
            self.num_heads,
            self.in_proj_weight,
            self.in_proj_bias,
            self.out_proj_weight,
            self.out_proj_bias,
            self.use_gelu,
            self.norm_first,
            self.norm1_eps,
            self.norm1_weight,
            self.norm1_bias,
            self.norm2_weight,
            self.norm2_bias,
            self.linear1_weight,
            self.linear1_bias,
            self.linear2_weight,
            self.linear2_bias,
            attention_mask,
        )
E       RuntimeError: expected scalar type BFloat16 but found Float

optimum/bettertransformer/models/encoder_models.py:250: RuntimeError

This is very strange. This very model works fine if I launch it outside optimum tests

Could this be related to #542?

@ViktorooReps
Copy link
Author

ViktorooReps commented Apr 5, 2023

Well, this is embarrassing...

Here is the source code for TransformerEncoderLayer.forward method from torch (https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer):

    ...
    elif torch.is_autocast_enabled():
            why_not_sparsity_fast_path = "autocast is enabled"
    ...

The fast path is not calculated with autocast enabled! By setting torch.is_autocast_enabled to lambda: False I practically forced fast path calculation on incorrect input data..

@ViktorooReps
Copy link
Author

And I wondered why torch._transformer_encoder_layer_fwd suddenly became GPU-blocking..

I am closing this pull request and associated issue. Sorry for possible confusion!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable AMP for BetterTransformer
3 participants