Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA2] Add flash attention for opt #26414

Merged
merged 16 commits into from
Nov 23, 2023
Merged

Conversation

susnato
Copy link
Contributor

@susnato susnato commented Sep 26, 2023

What does this PR do?

Adds Flash Attention 2 for OPT as discussed in in this issue - #26350 .

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc : @younesbelkada

@susnato
Copy link
Contributor Author

susnato commented Sep 26, 2023

The tests were done using a RTX 3060 (Ampere) which supports Flash Attention 2.
Screenshot from 2023-09-26 16-23-55

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look great to me! Thanks a lot for working on this @susnato !
I can confirm the tests pass on a A100, just tested it
Screenshot 2023-09-26 at 13 00 02

Before merging please have a look at my comment below, also can you add OPT to the list of officially supported models? The changes would go here: https://github.com/susnato/transformers/blob/flash_attn_opt/docs/source/en/perf_infer_gpu_one.md

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@susnato I have realised there were some issues with FA 2 + OPT + use_cache, some minor changes were needed, please have a look at susnato#3 and I will add later tests for the case users use caching to catch that in the future.

To reproduce:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-350m",
    use_flash_attention_2=True,
    torch_dtype=torch.float16
).to(0)

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

text = "Hello my name is"
inputs = tokenizer(text, return_tensors="pt").to(0)

out = model.generate(**inputs, max_new_tokens=30, use_cache=True)
print(tokenizer.batch_decode(out, skip_special_tokens=True))

@susnato
Copy link
Contributor Author

susnato commented Sep 26, 2023

Hi @younesbelkada, I have added the model to the list. Also I have merged the susnato#3, please let me know if anymore changes are needed or not.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for iterating ! Can you also run the styling checks?

make fixup

@susnato
Copy link
Contributor Author

susnato commented Sep 26, 2023

Hi @younesbelkada , done.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating! left one comment

@@ -216,7 +216,6 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
else {}
)
is_encoder_decoder = False
fx_compatible = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should not be removed

Copy link
Contributor Author

@susnato susnato Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to get error for the fx tests - test_torch_fx and test_torch_fx_output_loss.
The error is happening for this line.

The full error for test_torch_fx

tests/models/opt/test_modeling_opt.py F [100%]

========================================================= FAILURES =========================================================
________________________________________________ OPTModelTest.test_torch_fx ________________________________________________

self = <tests.models.opt.test_modeling_opt.OPTModelTest testMethod=test_torch_fx>
config = OPTConfig {
"_remove_final_layer_norm": false,
"activation_function": "relu",
"attention_dropout": 0.1,
"bos_t...d": 1,
"transformers_version": "4.34.0.dev0",
"use_cache": true,
"vocab_size": 99,
"word_embed_proj_dim": 16
}

inputs_dict = {'attention_mask': tensor([[True, True, True, True, True, True, True],
[True, True, True, True, True, True, Tr...91, 54, 98, 57, 55, 2],
[74, 79, 56, 51, 93, 26, 2],
[62, 18, 55, 3, 73, 74, 2]], device='cuda:0')}
output_loss = False

def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
    if not is_torch_fx_available() or not self.fx_compatible:
        return

    configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
    configs_no_init.return_dict = False

    for model_class in self.all_model_classes:
        model = model_class(config=configs_no_init)
        model.to(torch_device)
        model.eval()
        inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)

        try:
            if model.config.is_encoder_decoder:
                model.config.use_cache = False  # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
                labels = inputs.get("labels", None)
                input_names = [
                    "attention_mask",
                    "decoder_attention_mask",
                    "decoder_input_ids",
                    "input_features",
                    "input_ids",
                    "input_values",
                ]
                if labels is not None:
                    input_names.append("labels")

                filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
                input_names = list(filtered_inputs.keys())

                model_output = model(**filtered_inputs)

                traced_model = symbolic_trace(model, input_names)
                traced_output = traced_model(**filtered_inputs)
            else:
                input_names = [
                    "attention_mask",
                    "bbox",
                    "input_features",
                    "input_ids",
                    "input_values",
                    "pixel_values",
                    "token_type_ids",
                    "visual_feats",
                    "visual_pos",
                ]

                labels = inputs.get("labels", None)
                start_positions = inputs.get("start_positions", None)
                end_positions = inputs.get("end_positions", None)
                if labels is not None:
                    input_names.append("labels")
                if start_positions is not None:
                    input_names.append("start_positions")
                if end_positions is not None:
                    input_names.append("end_positions")

                filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
                input_names = list(filtered_inputs.keys())

                if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
                    not hasattr(model.config, "problem_type") or model.config.problem_type is None
                ):
                    model.config.problem_type = "single_label_classification"
              traced_model = symbolic_trace(model, input_names)

