Skip to content

Commit

Permalink
fixed test and apply the last comment
Browse files Browse the repository at this point in the history
  • Loading branch information
alvoron committed Nov 25, 2024
1 parent b9df131 commit b5f3487
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, con

return false;
}

inline bool canBeMatMulExecutedInInt8(const ov::element::Type& firstType, const ov::element::Type& secondType) {
return one_of(firstType, ov::element::i8, ov::element::u8) && secondType == ov::element::i8;
}
} // namespace

bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
Expand All @@ -344,9 +348,10 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
SetNodeFusingType(node, NodeFusingType::FusedWithMisc);
} else if (isSuitableMatMulParent(node)) {
const bool is_fc = isFullyConnected(node);
const bool is_i8 = canBeMatMulExecutedInInt8(node->get_input_element_type(0), node->get_input_element_type(1));
const auto out_rank = node->get_output_partial_shape(0).rank();
if (is_fc) {
SetNodeFusingType(node, NodeFusingType::FusedWithFC);
SetNodeFusingType(node, is_i8 ? NodeFusingType::FusedWithFCI8 : NodeFusingType::FusedWithFC);
channelAxis = out_rank.is_static() ? (out_rank.get_length() == 3 ? 2 : 1) : DEFAULT_AXIS;
} else {
SetNodeFusingType(node, NodeFusingType::FusedWithMatMul);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ std::vector<InputShape> static_shapes = {
InputShape{{}, {{1, 32, 16, 16}}},
};

#if defined(OPENVINO_ARCH_ARM)
#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
const ExpectedResult successfull_fuse_result{1, 1, 3};
#else
const ExpectedResult successfull_fuse_result{1, 1, 2};
Expand Down

0 comments on commit b5f3487

Please sign in to comment.