Skip to content

Commit

Permalink
Add Concat to _add_softmax_matmul ignored pattern (#2244)
Browse files Browse the repository at this point in the history
### Changes

Added `Concat` to `MULTIHEAD_ATTENTION_OUTPUT` ignored pattern for OV,
ONNX, Torch backends

### Reason for changes

To improve accuracy of https://huggingface.co/EleutherAI/gpt-neo-1.3B
model

### Related tickets

* 117617
  • Loading branch information
l-bat authored Nov 7, 2023
1 parent 074e749 commit 96de397
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 22 deletions.
18 changes: 10 additions & 8 deletions nncf/onnx/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,28 @@


def _add_softmax_matmul(pattern: GraphPattern) -> None:
# SOFTMAX RESHAPE||TRANSPOSE||GATHER||SQUEEZE
# SOFTMAX RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT
# \ /
# \ /
# \ /
# \ /
# \ /
# MATMUL
reshape_transpose_gather_squeeze = [
branch_matmul_nodes = [
om.ONNXReshapeMetatype,
om.ONNXTransposeMetatype,
om.ONNXGatherMetatype,
om.ONNXSqueezeMetatype,
om.ONNXConcatMetatype,
]
softmax = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: om.ONNXSoftmaxMetatype}
)
matmul = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: MATMUL_METATYPES})
matmul_branch_nodes = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "RESHAPE||TRANSPOSE||GATHER||SQUEEZE",
GraphPattern.METATYPE_ATTR: reshape_transpose_gather_squeeze,
GraphPattern.LABEL_ATTR: "RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT",
GraphPattern.METATYPE_ATTR: branch_matmul_nodes,
}
)
pattern.add_edge(softmax, matmul)
Expand All @@ -51,19 +52,20 @@ def _add_softmax_reshape_matmul(pattern: GraphPattern) -> None:
# \
# \
# \
# RESHAPE RESHAPE||TRANSPOSE||GATHER||SQUEEZE
# RESHAPE RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT
# \ /
# \ /
# \ /
# \ /
# \ /
# \ /
# MATMUL
reshape_transpose_gather_squeeze = [
branch_matmul_nodes = [
om.ONNXReshapeMetatype,
om.ONNXTransposeMetatype,
om.ONNXGatherMetatype,
om.ONNXSqueezeMetatype,
om.ONNXConcatMetatype,
]
softmax = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: om.ONNXSoftmaxMetatype}
Expand All @@ -74,8 +76,8 @@ def _add_softmax_reshape_matmul(pattern: GraphPattern) -> None:
matmul = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: MATMUL_METATYPES})
matmul_branch_nodes = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "RESHAPE||TRANSPOSE||GATHER||SQUEEZE",
GraphPattern.METATYPE_ATTR: reshape_transpose_gather_squeeze,
GraphPattern.LABEL_ATTR: "RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT",
GraphPattern.METATYPE_ATTR: branch_matmul_nodes,
}
)
pattern.add_edge(softmax, reshape)
Expand Down
23 changes: 15 additions & 8 deletions nncf/openvino/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,26 @@


def _add_softmax_matmul(pattern: GraphPattern) -> None:
# SOFTMAX RESHAPE||TRANSPOSE||GATHER||SQUEEZE
# SOFTMAX RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT
# \ /
# \ /
# \ /
# \ /
# \ /
# MATMUL
reshape_transpose_gather_squeeze = [
branch_matmul_nodes = [
om.OVReshapeMetatype,
om.OVTransposeMetatype,
om.OVGatherMetatype,
om.OVSqueezeMetatype,
om.OVConcatMetatype,
]
softmax = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: om.OVSoftmaxMetatype})
matmul = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.OVMatMulMetatype})
matmul_branch_nodes = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "RESHAPE||TRANSPOSE||GATHER||SQUEEZE",
GraphPattern.METATYPE_ATTR: reshape_transpose_gather_squeeze,
GraphPattern.LABEL_ATTR: "RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT",
GraphPattern.METATYPE_ATTR: branch_matmul_nodes,
}
)
pattern.add_edge(softmax, matmul)
Expand All @@ -48,22 +49,28 @@ def _add_softmax_reshape_matmul(pattern: GraphPattern) -> None:
# \
# \
# \
# RESHAPE RESHAPE||TRANSPOSE||GATHER||SQUEEZE
# RESHAPE RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT
# \ /
# \ /
# \ /
# \ /
# \ /
# \ /
# MATMUL
reshape_transpose_gather = [om.OVReshapeMetatype, om.OVTransposeMetatype, om.OVGatherMetatype, om.OVSqueezeMetatype]
branch_matmul_nodes = [
om.OVReshapeMetatype,
om.OVTransposeMetatype,
om.OVGatherMetatype,
om.OVSqueezeMetatype,
om.OVConcatMetatype,
]
softmax = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: om.OVSoftmaxMetatype})
reshape = pattern.add_node(**{GraphPattern.LABEL_ATTR: "RESHAPE", GraphPattern.METATYPE_ATTR: om.OVReshapeMetatype})
matmul = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.OVMatMulMetatype})
matmul_branch_nodes = pattern.add_node(
**{
GraphPattern.LABEL_ATTR: "RESHAPE||TRANSPOSE||GATHER||SQUEEZE",
GraphPattern.METATYPE_ATTR: reshape_transpose_gather,
GraphPattern.LABEL_ATTR: "RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT",
GraphPattern.METATYPE_ATTR: branch_matmul_nodes,
}
)
pattern.add_edge(softmax, reshape)
Expand Down
25 changes: 19 additions & 6 deletions nncf/torch/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@


def _add_softmax_matmul(
pattern: GraphPattern, matmul_aliases, reshape_squeeze_aliases, gather_aliases, transpose_aliases
pattern: GraphPattern,
matmul_aliases,
reshape_squeeze_aliases,
gather_aliases,
transpose_aliases,
concat_aliases,
) -> None:
# SOFTMAX RESHAPE||TRANSPOSE||GATHER||SQUEEZE
# SOFTMAX RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT
# \ /
# \ /
# \ /
# \ /
# \ /
# MATMUL
branch_matmul_nodes = reshape_squeeze_aliases + gather_aliases + transpose_aliases
branch_matmul_nodes = reshape_squeeze_aliases + gather_aliases + transpose_aliases + concat_aliases
softmax = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: "softmax"})
matmul = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: matmul_aliases})
matmul_branch_nodes = pattern.add_node(
Expand All @@ -38,21 +43,26 @@ def _add_softmax_matmul(


def _add_softmax_reshape_matmul(
pattern: GraphPattern, matmul_aliases, reshape_squeeze_aliases, gather_aliases, transpose_aliases
pattern: GraphPattern,
matmul_aliases,
reshape_squeeze_aliases,
gather_aliases,
transpose_aliases,
concat_aliases,
) -> None:
# SOFTMAX
# \
# \
# \
# RESHAPE RESHAPE||TRANSPOSE||GATHER||SQUEEZE
# RESHAPE RESHAPE||TRANSPOSE||GATHER||SQUEEZE||CONCAT
# \ /
# \ /
# \ /
# \ /
# \ /
# \ /
# MATMUL
branch_matmul_nodes = reshape_squeeze_aliases + gather_aliases + transpose_aliases
branch_matmul_nodes = reshape_squeeze_aliases + gather_aliases + transpose_aliases + concat_aliases
softmax = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SOFTMAX", GraphPattern.METATYPE_ATTR: "softmax"})
reshape = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "RESHAPE", GraphPattern.METATYPE_ATTR: reshape_squeeze_aliases}
Expand Down Expand Up @@ -80,6 +90,7 @@ def create_multihead_attention_output() -> GraphPattern:
]
gather_aliases = ["gather", "index_select", "where", "index_select", "__getitem__"]
transpose_aliases = ["transpose", "permute", "transpose_"]
concat_aliases = ["cat", "stack"]

pattern = GraphPattern()
_add_softmax_matmul(
Expand All @@ -88,13 +99,15 @@ def create_multihead_attention_output() -> GraphPattern:
reshape_squeeze_aliases=reshape_squeeze_aliases,
gather_aliases=gather_aliases,
transpose_aliases=transpose_aliases,
concat_aliases=concat_aliases,
)
_add_softmax_reshape_matmul(
pattern,
matmul_aliases=matmul_aliases,
reshape_squeeze_aliases=reshape_squeeze_aliases,
gather_aliases=gather_aliases,
transpose_aliases=transpose_aliases,
concat_aliases=concat_aliases,
)
return pattern

Expand Down

0 comments on commit 96de397

Please sign in to comment.