tests/test_modeling_common.py:877:


model = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
input_names = ['input_ids', 'attention_mask'], disable_check = False, tracer_cls = <class 'transformers.utils.fx.HFTracer'>

def symbolic_trace(
    model: PreTrainedModel,
    input_names: Optional[List[str]] = None,
    disable_check: bool = False,
    tracer_cls: Type[HFTracer] = HFTracer,
) -> GraphModule:
    """
    Performs symbolic tracing on the model.

    Args:
        model ([`PretrainedModel`]):
            The model to trace.
        input_names (`List[str]`, *optional*):
            The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
        disable_check (`bool`, *optional*, defaults to `False`):
            If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
        tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
            The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.

    Returns:
        `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.

    Example:

        ```python
        from transformers.utils.fx import symbolic_trace

        traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
        ```
    """
    if input_names is None:
        input_names = model.dummy_inputs.keys()

    input_names = list(input_names)
    concrete_args = get_concrete_args(model, input_names)

    if not disable_check:
        check_if_model_is_supported(model)

    # Tracing.
    tracer = tracer_cls()
  traced_graph = tracer.trace(model, concrete_args=concrete_args)

src/transformers/utils/fx.py:1250:


self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
root = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
concrete_args = {'head_mask': None, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, ...}
dummy_inputs = None, complete_concrete_args_with_inputs_not_in_dummy_inputs = True

def trace(
    self,
    root: Union[torch.nn.Module, Callable[..., Any]],
    concrete_args: Optional[Dict[str, Any]] = None,
    dummy_inputs: Optional[Dict[str, Any]] = None,
    complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph:
    """
    Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
    `torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
    the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
    `torch.nn.Module` instance to use as the root and add embedded constants to.

    Args:
        root (`torch.nn.Module` or  `Callable`):
            Either a `torch.nn.Module`` or a function to be traced through. If root is not a
            [`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
        concrete_args (`Dict[str, Any], *optional*):
            Concrete arguments that should not be treated as Proxies
        dummy_inputs (`Dict[str, Any]`, *optional*):
            The dummy inputs needed to handle data-dependent control-flow if `root` is not a
            [`~transformers.PreTrainedModel`]. It can also be used when `root` is a
            [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
        complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
            If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
            `dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.

    Returns:
        `torch.fx.Graph`:
            A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.

    """
    sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)

    if concrete_args is None:
        concrete_args = {}

    if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
        for param in sig.parameters.values():
            if param.name in dummy_inputs:
                continue
            if param.default is inspect.Parameter.empty:
                raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
        concrete_args.update(
            {
                p.name: p.default
                for p in sig.parameters.values()
                if (p.name not in dummy_inputs and p.name not in concrete_args)
            }
        )

    input_names = sig.parameters.keys() - concrete_args.keys()

    # Creating a random input shape to generate dummy inputs.
    batch_size = _generate_random_int()
    sequence_length = _generate_random_int()
    shape = [batch_size, sequence_length]

    if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
        num_choices = _generate_random_int(low=2, high=5)
        shape.insert(1, num_choices)

    inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
    for input_name in input_names:
        if input_name in inputs:
            continue
        # We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
        # be able to use HFTracer._generate_dummy_input.
        if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
            ("_deserialize_graph_module", "_CodeOnlyModule")
        ):
            inputs.update(self._generate_dummy_input(root, input_name, shape))
        else:
            raise RuntimeError(
                f"Could not generate input named {input_name} for because root is not a"
                " transformers.PreTrainedModel."
            )

    concrete_metas = {
        input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
        for input_name, input_ in inputs.items()
    }
    for param in sig.parameters.values():
        if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
            concrete_metas[f"**{param.name}"] = {}
    self.meta_args = concrete_metas
    self.patched_torch_methods = {
        target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
    }
    self.orig_fns = set()

    for name, (wrapper, orig) in self.patched_torch_methods.items():
        setattr(torch, name, wrapper)
        self.orig_fns.add(orig)

    try:
      self.graph = super().trace(root, concrete_args=concrete_args)

