Skip to content

Commit

Permalink
[PT] handle inner functions (#3064)
Browse files Browse the repository at this point in the history
### Changes

Add handle inner function for `relu`, `multi_head_attention_forward` and
`batch_norm`.
Analog of
https://github.com/openvinotoolkit/nncf/blob/5c1a029104bd89ba4e554f8db82af87f0aa6ec35/nncf/torch/dynamic_graph/patch_pytorch.py#L390-L399


### Related tickets

152996

### Tests

test_compare_torch_function_with_handle_inner_function - comparing code
of function with torch implementation to detect changes in torch
test_inner_functions - test traced graph of model
  • Loading branch information
AlexanderDokuchaev authored Nov 7, 2024
1 parent 80df08b commit 5893962
Show file tree
Hide file tree
Showing 7 changed files with 581 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,6 @@ def execute_hooks_for_parameter(self, value: torch.Tensor) -> torch.Tensor:
:param value: The tensor to which the post-hook will be applied.
:return: The processed tensor with the applied post-hook, if applicable.
"""
if not isinstance(value, torch.nn.Parameter):
return value
tensor_info = self.tensor_info.get(value)
if (
tensor_info is not None
Expand Down
302 changes: 302 additions & 0 deletions nncf/experimental/torch2/function_hook/handle_inner_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module implements selected functions from the `torch` module, excluding the `hand_function` mechanism.
It processes inner functions to handle exception hooks and graph analysis. The implementation is designed
to support custom handling of inner function exceptions for specific functions.
"""

import math
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
from torch.nn.functional import _canonical_mask
from torch.nn.functional import _in_projection # type:ignore[attr-defined]
from torch.nn.functional import _in_projection_packed # type:ignore[attr-defined]
from torch.nn.functional import _mha_shape_check # type:ignore[attr-defined]
from torch.nn.functional import _none_or_dtype
from torch.nn.functional import _verify_batch_size # type:ignore[attr-defined]
from torch.nn.functional import dropout
from torch.nn.functional import linear
from torch.nn.functional import pad
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.functional import softmax

Tensor = torch.Tensor


def multi_head_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Optional[Tensor],
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:

is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)

if not is_batched:
query = query.unsqueeze(1)
key = key.unsqueeze(1)
value = value.unsqueeze(1)
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.unsqueeze(0)

tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape

key_padding_mask = _canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype,
)

if is_causal and attn_mask is None:
raise RuntimeError(
"Need attn_mask if specifying the is_causal hint. "
"You may use the Transformer module method "
"`generate_square_subsequent_mask` to create this mask."
)

if is_causal and key_padding_mask is None and not need_weights:

attn_mask = None
else:
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)

if key_padding_mask is not None:

is_causal = False

assert (
embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, torch.Tensor):
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
assert (
key.shape[:2] == value.shape[:2]
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"

if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = in_proj_bias.chunk(3)
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)

if attn_mask is not None:
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
)
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
)
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
else:
assert bias_k is None
assert bias_v is None

q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is None:
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
assert (
static_k.size(0) == bsz * num_heads
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
assert (
static_v.size(0) == bsz * num_heads
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v

if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))

src_len = k.size(1)

if key_padding_mask is not None:
assert key_padding_mask.shape == (
bsz,
src_len,
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
)
if attn_mask is None:
attn_mask = key_padding_mask
else:
attn_mask = attn_mask + key_padding_mask

if not training:
dropout_p = 0.0

if need_weights:
B, Nt, E = q.shape # noqa: F841
q_scaled = q * math.sqrt(1.0 / float(E))

assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"

if attn_mask is not None:
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
if dropout_p > 0.0:
attn_output_weights = dropout(attn_output_weights, p=dropout_p)

attn_output = torch.bmm(attn_output_weights, v)

attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
if average_attn_weights:
attn_output_weights = attn_output_weights.mean(dim=1)

