Skip to content

Commit

Permalink
fix pattern match may stop early
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 committed Aug 22, 2024
1 parent 0d504cc commit 08bbaac
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,25 @@ RMSFusion::RMSFusion() {
auto gamma = wrap_type<ov::op::v0::Constant>(type_matches(element::f32));
auto mul2 = wrap_type<ov::op::v1::Multiply>({gamma, mul1});

// compress RMS result
auto convert = wrap_type<ov::op::v0::Convert>({mul2});

auto comp = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{mul2, convert});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto node = m.get_match_root();
if (transformation_callback(node)) {
return false;
}

auto find_opt_convert = [&](const ov::Output<ov::Node>& out) -> std::shared_ptr<ov::Node> {
auto present_to = out.get_target_inputs();
if (present_to.size() == 1) {
auto convert_raw = dynamic_cast<ov::op::v0::Convert*>(present_to.begin()->get_node());
if (convert_raw)
return convert_raw->shared_from_this();
return nullptr;
}
// if multiple children, skip finding convert even there is one.
return nullptr;
};

auto x_output = pattern_map.at(x);

auto const_eps_node =
Expand All @@ -101,16 +108,22 @@ RMSFusion::RMSFusion() {
return false;
}

auto output_type = m.get_match_root()->get_output_element_type(0);
auto root = m.get_match_root();
// compress RMS result
auto convert = find_opt_convert(root);
if (convert)
root = convert;

auto output_type = root->get_output_element_type(0);
auto rms = std::make_shared<ov::op::internal::RMS>(x_output, gamma_node, eps_value, output_type);
rms->set_friendly_name(m.get_match_root()->get_friendly_name());
rms->set_friendly_name(root->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), rms);
ov::replace_node(m.get_match_root(), rms);
ov::replace_node(root, rms);

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(comp, "RMSFusion");
auto m = std::make_shared<ov::pass::pattern::Matcher>(mul2, "RMSFusion");
this->register_matcher(m, callback);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,36 @@ TEST_F(TransformationTestsF, RMSNormFusionTest5) {
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rms}, ov::ParameterVector{input});
}
}

// no convert at the end of the subgraph
TEST_F(TransformationTestsF, RMSNormFusionTest6) {
{
auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 6});
auto power_const = ov::opset10::Constant::create(ov::element::f32, {}, {2.f});
auto power = std::make_shared<ov::opset10::Power>(input, power_const);
auto mean_axes = ov::opset10::Constant::create(ov::element::i64, ov::Shape{1}, {-1});
auto mean = std::make_shared<ov::opset10::ReduceMean>(power, mean_axes, true);
auto eps = ov::opset10::Constant::create(ov::element::f32, {}, {1e-5f});
auto add_eps = std::make_shared<ov::opset10::Add>(mean, eps);
auto sqrt = std::make_shared<ov::opset10::Sqrt>(add_eps);
auto div_const = ov::opset10::Constant::create(ov::element::f32, {}, {-1});
auto div = std::make_shared<ov::opset10::Power>(sqrt, div_const);
auto mul1 = std::make_shared<ov::opset10::Multiply>(input, div);
auto gamma = ov::opset10::Constant::create(ov::element::f32,
ov::Shape{6},
{0.029f, 0.014f, 0.003f, 0.013f, 0.015f, 0.009f});
auto mul2 = std::make_shared<ov::opset10::Multiply>(gamma, mul1);

model = std::make_shared<ov::Model>(ov::NodeVector{mul2}, ov::ParameterVector{input});
manager.register_pass<RMSFusion>();
}
{
auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 6});
auto rms_const = ov::opset10::Constant::create(ov::element::f32,
ov::Shape{6},
{0.029f, 0.014f, 0.003f, 0.013f, 0.015f, 0.009f});
auto rms = std::make_shared<ov::op::internal::RMS>(input, rms_const, 1e-5f);

model_ref = std::make_shared<ov::Model>(ov::NodeVector{rms}, ov::ParameterVector{input});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ OPENVINO_SUPPRESS_DEPRECATED_START
#include "openvino/opsets/opset15_tbl.hpp"

#include "ov_ops/opset_private_tbl.hpp"

#undef _OPENVINO_OP_REG
};
OPENVINO_SUPPRESS_DEPRECATED_END
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,6 @@ InputsMap getInputMap() {
#include "openvino/opsets/opset15_tbl.hpp"

#include "ov_ops/opset_private_tbl.hpp"

#undef _OPENVINO_OP_REG
};
return inputsMap;
Expand Down

0 comments on commit 08bbaac

Please sign in to comment.