src/transformers/utils/fx.py:1088:


self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
root = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
concrete_args = {'head_mask': None, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, ...}

@compatibility(is_backward_compatible=True)
def trace(
    self,
    root: Union[torch.nn.Module, Callable[..., Any]],
    concrete_args: Optional[Dict[str, Any]] = None,
) -> Graph:
    """
    Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
    can either be an ``nn.Module`` instance or a Python callable.

    Note that after this call, ``self.root`` may be different from the ``root`` passed
    in here. For example, when a free function is passed to ``trace()``, we will
    create an ``nn.Module`` instance to use as the root and add embedded constants
    to.


    Args:

        root (Union[Module, Callable]): Either a ``Module`` or a function to be
            traced through. Backwards-compatibility for this parameter is
            guaranteed.
        concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
            not be treated as Proxies. This parameter is experimental and
            its backwards-compatibility is *NOT* guaranteed.

    Returns:

        A ``Graph`` representing the semantics of the passed-in ``root``.
    """
    global _is_fx_tracing_flag
    old_is_fx_tracing_flag = _is_fx_tracing_flag
    _is_fx_tracing_flag = True
    try:
        if isinstance(root, torch.nn.Module):
            self.root = root

            assert hasattr(
                type(root), self.traced_func_name
            ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"

            fn = getattr(type(root), self.traced_func_name)
            self.submodule_paths = {mod: name for name, mod in root.named_modules()}
        else:
            self.root = torch.nn.Module()
            fn = root

        tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None)
        self.graph = Graph(tracer_cls=tracer_cls)

        # When we encounter a Tensor value that's not a parameter, we look if it
        # is some other attribute on the model. Construct a dict mapping Tensor
        # values to the qualified name here for efficiency. This is used downstream
        # in create_arg
        self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {}

        def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
            for k, v in m.__dict__.items():
                if isinstance(v, (torch.Tensor, ScriptObject)):
                    self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
            for k, v in m.named_children():
                collect_tensor_attrs(v, prefix_atoms + [k])

        collect_tensor_attrs(self.root, [])

        assert isinstance(fn, FunctionType)

        fn_globals = fn.__globals__  # run before it gets patched
        fn, args = self.create_args_for_root(
            fn, isinstance(root, torch.nn.Module), concrete_args
        )

        parameter_proxy_cache: Dict[
            str, Proxy
        ] = {}  # Reduce number of get_attr calls

        # Method dispatch on parameters is not recorded unless it's directly used.
        # Thus, we need to insert a proxy when __getattr__ requests a parameter.
        @functools.wraps(_orig_module_getattr)
        def module_getattr_wrapper(mod, attr):
            attr_val = _orig_module_getattr(mod, attr)
            return self.getattr(attr, attr_val, parameter_proxy_cache)

        @functools.wraps(_orig_module_call)
        def module_call_wrapper(mod, *args, **kwargs):
            def forward(*args, **kwargs):
                return _orig_module_call(mod, *args, **kwargs)

            _autowrap_check(
                patcher,
                getattr(getattr(mod, "forward", mod), "__globals__", {}),
                self._autowrap_function_ids,
            )
            return self.call_module(mod, forward, args, kwargs)

        with _Patcher() as patcher:
            # allow duplicate patches to support the case of nested calls
            patcher.patch_method(
                torch.nn.Module,
                "__getattr__",
                module_getattr_wrapper,
                deduplicate=False,
            )
            patcher.patch_method(
                torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False
            )
            _patch_wrapped_functions(patcher)
            _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
            for module in self._autowrap_search:
                _autowrap_check(
                    patcher, module.__dict__, self._autowrap_function_ids
                )
            self.create_node(
                "output",
                "output",
              (self.create_arg(fn(*args)),),
                {},
                type_expr=fn.__annotations__.get("return", None),
            )

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:739:


self = OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLear...res=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
)
input_ids = Proxy(input_ids), attention_mask = Proxy(attention_mask), head_mask = None, past_key_values = None
inputs_embeds = None, use_cache = True, output_attentions = False, output_hidden_states = False, return_dict = False

