Skip to content

Commit

Permalink
comments and .md file modification
Browse files Browse the repository at this point in the history
  • Loading branch information
susnato committed Nov 16, 2023
1 parent 10ab9b3 commit 7d4c688
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 33 deletions.
49 changes: 49 additions & 0 deletions docs/source/en/model_doc/opt.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,55 @@ The resource should ideally demonstrate something new instead of duplicating an

- A blog post on [How 🤗 Accelerate runs very large models thanks to PyTorch](https://huggingface.co/blog/accelerate-large-models) with OPT.


## Combining OPT and Flash Attention 2

First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.

```bash
pip install -U flash-attn --no-build-isolation
```

Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)

To load and run a model using Flash Attention 2, refer to the snippet below:

```python
>>> import torch
>>> from transformers import OPTForCausalLM, GPT2Tokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")

>>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
"Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
"there?")

>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
>>> model.to(device)

>>> generated_ids = model.generate(**model_inputs, max_new_tokens=30, do_sample=False)
>>> tokenizer.batch_decode(generated_ids)[0]
'</s>A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived there?\nStatue: I have lived here for about a year.\nHuman: What is your favorite place to eat?\nStatue: I love'
```

### Expected speedups

Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `facebook/opt-2.7b` checkpoint and the Flash Attention 2 version of the model using two different sequence lengths.

<div style="text-align: center">
<img src="https://user-images.githubusercontent.com/49240599/281101546-d2fca6d2-ee44-48f3-9534-ba8d5bee4531.png">
</div>

Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `facebook/opt-350m` checkpoint and the Flash Attention 2 version of the model using two different sequence lengths.

<div style="text-align: center">
<img src="https://user-images.githubusercontent.com/49240599/281101682-d1144e90-0dbc-46f4-8fc8-c6206cb793c9.png">
</div>



## OPTConfig

[[autodoc]] OPTConfig
Expand Down
63 changes: 30 additions & 33 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OPT model."""
import warnings
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -115,15 +114,34 @@ class OPTAttention(nn.Module):
def __init__(
self,
config: OPTConfig,
is_decoder: bool,
is_decoder: bool = False,
**kwargs,
):
super().__init__()
self.config = config

self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.dropout = config.attention_dropout
self.head_dim = config.hidden_size // config.num_attention_heads
def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
"""
If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
warning and return that value, otherwise take the equivalent config.config_arg_name
"""
val = None
if fn_arg_name in kwargs:
logging.warning(
"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
" Please set it in the config instead"
)
val = kwargs.pop(fn_arg_name)
else:
val = getattr(config, config_arg_name)
return val

self.embed_dim = _handle_deprecated_argument("hidden_size", config, "embed_dim", kwargs)
self.num_heads = _handle_deprecated_argument("num_attention_heads", config, "num_heads", kwargs)
self.dropout = _handle_deprecated_argument("attention_dropout", config, "dropout", kwargs)
self.enable_bias = _handle_deprecated_argument("enable_bias", config, "bias", kwargs)

self.head_dim = self.embed_dim // self.num_heads
self.is_causal = True

if (self.head_dim * self.num_heads) != self.embed_dim:
Expand All @@ -134,10 +152,10 @@ def __init__(
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder

self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.enable_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.enable_bias)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.enable_bias)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.enable_bias)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
Expand All @@ -150,15 +168,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
Expand Down Expand Up @@ -284,18 +296,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
Expand Down Expand Up @@ -478,15 +481,9 @@ def __init__(self, config: OPTConfig):
self.embed_dim = config.hidden_size

if not getattr(config, "_flash_attn_2_enabled", False):
self.self_attn = OPTAttention(
config=config,
is_decoder=True,
)
self.self_attn = OPTAttention(config=config, is_decoder=True)
else:
self.self_attn = OptFlashAttention2(
config=config,
is_decoder=True,
)
self.self_attn = OptFlashAttention2(config=config, is_decoder=True)

self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
Expand Down

0 comments on commit 7d4c688

Please sign in to comment.