-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FA2
] Add flash attention for opt
#26414
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look great to me! Thanks a lot for working on this @susnato !
I can confirm the tests pass on a A100, just tested it
Before merging please have a look at my comment below, also can you add OPT to the list of officially supported models? The changes would go here: https://github.com/susnato/transformers/blob/flash_attn_opt/docs/source/en/perf_infer_gpu_one.md
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@susnato I have realised there were some issues with FA 2 + OPT + use_cache, some minor changes were needed, please have a look at susnato#3 and I will add later tests for the case users use caching to catch that in the future.
To reproduce:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-350m",
use_flash_attention_2=True,
torch_dtype=torch.float16
).to(0)
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
text = "Hello my name is"
inputs = tokenizer(text, return_tensors="pt").to(0)
out = model.generate(**inputs, max_new_tokens=30, use_cache=True)
print(tokenizer.batch_decode(out, skip_special_tokens=True))
Hi @younesbelkada, I have added the model to the list. Also I have merged the susnato#3, please let me know if anymore changes are needed or not. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for iterating ! Can you also run the styling checks?
make fixup
Hi @younesbelkada , done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating! left one comment
@@ -216,7 +216,6 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, | |||
else {} | |||
) | |||
is_encoder_decoder = False | |||
fx_compatible = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should not be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to get error for the fx tests - test_torch_fx
and test_torch_fx_output_loss
.
The error is happening for this line.
The full error for test_torch_fx
tests/models/opt/test_modeling_opt.py F [100%]
========================================================= FAILURES =========================================================
________________________________________________ OPTModelTest.test_torch_fx ________________________________________________
self = <tests.models.opt.test_modeling_opt.OPTModelTest testMethod=test_torch_fx>
config = OPTConfig {
"_remove_final_layer_norm": false,
"activation_function": "relu",
"attention_dropout": 0.1,
"bos_t...d": 1,
"transformers_version": "4.34.0.dev0",
"use_cache": true,
"vocab_size": 99,
"word_embed_proj_dim": 16
}
inputs_dict = {'attention_mask': tensor([[True, True, True, True, True, True, True],
[True, True, True, True, True, True, Tr...91, 54, 98, 57, 55, 2],
[74, 79, 56, 51, 93, 26, 2],
[62, 18, 55, 3, 73, 74, 2]], device='cuda:0')}
output_loss = False
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"attention_mask",
"decoder_attention_mask",
"decoder_input_ids",
"input_features",
"input_ids",
"input_values",
]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = [
"attention_mask",
"bbox",
"input_features",
"input_ids",
"input_values",
"pixel_values",
"token_type_ids",
"visual_feats",
"visual_pos",
]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
not hasattr(model.config, "problem_type") or model.config.problem_type is None
):
model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names)
tests/test_modeling_common.py:877:
model = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
input_names = ['input_ids', 'attention_mask'], disable_check = False, tracer_cls = <class 'transformers.utils.fx.HFTracer'>
def symbolic_trace(
model: PreTrainedModel,
input_names: Optional[List[str]] = None,
disable_check: bool = False,
tracer_cls: Type[HFTracer] = HFTracer,
) -> GraphModule:
"""
Performs symbolic tracing on the model.
Args:
model ([`PretrainedModel`]):
The model to trace.
input_names (`List[str]`, *optional*):
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
disable_check (`bool`, *optional*, defaults to `False`):
If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.
Returns:
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
Example:
```python
from transformers.utils.fx import symbolic_trace
traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
```
"""
if input_names is None:
input_names = model.dummy_inputs.keys()
input_names = list(input_names)
concrete_args = get_concrete_args(model, input_names)
if not disable_check:
check_if_model_is_supported(model)
# Tracing.
tracer = tracer_cls()
traced_graph = tracer.trace(model, concrete_args=concrete_args)
src/transformers/utils/fx.py:1250:
self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
root = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
concrete_args = {'head_mask': None, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, ...}
dummy_inputs = None, complete_concrete_args_with_inputs_not_in_dummy_inputs = True
def trace(
self,
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
dummy_inputs: Optional[Dict[str, Any]] = None,
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph:
"""
Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
`torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
`torch.nn.Module` instance to use as the root and add embedded constants to.
Args:
root (`torch.nn.Module` or `Callable`):
Either a `torch.nn.Module`` or a function to be traced through. If root is not a
[`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
concrete_args (`Dict[str, Any], *optional*):
Concrete arguments that should not be treated as Proxies
dummy_inputs (`Dict[str, Any]`, *optional*):
The dummy inputs needed to handle data-dependent control-flow if `root` is not a
[`~transformers.PreTrainedModel`]. It can also be used when `root` is a
[`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
`dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.
Returns:
`torch.fx.Graph`:
A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.
"""
sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)
if concrete_args is None:
concrete_args = {}
if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
for param in sig.parameters.values():
if param.name in dummy_inputs:
continue
if param.default is inspect.Parameter.empty:
raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
concrete_args.update(
{
p.name: p.default
for p in sig.parameters.values()
if (p.name not in dummy_inputs and p.name not in concrete_args)
}
)
input_names = sig.parameters.keys() - concrete_args.keys()
# Creating a random input shape to generate dummy inputs.
batch_size = _generate_random_int()
sequence_length = _generate_random_int()
shape = [batch_size, sequence_length]
if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
num_choices = _generate_random_int(low=2, high=5)
shape.insert(1, num_choices)
inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
for input_name in input_names:
if input_name in inputs:
continue
# We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
# be able to use HFTracer._generate_dummy_input.
if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
("_deserialize_graph_module", "_CodeOnlyModule")
):
inputs.update(self._generate_dummy_input(root, input_name, shape))
else:
raise RuntimeError(
f"Could not generate input named {input_name} for because root is not a"
" transformers.PreTrainedModel."
)
concrete_metas = {
input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
for input_name, input_ in inputs.items()
}
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
concrete_metas[f"**{param.name}"] = {}
self.meta_args = concrete_metas
self.patched_torch_methods = {
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
}
self.orig_fns = set()
for name, (wrapper, orig) in self.patched_torch_methods.items():
setattr(torch, name, wrapper)
self.orig_fns.add(orig)
try:
self.graph = super().trace(root, concrete_args=concrete_args)
src/transformers/utils/fx.py:1088:
self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
root = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
concrete_args = {'head_mask': None, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, ...}
@compatibility(is_backward_compatible=True)
def trace(
self,
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
) -> Graph:
"""
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
can either be an ``nn.Module`` instance or a Python callable.
Note that after this call, ``self.root`` may be different from the ``root`` passed
in here. For example, when a free function is passed to ``trace()``, we will
create an ``nn.Module`` instance to use as the root and add embedded constants
to.
Args:
root (Union[Module, Callable]): Either a ``Module`` or a function to be
traced through. Backwards-compatibility for this parameter is
guaranteed.
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
not be treated as Proxies. This parameter is experimental and
its backwards-compatibility is *NOT* guaranteed.
Returns:
A ``Graph`` representing the semantics of the passed-in ``root``.
"""
global _is_fx_tracing_flag
old_is_fx_tracing_flag = _is_fx_tracing_flag
_is_fx_tracing_flag = True
try:
if isinstance(root, torch.nn.Module):
self.root = root
assert hasattr(
type(root), self.traced_func_name
), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
fn = getattr(type(root), self.traced_func_name)
self.submodule_paths = {mod: name for name, mod in root.named_modules()}
else:
self.root = torch.nn.Module()
fn = root
tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None)
self.graph = Graph(tracer_cls=tracer_cls)
# When we encounter a Tensor value that's not a parameter, we look if it
# is some other attribute on the model. Construct a dict mapping Tensor
# values to the qualified name here for efficiency. This is used downstream
# in create_arg
self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {}
def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
for k, v in m.__dict__.items():
if isinstance(v, (torch.Tensor, ScriptObject)):
self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
for k, v in m.named_children():
collect_tensor_attrs(v, prefix_atoms + [k])
collect_tensor_attrs(self.root, [])
assert isinstance(fn, FunctionType)
fn_globals = fn.__globals__ # run before it gets patched
fn, args = self.create_args_for_root(
fn, isinstance(root, torch.nn.Module), concrete_args
)
parameter_proxy_cache: Dict[
str, Proxy
] = {} # Reduce number of get_attr calls
# Method dispatch on parameters is not recorded unless it's directly used.
# Thus, we need to insert a proxy when __getattr__ requests a parameter.
@functools.wraps(_orig_module_getattr)
def module_getattr_wrapper(mod, attr):
attr_val = _orig_module_getattr(mod, attr)
return self.getattr(attr, attr_val, parameter_proxy_cache)
@functools.wraps(_orig_module_call)
def module_call_wrapper(mod, *args, **kwargs):
def forward(*args, **kwargs):
return _orig_module_call(mod, *args, **kwargs)
_autowrap_check(
patcher,
getattr(getattr(mod, "forward", mod), "__globals__", {}),
self._autowrap_function_ids,
)
return self.call_module(mod, forward, args, kwargs)
with _Patcher() as patcher:
# allow duplicate patches to support the case of nested calls
patcher.patch_method(
torch.nn.Module,
"__getattr__",
module_getattr_wrapper,
deduplicate=False,
)
patcher.patch_method(
torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False
)
_patch_wrapped_functions(patcher)
_autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
for module in self._autowrap_search:
_autowrap_check(
patcher, module.__dict__, self._autowrap_function_ids
)
self.create_node(
"output",
"output",
(self.create_arg(fn(*args)),),
{},
type_expr=fn.__annotations__.get("return", None),
)
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:739:
self = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
input_ids = Proxy(input_ids), attention_mask = Proxy(attention_mask), head_mask = None, past_key_values = None
inputs_embeds = None, use_cache = True, output_attentions = False, output_hidden_states = False, return_dict = False
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
src/transformers/models/opt/modeling_opt.py:1023:
mod = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>
@functools.wraps(_orig_module_call)
def module_call_wrapper(mod, *args, **kwargs):
def forward(*args, **kwargs):
return _orig_module_call(mod, *args, **kwargs)
_autowrap_check(
patcher,
getattr(getattr(mod, "forward", mod), "__globals__", {}),
self._autowrap_function_ids,
)
return self.call_module(mod, forward, args, kwargs)
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:717:
self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
m = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>, args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
return super().call_module(m, forward, args, kwargs)
src/transformers/utils/fx.py:987:
self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
m = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>, args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
@compatibility(is_backward_compatible=True)
def call_module(
self,
m: torch.nn.Module,
forward: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
"""
Method that specifies the behavior of this ``Tracer`` when it encounters
a call to an ``nn.Module`` instance.
By default, the behavior is to check if the called module is a leaf module
via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
the operations in its ``forward`` function.
This method can be overridden to--for example--create nested traced
GraphModules, or any other behavior you would want while tracing across
``Module`` boundaries.
Args:
m (Module): The module for which a call is being emitted
forward (Callable): The forward() method of the ``Module`` to be invoked
args (Tuple): args of the module callsite
kwargs (Dict): kwargs of the module callsite
Return:
The return value from the Module call. In the case that a ``call_module``
node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
value was returned from the ``Module`` invocation.
"""
module_qualified_name = self.path_of_module(m)
if not self.is_leaf_module(m, module_qualified_name):
return forward(*args, **kwargs)
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:434:
args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
def forward(*args, **kwargs):
return _orig_module_call(mod, *args, **kwargs)
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:710:
self = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
input = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
forward_call = <bound method OPTDecoder.forward of OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions)...out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)>
def _call_impl(self, *input, **kwargs):
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
# If we don't have any hooks, we want to skip the rest of the logic in
# this function, and just call forward.
if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
or _global_forward_hooks or _global_forward_pre_hooks):
return forward_call(*input, **kwargs)
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194:
self = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
input_ids = Proxy(view), attention_mask = Proxy(attention_mask), head_mask = None, past_key_values = None
inputs_embeds = Proxy(decoder_embed_tokens), use_cache = True, output_attentions = False, output_hidden_states = False
return_dict = False
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
padding_mask = None
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
else:
if 0 in attention_mask:
src/transformers/models/opt/modeling_opt.py:872:
self = Proxy(attention_mask), key = 0
def __contains__(self, key):
if hasattr(self, "_metadata") and self._metadata is not None:
return key in self._metadata
src/transformers/utils/fx.py:646:
self = tensor(..., device='meta', size=(14, 11), dtype=torch.int64), element = 0
def __contains__(self, element):
r"""Check if `element` is present in tensor
Args:
element (Tensor or scalar): element to be checked
for presence in current tensor"
"""
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__contains__, (self,), self, element)
if isinstance(element, (torch.Tensor, Number)):
# type hint doesn't understand the __contains__ result array
return (element == self).any().item() # type: ignore[union-attr]
E NotImplementedError: Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_local_scalar_dense' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].
E
E CPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCPU.cpp:30798 [kernel]
E CUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCUDA.cpp:43635 [kernel]
E BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
E Python: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:140 [backend fallback]
E FuncTorchDynamicLayerBackMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:488 [backend fallback]
E Functionalize: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:291 [backend fallback]
E Named: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
E Conjugate: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
E Negative: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
E ZeroTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
E ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
E AutogradOther: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHIP: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMPS: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradIPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradVE: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradLazy: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMeta: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E Tracer: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/TraceType_2.cpp:16890 [kernel]
E AutocastCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:482 [backend fallback]
E AutocastCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
E FuncTorchBatched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/BatchRulesDynamic.cpp:64 [kernel]
E FuncTorchVmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
E Batched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]
E VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
E FuncTorchGradWrapper: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/TensorWrapper.cpp:189 [backend fallback]
E PythonTLSSnapshot: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]
E FuncTorchDynamicLayerFrontMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:484 [backend fallback]
E PythonDispatcher: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/_tensor.py:983: NotImplementedError
During handling of the above exception, another exception occurred:
self = <tests.models.opt.test_modeling_opt.OPTModelTest testMethod=test_torch_fx>
def test_torch_fx(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict)
tests/test_modeling_common.py:805:
tests/test_modeling_common.py:882: in _create_and_check_torch_fx_tracing
self.fail(f"Couldn't trace module: {e}")
E AssertionError: Couldn't trace module: Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_local_scalar_dense' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].
E
E CPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCPU.cpp:30798 [kernel]
E CUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCUDA.cpp:43635 [kernel]
E BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
E Python: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:140 [backend fallback]
E FuncTorchDynamicLayerBackMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:488 [backend fallback]
E Functionalize: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:291 [backend fallback]
E Named: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
E Conjugate: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
E Negative: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
E ZeroTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
E ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
E AutogradOther: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHIP: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMPS: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradIPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradVE: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradLazy: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMeta: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E Tracer: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/TraceType_2.cpp:16890 [kernel]
E AutocastCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:482 [backend fallback]
E AutocastCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
E FuncTorchBatched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/BatchRulesDynamic.cpp:64 [kernel]
E FuncTorchVmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
E Batched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]
E VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
E FuncTorchGradWrapper: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/TensorWrapper.cpp:189 [backend fallback]
E PythonTLSSnapshot: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]
E FuncTorchDynamicLayerFrontMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:484 [backend fallback]
E PythonDispatcher: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]
EDIT : The errors are fixed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I see, in that line can you use instead torch.isin
? https://pytorch.org/docs/stable/generated/torch.isin.html - suggestion from @michaelbenayoun
elif attention_mask.shape[1] != mask_seq_length: | ||
raise ValueError( | ||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " | ||
f"{mask_seq_length} (sum of the lengths of current and past inputs)" | ||
) | ||
else: | ||
padding_mask = attention_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try to add a conditional check to check if there is any 0 in the attention mask, sometimes users pass an attention mask with a full ones, therefore we'll need to set padding_mask
to None to avoid entering the padding case in OPTFlashAttention2
module
If the test fails with torch.fx can you try with torch.isin
? otherwise you can also do a check based on the sum of the attention mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, sorry for removing it directly without discussing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no problem at all @susnato ! 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It fails for -
if torch.isin(attention_mask.int(), 0).sum().item()!=0
(AssertionError: Couldn't trace module: symbolically traced variables cannot be used as inputs to control flow
)if torch.isin(attention_mask.int(), 0).sum()!=0
(AssertionError: Couldn't trace module: bool should return bool, returned Tensor
)if attention_mask.sum().item()!=attention_mask.numel()
(AssertionError: Couldn't trace module: symbolically traced variables cannot be used as inputs to control flow
)if not torch.equal(attention_mask, torch.ones_like(attention_mask))
( AssertionError: Couldn't trace module: symbolically traced variables cannot be used as inputs to control flow)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @younesbelkada, I have solved the error but it requires us to change the _flash_attention_forward
function a little bit, otherwise if we try to add the if
statement in the OPTDecoder.forward
we are getting the error.
""" | ||
# we check if padding_mask contains all ones, in that case we don't use it. This is to make sure the torch.fx | ||
# tests pass for the relevant models | ||
if padding_mask.sum().item() != padding_mask.numel(): | ||
batch_size = query_states.shape[0] | ||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that if we check if attention_mask
is ones here, the torch.fx
tests pass.
|
||
padding_mask = attention_mask | ||
|
||
causal_attention_mask = self._prepare_decoder_attention_mask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case we need to pass padding_mask = attention_mask
and we check for the ones array in the _flash_attention_forward
. If we use any if-else
condition here it gives an error for the torch.fx
tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks very much for iterating @susnato I have tried out also different approaches on my end as well
# Avoids torch FX issues
if torch.any(~attention_mask.bool()):
padding_mask = attention_mask
else:
padding_mask = None
and the FX tests still fails
The issue with the current approach is that the padding_mask.sum().item() != padding_mask.numel()
adds an important overhead as it computed by each attention layer, especially for long sequence length leading to a slow down.
Let me dig a bit with @michaelbenayoun regarding failing FX tests and get back to you
Thanks for the reply @younesbelkada, in meantime I would like to work on adding flash attention for another model. |
yes I am interested in that! @younesbelkada |
Thanks very much ! Looking forward to your PR ! |
Also please assign me in the list :). @younesbelkada |
Yes, just assigned you on the list ! |
@susnato @younesbelkada why not use |
Hi @susnato 1- define a method on the top level of the modeling opt file that checks if there is any padding token inside the attention mask if is_torch_fx_available():
@torch.fx.wrap
def check_padding_in_attention_mask(attention_mask):
if 0 in attention_mask:
return attention_mask
return None
else:
def check_padding_in_attention_mask(attention_mask):
if 0 in attention_mask:
return attention_mask
return None And add the That way FX tests should hopefully pass I will later take care of moving |
Hi @younesbelkada , Full Traceback for Llamaself = <tests.models.llama.test_modeling_llama.LlamaModelTest testMethod=test_flash_attn_2_generate_use_cache>
tests/test_modeling_common.py:2936: ../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27: in decorate_context ctx = <torch.autograd.function.IndexPutFirstAxisBackward object at 0x7f1fb51b9040>
E IndexError: tensors used as indices must be long, byte or bool tensors ../../anaconda3/envs/transformers/lib/python3.9/site-packages/flash_attn/bert_padding.py:51: IndexError BTW I updated the flash attention library and merged to main but the error is still there. |
I have pushed the docstring change. If you don't mind @younesbelkada , could you please checkout this branch and run the tests on your end and let me know the results? |
Sure @susnato no problem, will run the tests locally tomorrow and let you know how it goes |
Hi @younesbelkada, did you manage to check if the tests run successfully on your local machine? |
Yes the tests seemed to pass the time I ran them, however since #26792 being merged you need to remove the support for |
After I finish updating the |
OK awesome, thanks @susnato ! |
Hi @younesbelkada, I have pushed the changes. All the flash attention tests are running fine except |
Hello @younesbelkada , are you going to check the speed-up for this model too? I am asking this out of curiosity - Are you going to check speed-ups for every model, FlashAttention is added to or is this benchmark needed for popular ones only? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dropout: float = 0.0, | ||
is_decoder: bool = False, | ||
bias: bool = True, | ||
config: OPTConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we ok modifying the Attention Class's signature? We've had issues in the past when changing signature of such classes (e.g. for Llama when we added padding_mask
in the forward signature) cc @amyeroberts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be OK, as this isn't a class importable from the top layer of transformers or in the documentation. However, it definitely can still cause issues! What I'd suggest is something like this (happy for you to modify as desired):
def __init__(
self,
config,
is_decoder: bool = False,
**kwargs
):
def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
"""
If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
warning and return that value, otherwise take the equivalent config.config_arg_name
"""
val = None
if fn_arg_name in kwargs:
logging.warning(
"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
" Please set it in the config instead"
)
val = kwargs.pop(fn_arg_name)
else:
val = getattr(config, config_arg_name)
return val
embed_dim = _handle_deprecated_argument("hidden_size", config, "embed_dim", kwargs)
num_heads = _handle_deprecated_argument("num_attention_heads", config, "num_heads", kwargs)
dropout = _handle_deprecated_argument("attention_dropout", config, "dropout", kwargs)
enable_bias = _handle_deprecated_argument("enable_bias", config, "bias", kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
Just some comments on handling old arguments + deprecation warnings
self.self_attn = OPTAttention( | ||
config=config, | ||
is_decoder=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - can be put on one line
self.self_attn = OPTAttention( | |
config=config, | |
is_decoder=True, | |
) | |
self.self_attn = OPTAttention(config=config, is_decoder=True) |
self.self_attn = OptFlashAttention2( | ||
config=config, | ||
is_decoder=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.self_attn = OptFlashAttention2( | |
config=config, | |
is_decoder=True, | |
) | |
self.self_attn = OptFlashAttention2(config=config, is_decoder=True) |
|
||
if "padding_mask" in kwargs: | ||
warnings.warn( | ||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" | ||
) | ||
|
||
# overwrite attention_mask with padding_mask | ||
attention_mask = kwargs.pop("padding_mask") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't look like padding_mask
was a previous argument
if "padding_mask" in kwargs: | |
warnings.warn( | |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" | |
) | |
# overwrite attention_mask with padding_mask | |
attention_mask = kwargs.pop("padding_mask") |
if "padding_mask" in kwargs: | ||
warnings.warn( | ||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if "padding_mask" in kwargs: | |
warnings.warn( | |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" | |
) |
@@ -129,9 +150,15 @@ def forward( | |||
attention_mask: Optional[torch.Tensor] = None, | |||
layer_head_mask: Optional[torch.Tensor] = None, | |||
output_attentions: bool = False, | |||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**kwargs, |
is_decoder: bool = False, | ||
bias: bool = True, | ||
config: OPTConfig, | ||
is_decoder: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to keep the default value
is_decoder: bool, | |
is_decoder: bool = False, |
dropout: float = 0.0, | ||
is_decoder: bool = False, | ||
bias: bool = True, | ||
config: OPTConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be OK, as this isn't a class importable from the top layer of transformers or in the documentation. However, it definitely can still cause issues! What I'd suggest is something like this (happy for you to modify as desired):
def __init__(
self,
config,
is_decoder: bool = False,
**kwargs
):
def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
"""
If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
warning and return that value, otherwise take the equivalent config.config_arg_name
"""
val = None
if fn_arg_name in kwargs:
logging.warning(
"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
" Please set it in the config instead"
)
val = kwargs.pop(fn_arg_name)
else:
val = getattr(config, config_arg_name)
return val
embed_dim = _handle_deprecated_argument("hidden_size", config, "embed_dim", kwargs)
num_heads = _handle_deprecated_argument("num_attention_heads", config, "num_heads", kwargs)
dropout = _handle_deprecated_argument("attention_dropout", config, "dropout", kwargs)
enable_bias = _handle_deprecated_argument("enable_bias", config, "bias", kwargs)
696b787
to
7d4c688
Compare
Hi @amyeroberts, I have pushed the changes you suggested, in addition to that I have also updated the model_doc file and added speedup graphs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@susnato Thanks for iterating! All looks good to me.
Let's get a second review from @younesbelkada before merging as he's more familiar with the Flash Attention code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your great work ! Can you just confirm the slow integration tests for OPT pass before merging? 🙏
Yes, I just checked and all tests are passing! @younesbelkada! Flash Attention tests - I found only one Integration test[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I'll let @amyeroberts merge the PR
What does this PR do?
Adds Flash Attention 2 for
OPT
as discussed in in this issue - #26350 .Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
cc : @younesbelkada