@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_DOC,
    output_type=BaseModelOutputWithPast,
    config_class=_CONFIG_FOR_DOC,
    expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
  decoder_outputs = self.decoder(
        input_ids=input_ids,
        attention_mask=attention_mask,
        head_mask=head_mask,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

src/transformers/models/opt/modeling_opt.py:1023:


mod = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>

@functools.wraps(_orig_module_call)
def module_call_wrapper(mod, *args, **kwargs):
    def forward(*args, **kwargs):
        return _orig_module_call(mod, *args, **kwargs)

    _autowrap_check(
        patcher,
        getattr(getattr(mod, "forward", mod), "__globals__", {}),
        self._autowrap_function_ids,
    )
  return self.call_module(mod, forward, args, kwargs)

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:717:


self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
m = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>, args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}

def call_module(self, m, forward, args, kwargs):
    self.orig_forward = forward
  return super().call_module(m, forward, args, kwargs)

src/transformers/utils/fx.py:987:


self = <transformers.utils.fx.HFTracer object at 0x7f49bab9e220>
m = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
forward = <function Tracer.trace..module_call_wrapper..forward at 0x7f49bab3a4c0>, args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}

@compatibility(is_backward_compatible=True)
def call_module(
    self,
    m: torch.nn.Module,
    forward: Callable[..., Any],
    args: Tuple[Any, ...],
    kwargs: Dict[str, Any],
) -> Any:
    """
    Method that specifies the behavior of this ``Tracer`` when it encounters
    a call to an ``nn.Module`` instance.

    By default, the behavior is to check if the called module is a leaf module
    via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
    ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
    the operations in its ``forward`` function.

    This method can be overridden to--for example--create nested traced
    GraphModules, or any other behavior you would want while tracing across
    ``Module`` boundaries.

    Args:

        m (Module): The module for which a call is being emitted
        forward (Callable): The forward() method of the ``Module`` to be invoked
        args (Tuple): args of the module callsite
        kwargs (Dict): kwargs of the module callsite

    Return:

        The return value from the Module call. In the case that a ``call_module``
        node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
        value was returned from the ``Module`` invocation.
    """
    module_qualified_name = self.path_of_module(m)
    if not self.is_leaf_module(m, module_qualified_name):
      return forward(*args, **kwargs)

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:434:


args = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}

def forward(*args, **kwargs):
  return _orig_module_call(mod, *args, **kwargs)

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:710:


self = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
input = ()
kwargs = {'attention_mask': Proxy(attention_mask), 'head_mask': None, 'input_ids': Proxy(input_ids), 'inputs_embeds': None, ...}
forward_call = <bound method OPTDecoder.forward of OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions)...out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)>

def _call_impl(self, *input, **kwargs):
    forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
    # If we don't have any hooks, we want to skip the rest of the logic in
    # this function, and just call forward.
    if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
            or _global_forward_hooks or _global_forward_pre_hooks):
      return forward_call(*input, **kwargs)

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194:


