Skip to content

Commit

Permalink
Add softmax -> dropout -> mm <- non pattern pattern / add new names t…
Browse files Browse the repository at this point in the history
…o pattern
  • Loading branch information
daniil-lyakhov committed Sep 29, 2023
1 parent 1920398 commit d411fd8
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions nncf/torch/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,33 @@ def _add_softmax_matmul(
pattern.add_edge(matmul_branch_nodes, matmul)


def _add_softmax_dropout_matmul(
pattern: GraphPattern, matmul_aliases, reshape_squeeze_aliases, gather_aliases, transpose_aliases
) -> None:
# SOFTMAX
# \
# \
# \
# DROPOUT RESHAPE||TRANSPOSE||GATHER||SQUEEZE
# \ /
# \ /
# \ /
# \ /
# \ /
# \ /
# MATMUL
branch_matmul_nodes = reshape_squeeze_aliases + gather_aliases + transpose_aliases
softmax = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: "softmax"})
dropout = pattern.add_node(**{GraphPattern.LABEL_ATTR: "DROPOUT", GraphPattern.METATYPE_ATTR: "dropout"})
matmul = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: matmul_aliases})
matmul_branch_nodes = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "NON_PATTERN", GraphPattern.METATYPE_ATTR: branch_matmul_nodes}
)
pattern.add_edge(softmax, dropout)
pattern.add_edge(dropout, matmul)
pattern.add_edge(matmul_branch_nodes, matmul)


def _add_softmax_reshape_matmul(
pattern: GraphPattern, matmul_aliases, reshape_squeeze_aliases, gather_aliases, transpose_aliases
) -> None:
Expand Down Expand Up @@ -67,8 +94,18 @@ def _add_softmax_reshape_matmul(

@PT_IGNORED_PATTERNS.register(IgnoredPatternNames.MULTIHEAD_ATTENTION_OUTPUT)
def create_multihead_attention_output() -> GraphPattern:
matmul_aliases = ["linear", "addmm", "matmul", "bmm", "mm", "baddbmm"]
reshape_squeeze_aliases = ["reshape", "view", "flatten", "squeeze", "unsqueeze", "squeeze", "flatten", "unsqueeze"]
matmul_aliases = ["linear", "addmm", "matmul", "bmm", "mm", "baddbmm", "__matmul__"]
reshape_squeeze_aliases = [
"reshape",
"view",
"flatten",
"squeeze",
"unsqueeze",
"squeeze",
"flatten",
"unsqueeze",
"unbind",
]
gather_aliases = ["gather", "index_select", "where", "index_select", "__getitem__"]
transpose_aliases = ["transpose", "permute", "transpose_"]

Expand All @@ -80,6 +117,13 @@ def create_multihead_attention_output() -> GraphPattern:
gather_aliases=gather_aliases,
transpose_aliases=transpose_aliases,
)
_add_softmax_dropout_matmul(
pattern,
matmul_aliases=matmul_aliases,
reshape_squeeze_aliases=reshape_squeeze_aliases,
gather_aliases=gather_aliases,
transpose_aliases=transpose_aliases,
)
_add_softmax_reshape_matmul(
pattern,
matmul_aliases=matmul_aliases,
Expand Down

0 comments on commit d411fd8

Please sign in to comment.