Skip to content

Commit

Permalink
SE block HW fusing pattern is presented
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Oct 5, 2023
1 parent eaddf18 commit 2b2a288
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 1 deletion.
109 changes: 109 additions & 0 deletions nncf/torch/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,112 @@ def create_h_sigmoid_act() -> GraphPattern:
main_pattern.add_pattern_alternative(pattern)

return main_pattern


@PT_HW_FUSED_PATTERNS.register(HWFusedPatternNames.SE_BLOCK)
def create_se_block() -> GraphPattern:
MEAN_OPERATIONS = {
GraphPattern.LABEL_ATTR: "REDUCE_MEAN",
GraphPattern.METATYPE_ATTR: ["avg_pool2d", "adaptive_avg_pool2d", "avg_pool3d", "adaptive_avg_pool3d", "mean"],
}

main_pattern = GraphPattern()

pattern = GraphPattern()

any_node = pattern.add_node(label="ANY", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_2 = pattern.add_node(label="SIGMOID", type="sigmoid")
multiply_node = pattern.add_node(label="MUL", type="__mul__")

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
pattern.add_edge(linear_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, activation_node_2)
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
main_pattern.add_pattern_alternative(pattern)

pattern = GraphPattern()

any_node = pattern.add_node(label="ANY", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_1 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_2 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_2 = pattern.add_node(label="SIGMOID", type="sigmoid")
multiply_node = pattern.add_node(label="MUL", type="__mul__")

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, linear_node_1)
pattern.add_edge(linear_node_1, add_node_1)
pattern.add_edge(add_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, add_node_2)
pattern.add_edge(add_node_2, activation_node_2)
pattern.add_edge(activation_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
main_pattern.add_pattern_alternative(pattern)

RESHAPE_NODES = {
GraphPattern.LABEL_ATTR: "RESHAPE",
GraphPattern.METATYPE_ATTR: ["reshape", "view", "flatten", "unsqueeze"],
}

pattern = GraphPattern()

any_node = pattern.add_node(label="ANY", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
reshape_node_1 = pattern.add_node(**RESHAPE_NODES)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
activation_node_2 = pattern.add_node(label="SIGMOID", type="sigmoid")
reshape_node_2 = pattern.add_node(**RESHAPE_NODES)
multiply_node = pattern.add_node(label="MUL", type="__mul__")

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, reshape_node_1)
pattern.add_edge(reshape_node_1, linear_node_1)
pattern.add_edge(linear_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, activation_node_2)
pattern.add_edge(activation_node_2, reshape_node_2)
pattern.add_edge(reshape_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
main_pattern.add_pattern_alternative(pattern)

pattern = GraphPattern()

any_node = pattern.add_node(label="ANY", type=GraphPattern.NON_PATTERN_NODE_TYPE)
reduce_mean_node = pattern.add_node(**MEAN_OPERATIONS)
reshape_node_1 = pattern.add_node(**RESHAPE_NODES)
linear_node_1 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_1 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_1 = pattern.add_node(**ATOMIC_ACTIVATIONS_OPERATIONS)
linear_node_2 = pattern.add_node(**LINEAR_OPERATIONS)
add_node_2 = pattern.add_node(label="ADD_BIAS", type=["__add__", "__sub__"])
activation_node_2 = pattern.add_node(label="SIGMOID", type="sigmoid")
reshape_node_2 = pattern.add_node(**RESHAPE_NODES)
multiply_node = pattern.add_node(label="MUL", type="__mul__")

pattern.add_edge(any_node, reduce_mean_node)
pattern.add_edge(reduce_mean_node, reshape_node_1)
pattern.add_edge(reshape_node_1, linear_node_1)
pattern.add_edge(linear_node_1, add_node_1)
pattern.add_edge(add_node_1, activation_node_1)
pattern.add_edge(activation_node_1, linear_node_2)
pattern.add_edge(linear_node_2, add_node_2)
pattern.add_edge(add_node_2, activation_node_2)
pattern.add_edge(activation_node_2, reshape_node_2)
pattern.add_edge(reshape_node_2, multiply_node)
pattern.add_edge(any_node, multiply_node)
main_pattern.add_pattern_alternative(pattern)

return main_pattern
1 change: 0 additions & 1 deletion tests/torch/test_pattern_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
HWFusedPatternNames.MVN_SCALE_SHIFT: "Not relevant for Torch.",
HWFusedPatternNames.NORMALIZE_L2_MULTIPLY: "Not relevant for Torch.",
HWFusedPatternNames.SCALE_SHIFT: "Not relevant for Torch.",
HWFusedPatternNames.SE_BLOCK: "Not relevant for Torch.",
HWFusedPatternNames.SOFTMAX_DIV: "Not relevant for Torch.",
HWFusedPatternNames.HSWISH_ACTIVATION: "Not relevant for Torch.",
HWFusedPatternNames.HSWISH_ACTIVATION_V2: "Not relevant for Torch.",
Expand Down

0 comments on commit 2b2a288

Please sign in to comment.