self = OPTDecoder(
(embed_tokens): Embedding(99, 16, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(22, ... out_features=16, bias=True)
(final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
)
)
input_ids = Proxy(view), attention_mask = Proxy(attention_mask), head_mask = None, past_key_values = None
inputs_embeds = Proxy(decoder_embed_tokens), use_cache = True, output_attentions = False, output_hidden_states = False
return_dict = False

def forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
    r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
            provide it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of

            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
            cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
            that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
            all `decoder_input_ids` of shape `(batch_size, sequence_length)`.

        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
            for more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
    """
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache

    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # retrieve input_ids and inputs_embeds
    if input_ids is not None and inputs_embeds is not None:
        raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
    elif input_ids is not None:
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])
    elif inputs_embeds is not None:
        input_shape = inputs_embeds.size()[:-1]
    else:
        raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids)

    batch_size, seq_length = input_shape
    past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
    # required mask seq length can be calculated via length of past
    mask_seq_length = past_key_values_length + seq_length

    # embed positions
    if attention_mask is None:
        attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
        padding_mask = None
    elif attention_mask.shape[1] != mask_seq_length:
        raise ValueError(
            f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
            f"{mask_seq_length} (sum of the lengths of current and past inputs)"
        )
    else:
      if 0 in attention_mask:

src/transformers/models/opt/modeling_opt.py:872:


self = Proxy(attention_mask), key = 0

def __contains__(self, key):
    if hasattr(self, "_metadata") and self._metadata is not None:
      return key in self._metadata

src/transformers/utils/fx.py:646:


self = tensor(..., device='meta', size=(14, 11), dtype=torch.int64), element = 0

def __contains__(self, element):
    r"""Check if `element` is present in tensor

    Args:
        element (Tensor or scalar): element to be checked
            for presence in current tensor"
    """
    if has_torch_function_unary(self):
        return handle_torch_function(Tensor.__contains__, (self,), self, element)
    if isinstance(element, (torch.Tensor, Number)):
        # type hint doesn't understand the __contains__ result array
      return (element == self).any().item()  # type: ignore[union-attr]

E NotImplementedError: Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_local_scalar_dense' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].
E
E CPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCPU.cpp:30798 [kernel]
E CUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCUDA.cpp:43635 [kernel]
E BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
E Python: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:140 [backend fallback]
E FuncTorchDynamicLayerBackMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:488 [backend fallback]
E Functionalize: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:291 [backend fallback]
E Named: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
E Conjugate: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
E Negative: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
E ZeroTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
E ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
E AutogradOther: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHIP: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMPS: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradIPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradVE: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradLazy: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMeta: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E Tracer: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/TraceType_2.cpp:16890 [kernel]
E AutocastCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:482 [backend fallback]
E AutocastCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
E FuncTorchBatched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/BatchRulesDynamic.cpp:64 [kernel]
E FuncTorchVmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
E Batched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]
E VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
E FuncTorchGradWrapper: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/TensorWrapper.cpp:189 [backend fallback]
E PythonTLSSnapshot: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]
E FuncTorchDynamicLayerFrontMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:484 [backend fallback]
E PythonDispatcher: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]

../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/_tensor.py:983: NotImplementedError

During handling of the above exception, another exception occurred:

self = <tests.models.opt.test_modeling_opt.OPTModelTest testMethod=test_torch_fx>

def test_torch_fx(self):
    config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
  self._create_and_check_torch_fx_tracing(config, inputs_dict)

tests/test_modeling_common.py:805:


tests/test_modeling_common.py:882: in _create_and_check_torch_fx_tracing
self.fail(f"Couldn't trace module: {e}")
E AssertionError: Couldn't trace module: Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_local_scalar_dense' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].
E
E CPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCPU.cpp:30798 [kernel]
E CUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/build/aten/src/ATen/RegisterCUDA.cpp:43635 [kernel]
E BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
E Python: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:140 [backend fallback]
E FuncTorchDynamicLayerBackMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:488 [backend fallback]
E Functionalize: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:291 [backend fallback]
E Named: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
E Conjugate: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
E Negative: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
E ZeroTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
E ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
E AutogradOther: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHIP: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMPS: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradIPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradXPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradVE: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradLazy: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradMeta: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/VariableType_2.cpp:16903 [autograd kernel]
E Tracer: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/autograd/generated/TraceType_2.cpp:16890 [kernel]
E AutocastCPU: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:482 [backend fallback]
E AutocastCUDA: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
E FuncTorchBatched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/BatchRulesDynamic.cpp:64 [kernel]
E FuncTorchVmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
E Batched: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]
E VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
E FuncTorchGradWrapper: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/TensorWrapper.cpp:189 [backend fallback]
E PythonTLSSnapshot: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]
E FuncTorchDynamicLayerFrontMode: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/functorch/DynamicLayer.cpp:484 [backend fallback]
E PythonDispatcher: registered at /opt/conda/conda-bld/pytorch_1670525539683/work/aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]

EDIT : The errors are fixed now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I see, in that line can you use instead torch.isin ? https://pytorch.org/docs/stable/generated/torch.isin.html - suggestion from @michaelbenayoun

elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
else:
padding_mask = attention_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to add a conditional check to check if there is any 0 in the attention mask, sometimes users pass an attention mask with a full ones, therefore we'll need to set padding_mask to None to avoid entering the padding case in OPTFlashAttention2 module

If the test fails with torch.fx can you try with torch.isin ? otherwise you can also do a check based on the sum of the attention mask

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sorry for removing it directly without discussing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no problem at all @susnato ! 🙏

Copy link
Contributor Author

@susnato susnato Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails for -

  • if torch.isin(attention_mask.int(), 0).sum().item()!=0(AssertionError: Couldn't trace module: symbolically traced variables cannot be used as inputs to control flow
    )
  • if torch.isin(attention_mask.int(), 0).sum()!=0(AssertionError: Couldn't trace module: bool should return bool, returned Tensor
    )
  • if attention_mask.sum().item()!=attention_mask.numel()(AssertionError: Couldn't trace module: symbolically traced variables cannot be used as inputs to control flow
    )
  • if not torch.equal(attention_mask, torch.ones_like(attention_mask))( AssertionError: Couldn't trace module: symbolically traced variables cannot be used as inputs to control flow)

Copy link
Contributor Author

@susnato susnato left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @younesbelkada, I have solved the error but it requires us to change the _flash_attention_forward function a little bit, otherwise if we try to add the if statement in the OPTDecoder.forward we are getting the error.

Comment on lines 413 to 418
"""
# we check if padding_mask contains all ones, in that case we don't use it. This is to make sure the torch.fx
# tests pass for the relevant models
if padding_mask.sum().item() != padding_mask.numel():
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
Copy link
Contributor Author

@susnato susnato Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that if we check if attention_mask is ones here, the torch.fx tests pass.

Comment on lines 870 to 873

padding_mask = attention_mask

causal_attention_mask = self._prepare_decoder_attention_mask(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case we need to pass padding_mask = attention_mask and we check for the ones array in the _flash_attention_forward. If we use any if-else condition here it gives an error for the torch.fx tests.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much for iterating @susnato I have tried out also different approaches on my end as well

            # Avoids torch FX issues
            if torch.any(~attention_mask.bool()):
                padding_mask = attention_mask
            else:
                padding_mask = None

and the FX tests still fails

The issue with the current approach is that the padding_mask.sum().item() != padding_mask.numel() adds an important overhead as it computed by each attention layer, especially for long sequence length leading to a slow down.

Let me dig a bit with @michaelbenayoun regarding failing FX tests and get back to you

@susnato
Copy link
Contributor Author

susnato commented Sep 26, 2023

Thanks for the reply @younesbelkada, in meantime I would like to work on adding flash attention for another model.

@younesbelkada
Copy link
Contributor

Thanks @susnato , sure yes ok, let me know on #26350 which model would you take, perhaps you can try your hands on Starcoder if you are interested. Let me know!

@susnato
Copy link
Contributor Author

susnato commented Sep 26, 2023

yes I am interested in that! @younesbelkada

@younesbelkada
Copy link
Contributor

Thanks very much ! Looking forward to your PR !

@susnato
Copy link
Contributor Author

susnato commented Sep 26, 2023

Also please assign me in the list :). @younesbelkada

@younesbelkada
Copy link
Contributor

Yes, just assigned you on the list !

@dathudeptrai
Copy link

@susnato @younesbelkada why not use F.scaled_dot_product_attention instead as it support Flash-Attention-2 and able to use with torch.compile ?. I used to use the flash-attention on official repo like this PR but now I moved to the F.scaled_dot_product_attention.

@younesbelkada
Copy link
Contributor

Hi @susnato
Thanks for your patience,
I had a look at the issue with @michaelbenayoun and the solution is the following:

1- define a method on the top level of the modeling opt file that checks if there is any padding token inside the attention mask

if is_torch_fx_available():
    @torch.fx.wrap
    def check_padding_in_attention_mask(attention_mask):
        if 0 in attention_mask:
            return attention_mask
        return None
else:
    def check_padding_in_attention_mask(attention_mask):
        if 0 in attention_mask:
            return attention_mask
        return None

And add the @torch.fx.wrap decorator in the method to make sure it is compatible with the PT versions we support.
2- revert until the commit 689f599 and replace the simple logic by a call to that method

That way FX tests should hopefully pass

I will later take care of moving check_padding_in_attention_mask in pytorch_utils file so that models that support FX tracing can use that method for future integrations.

@susnato
Copy link
Contributor Author

susnato commented Sep 28, 2023

Hi @younesbelkada ,

Full Traceback for Llama

self = <tests.models.llama.test_modeling_llama.LlamaModelTest testMethod=test_flash_attn_2_generate_use_cache>

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
    import torch

    for model_class in self.all_generative_model_classes:
        if not model_class._supports_flash_attn_2:
            return

        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        model = model_class(config)

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)

            dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
            dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)

            model = model_class.from_pretrained(
                tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
            ).to(torch_device)

            # Just test that a large cache works as expected
          _ = model.generate(
                dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
            )

