Skip to content

Commit

Permalink
fix pattern match may stop early
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 committed Aug 22, 2024
1 parent 0d504cc commit c8fdd42
Showing 1 changed file with 22 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,25 @@ RMSFusion::RMSFusion() {
auto gamma = wrap_type<ov::op::v0::Constant>(type_matches(element::f32));
auto mul2 = wrap_type<ov::op::v1::Multiply>({gamma, mul1});

// compress RMS result
auto convert = wrap_type<ov::op::v0::Convert>({mul2});

auto comp = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{mul2, convert});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto node = m.get_match_root();
if (transformation_callback(node)) {
return false;
}

auto find_opt_convert = [&](const ov::Output<ov::Node>& out) -> std::shared_ptr<ov::Node> {
auto present_to = out.get_target_inputs();
if (present_to.size() == 1) {
auto convert_raw = dynamic_cast<ov::op::v0::Convert*>(present_to.begin()->get_node());
if (convert_raw)
return convert_raw->shared_from_this();
return nullptr;
}
// if multiple children, skip finding convert even there is one.
return nullptr;
};

auto x_output = pattern_map.at(x);

auto const_eps_node =
Expand All @@ -101,16 +108,22 @@ RMSFusion::RMSFusion() {
return false;
}

auto output_type = m.get_match_root()->get_output_element_type(0);
auto root = m.get_match_root();
// compress RMS result
auto convert = find_opt_convert(root);
if (convert)
root = convert;

auto output_type = root->get_output_element_type(0);
auto rms = std::make_shared<ov::op::internal::RMS>(x_output, gamma_node, eps_value, output_type);
rms->set_friendly_name(m.get_match_root()->get_friendly_name());
rms->set_friendly_name(root->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), rms);
ov::replace_node(m.get_match_root(), rms);
ov::replace_node(root, rms);

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(comp, "RMSFusion");
auto m = std::make_shared<ov::pass::pattern::Matcher>(mul2, "RMSFusion");
this->register_matcher(m, callback);
}

Expand Down

0 comments on commit c8fdd42

Please sign in to comment.