if not is_batched:
attn_output = attn_output.squeeze(1)
attn_output_weights = attn_output_weights.squeeze(0)
return attn_output, attn_output_weights
else:
if attn_mask is not None:
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)

q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)

attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
if not is_batched:
attn_output = attn_output.squeeze(1)
return attn_output, None


def relu(input: Tensor, inplace: bool = False) -> Tensor:
if inplace:
result = torch.relu_(input)
else:
result = torch.relu(input)
return result


def batch_norm(
input: Tensor,
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
training: bool = False,
momentum: float = 0.1,
eps: float = 1e-5,
) -> Tensor:
if training:
_verify_batch_size(input.size())

return torch.batch_norm(
input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
)


MAP_HANDLER_TO_INNER_FUNCTION: Dict[Callable[..., Any], Callable[..., Any]] = {
torch.nn.functional.relu: relu,
torch.nn.functional.multi_head_attention_forward: multi_head_attention_forward,
torch.nn.functional.batch_norm: batch_norm,
}


def get_handle_inner_function(fn: Callable[..., Any]) -> Union[Callable[..., Any], None]:
"""
Returns the corresponding function to process inner functions.
:param fn: The function for which the handler is needed.
:return: The function without handle_function, or None if no handler is found.
"""
return MAP_HANDLER_TO_INNER_FUNCTION.get(fn)
8 changes: 8 additions & 0 deletions nncf/experimental/torch2/function_hook/hook_executor_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torch.overrides import TorchFunctionMode

from nncf.common.logging import nncf_logger as logger
from nncf.experimental.torch2.function_hook.handle_inner_functions import get_handle_inner_function
from nncf.experimental.torch2.function_hook.hook_storage import HookStorage
from nncf.experimental.torch2.function_hook.weak_map import WeakUnhashableKeyMap

Expand Down Expand Up @@ -217,6 +218,13 @@ def __torch_function__(

fn_name = func.__name__

# WA: to catch nested calls for some functions
# https://github.com/pytorch/pytorch/issues/55093
fn_for_nested_call = get_handle_inner_function(func)
if fn_for_nested_call is not None:
with self:
return fn_for_nested_call(*args, **kwargs)

if not self.enabled or fn_name in IGNORED_FN_NAMES:
return func(*args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
digraph {
rankdir=TB;
0 [fillcolor="#adadad", fontcolor="#000000", label="{type: input|name: x|dtype: torch.float32|shape: (1, 1, 1)}", shape=record, style="filled,rounded"];
1 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: bn.weight|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
2 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: bn.bias|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
3 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: bn.running_mean|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
4 [fillcolor="#ffffff", fontcolor="#000000", label="{type: const|name: bn.running_var|dtype: torch.float32|shape: (1,)}", shape=record, style="filled,rounded"];
5 [fillcolor="#ffadad", fontcolor="#000000", label="{type: function_call|op_name: bn/batch_norm/0|fn_name: batch_norm|args: [\nTensorMeta(dtype=torch.float32, shape=(1, 1, 1), requires_grad=False),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=True),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=False),\nTensorMeta(dtype=torch.float32, shape=(1,), requires_grad=False),\nFalse,\n0.1,\n1e-05,\nTrue,\n]|kwargs: \{\}}", shape=record, style="filled,rounded"];
6 [fillcolor="#adadad", fontcolor="#000000", label="{type: output|name: output|dtype: torch.float32|shape: (1, 1, 1)}", shape=record, style="filled,rounded"];
0 -> 5 [label="(1, 1, 1)\n0 → 0"];
1 -> 5 [label="(1,)\n0 → 1"];
2 -> 5 [label="(1,)\n0 → 2"];
3 -> 5 [label="(1,)\n0 → 3"];
4 -> 5 [label="(1,)\n0 → 4"];
5 -> 6 [label="(1, 1, 1)\n0 → 0"];
}
Loading

0 comments on commit 5893962

Please sign in to comment.