tests/test_modeling_common.py:2936:


../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27: in decorate_context
return func(*args, **kwargs)
src/transformers/generation/utils.py:1606: in generate
return self.greedy_search(
src/transformers/generation/utils.py:2454: in greedy_search
outputs = self(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/llama/modeling_llama.py:1034: in forward
outputs = self.model(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/llama/modeling_llama.py:921: in forward
layer_outputs = decoder_layer(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/llama/modeling_llama.py:631: in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
../../anaconda3/envs/transformers/lib/python3.9/site-packages/torch/nn/modules/module.py:1194: in _call_impl
return forward_call(*input, **kwargs)
src/transformers/models/llama/modeling_llama.py:489: in forward
attn_output = self._flash_attention_forward(
src/transformers/models/llama/modeling_llama.py:546: in _flash_attention_forward
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
../../anaconda3/envs/transformers/lib/python3.9/site-packages/flash_attn/bert_padding.py:208: in pad_input
output = index_put_first_axis(hidden_states, indices, batch * seqlen)


ctx = <torch.autograd.function.IndexPutFirstAxisBackward object at 0x7f1fb51b9040>
values = tensor([[[-2.3666e-02, 2.1423e-02, -7.4646e-02, 3.6530e-02, -7.8979e-02,
-3.4546e-02, 1.0236e-01, 4.3549...4504e-03, -9.3384e-02,
-4.5532e-02, -5.5847e-02, 4.0253e-02]]], device='cuda:0',
dtype=torch.float16)
indices = tensor([0, 1], device='cuda:0', dtype=torch.int32), first_axis_dim = 2

@staticmethod
def forward(ctx, values, indices, first_axis_dim):
    ctx.save_for_backward(indices)
    assert indices.ndim == 1
    assert values.ndim >= 2
    output = torch.zeros(
        first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
    )
    # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
  output[indices] = values

E IndexError: tensors used as indices must be long, byte or bool tensors

../../anaconda3/envs/transformers/lib/python3.9/site-packages/flash_attn/bert_padding.py:51: IndexError

BTW I updated the flash attention library and merged to main but the error is still there.

@susnato
Copy link
Contributor Author

susnato commented Sep 28, 2023

I have pushed the docstring change.

If you don't mind @younesbelkada , could you please checkout this branch and run the tests on your end and let me know the results?
If this passes on your machine, then It could be due to some unknown error on my end.

@younesbelkada
Copy link
Contributor

Sure @susnato no problem, will run the tests locally tomorrow and let you know how it goes

@susnato
Copy link
Contributor Author

susnato commented Oct 5, 2023

Hi @younesbelkada, did you manage to check if the tests run successfully on your local machine?

@huggingface huggingface deleted a comment from github-actions bot Oct 30, 2023
@younesbelkada
Copy link
Contributor

Yes the tests seemed to pass the time I ran them, however since #26792 being merged you need to remove the support for padding_mask and follow whqat has been done in that PR. Let me know if you need help on this!

@susnato
Copy link
Contributor Author

susnato commented Oct 30, 2023

After I finish updating the gpt_bigcode I will take this up.

@younesbelkada
Copy link
Contributor

OK awesome, thanks @susnato !

@susnato
Copy link
Contributor Author

susnato commented Oct 30, 2023

Hi @younesbelkada, I have pushed the changes.

All the flash attention tests are running fine except test_flash_attn_2_generate_use_cache but I believe it's due to some problem with my local installation since the model generates outputs quite well when I run separately.

@susnato
Copy link
Contributor Author

susnato commented Oct 31, 2023

Hello @younesbelkada , are you going to check the speed-up for this model too?

I am asking this out of curiosity - Are you going to check speed-ups for every model, FlashAttention is added to or is this benchmark needed for popular ones only?

@susnato susnato changed the title Add flash attention for opt [FA2] Add flash attention for opt Nov 2, 2023
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice thanks a lot for your huge work !
I can confirm interesting speedups for opt-350m and opt-2.7b

opt-2.7b + large seqlen:

Screenshot 2023-11-07 at 16 55 16

opt-350m + small seqlen:

Screenshot 2023-11-07 at 16 55 38

I just left a single question about changing the attention class' init signature, otherwise LGTM !

dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
config: OPTConfig,
Copy link
Contributor

@younesbelkada younesbelkada Nov 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we ok modifying the Attention Class's signature? We've had issues in the past when changing signature of such classes (e.g. for Llama when we added padding_mask in the forward signature) cc @amyeroberts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be OK, as this isn't a class importable from the top layer of transformers or in the documentation. However, it definitely can still cause issues! What I'd suggest is something like this (happy for you to modify as desired):

    def __init__(
        self,
        config,
        is_decoder: bool = False,
        **kwargs
    ):
	def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
	    """
	    If a the deprecated argument `fn_arg_name` is passed, raise a deprecation 
	    warning and return that value, otherwise take the equivalent config.config_arg_name
	    """
	    val = None
	    if fn_arg_name in kwargs:
	    	logging.warning(
	       		"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
	       		" Please set it in the config instead"
	       	)
	       	val = kwargs.pop(fn_arg_name)
	    else:
	       	val = getattr(config, config_arg_name)
            return val

        embed_dim = _handle_deprecated_argument("hidden_size", config, "embed_dim", kwargs)
        num_heads = _handle_deprecated_argument("num_attention_heads", config, "num_heads", kwargs)
        dropout = _handle_deprecated_argument("attention_dropout", config, "dropout", kwargs)
        enable_bias = _handle_deprecated_argument("enable_bias", config, "bias", kwargs)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Just some comments on handling old arguments + deprecation warnings

Comment on lines 481 to 484
self.self_attn = OPTAttention(
config=config,
is_decoder=True,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - can be put on one line

Suggested change
self.self_attn = OPTAttention(
config=config,
is_decoder=True,
)
self.self_attn = OPTAttention(config=config, is_decoder=True)

Comment on lines 486 to 489
self.self_attn = OptFlashAttention2(
config=config,
is_decoder=True,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.self_attn = OptFlashAttention2(
config=config,
is_decoder=True,
)
self.self_attn = OptFlashAttention2(config=config, is_decoder=True)

Comment on lines 290 to 297

if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like padding_mask was a previous argument

Suggested change
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")

Comment on lines 157 to 161
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

@@ -129,9 +150,15 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs,

is_decoder: bool = False,
bias: bool = True,
config: OPTConfig,
is_decoder: bool,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to keep the default value

Suggested change
is_decoder: bool,
is_decoder: bool = False,

dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
config: OPTConfig,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be OK, as this isn't a class importable from the top layer of transformers or in the documentation. However, it definitely can still cause issues! What I'd suggest is something like this (happy for you to modify as desired):

    def __init__(
        self,
        config,
        is_decoder: bool = False,
        **kwargs
    ):
	def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
	    """
	    If a the deprecated argument `fn_arg_name` is passed, raise a deprecation 
	    warning and return that value, otherwise take the equivalent config.config_arg_name
	    """
	    val = None
	    if fn_arg_name in kwargs:
	    	logging.warning(
	       		"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
	       		" Please set it in the config instead"
	       	)
	       	val = kwargs.pop(fn_arg_name)
	    else:
	       	val = getattr(config, config_arg_name)
            return val

        embed_dim = _handle_deprecated_argument("hidden_size", config, "embed_dim", kwargs)
        num_heads = _handle_deprecated_argument("num_attention_heads", config, "num_heads", kwargs)
        dropout = _handle_deprecated_argument("attention_dropout", config, "dropout", kwargs)
        enable_bias = _handle_deprecated_argument("enable_bias", config, "bias", kwargs)

@susnato
Copy link
Contributor Author

susnato commented Nov 16, 2023

Hi @amyeroberts, I have pushed the changes you suggested, in addition to that I have also updated the model_doc file and added speedup graphs.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@susnato Thanks for iterating! All looks good to me.

Let's get a second review from @younesbelkada before merging as he's more familiar with the Flash Attention code

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your great work ! Can you just confirm the slow integration tests for OPT pass before merging? 🙏

@susnato
Copy link
Contributor Author

susnato commented Nov 22, 2023

Yes, I just checked and all tests are passing! @younesbelkada!

Flash Attention tests -

Screenshot from 2023-11-23 01-04-43

I found only one Integration test[pt] related to OPT -

Screenshot from 2023-11-23 01-09-20

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll let @amyeroberts merge the PR

@amyeroberts amyeroberts merged commit 3bc50d8 into huggingface:main Nov 23, 2023
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants