Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RandomUniformFusion transformation. #7187

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9cae746
Added RandomUniformFusion transformation.
popovaan Aug 20, 2021
92395c7
Extended transformations for case with Convert, extended transformati…
popovaan Aug 31, 2021
8444101
Set to const unchanged variables.
popovaan Aug 31, 2021
bda512e
Merge remote-tracking branch 'upstream/master' into random_uniform_tr…
popovaan Aug 31, 2021
fc383d3
Apply suggestions from code review
popovaan Sep 9, 2021
ec84aa0
Reformat code, small corrections.
popovaan Sep 9, 2021
da9ce14
Merge remote-tracking branch 'upstream/master' into random_uniform_tr…
popovaan Sep 9, 2021
4cb096c
Added const shape checks.
popovaan Sep 10, 2021
5ab0156
Fixed transformation for case of different const ranks.
popovaan Sep 10, 2021
b805c53
Added type checks.
popovaan Sep 13, 2021
e2e040e
Merge remote-tracking branch 'upstream/master' into random_uniform_tr…
popovaan Sep 13, 2021
b45e495
Merge remote-tracking branch 'upstream/master' into random_uniform_tr…
popovaan Sep 13, 2021
c2cdc26
Merge remote-tracking branch 'upstream/master' into random_uniform_tr…
popovaan Sep 13, 2021
13cba4f
Apply suggestions from code review
popovaan Sep 14, 2021
2f907cc
United RandomUniformMulFusion and RandomUniformAddFusion to single tr…
popovaan Sep 14, 2021
9882a9a
Added negative tests.
popovaan Sep 14, 2021
8f5b0be
Used get_constant_from_source().
popovaan Sep 14, 2021
75fce5e
Moved transformation to common fusions.
popovaan Sep 14, 2021
c5a01a3
Fixed conflicts.
popovaan Sep 14, 2021
f8e5994
Merge branch 'master' into random_uniform_transformation
popovaan Sep 14, 2021
131a6e7
Added const refs.
popovaan Sep 14, 2021
4254444
Merge branch 'random_uniform_transformation' of https://github.com/po…
popovaan Sep 14, 2021
e6d496d
Update inference-engine/src/transformations/src/transformations/commo…
popovaan Sep 14, 2021
1e187a6
Changed to single class.
popovaan Sep 14, 2021
5ca25bf
Merge branch 'random_uniform_transformation' of https://github.com/po…
popovaan Sep 14, 2021
39e86fe
Merge remote-tracking branch 'upstream/master' into random_uniform_tr…
popovaan Sep 15, 2021
316d058
Corrected IRs checks in layer tests.
popovaan Sep 15, 2021
bc55779
Small corrections.
popovaan Sep 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <transformations/low_precision/disable_convert_constant_folding_on_const_path.hpp>
#include <transformations/common_optimizations/leaky_relu_fusion.hpp>
#include <transformations/common_optimizations/normalize_l2_fusion.hpp>
#include <transformations/common_optimizations/random_uniform_fusion.hpp>
#include "transformations/common_optimizations/mul_conv_fusion.hpp"

NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
Expand Down Expand Up @@ -92,6 +93,7 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
common_fusions->add_matcher<ngraph::pass::DilatedConvolutionConverter>();
common_fusions->add_matcher<ngraph::pass::GeluFusion>();
common_fusions->add_matcher<ngraph::pass::LeakyReluFusion>();
common_fusions->add_matcher<ngraph::pass::RandomUniformFusion>();
common_fusions->set_name("ngraph::pass::CommonFusions");

manager.register_pass<ngraph::pass::BinarizeWeights>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <ngraph/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API RandomUniformFusion;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief RandomUniformFusion transformation replaces RandomUniform -> Add or
* RandomUniform -> Mul subgraph with a RandomUniform and replaces min and max const
* with corrected values.
*/
class ngraph::pass::RandomUniformFusion : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
RandomUniformFusion();
};
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include "transformations/common_optimizations/relu_fake_quantize_fusion.hpp"
#include "transformations/common_optimizations/disable_random_uniform_constant_folding.hpp"
#include "transformations/common_optimizations/random_uniform_fusion.hpp"
#include "transformations/common_optimizations/add_fake_quantize_fusion.hpp"
#include "transformations/common_optimizations/mul_fake_quantize_fusion.hpp"
#include "transformations/common_optimizations/clamp_fusion.hpp"
Expand Down Expand Up @@ -139,6 +140,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
common_fusions->add_matcher<ngraph::pass::GeluFusion>();
common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
common_fusions->add_matcher<ngraph::pass::LeakyReluFusion>();
common_fusions->add_matcher<ngraph::pass::RandomUniformFusion>();
popovaan marked this conversation as resolved.
Show resolved Hide resolved
common_fusions->set_name("ngraph::pass::CommonFusions");

manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/random_uniform_fusion.hpp"

#include <memory>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/ngraph.hpp>

#include "itt.hpp"

NGRAPH_RTTI_DEFINITION(ngraph::pass::RandomUniformFusion, "RandomUniformFusion", 0);

ngraph::pass::RandomUniformFusion::RandomUniformFusion() {
MATCHER_SCOPE(RandomUniformFusion);
const auto data_pattern = ngraph::pattern::any_input();
const auto ru_min_input_pattern = ngraph::pattern::any_input();
const auto ru_max_input_pattern = ngraph::pattern::any_input();
const auto random_uniform_pattern =
ngraph::pattern::wrap_type<opset8::RandomUniform>({data_pattern, ru_min_input_pattern, ru_max_input_pattern},
pattern::consumers_count(1));
const auto const_pattern = ngraph::pattern::wrap_type<opset8::Constant>();

const auto convert_pattern = ngraph::pattern::wrap_type<opset8::Convert>({random_uniform_pattern});
const auto random_uniform_or_convert_pattern =
std::make_shared<pattern::op::Or>(OutputVector{random_uniform_pattern, convert_pattern});

const auto mul_add_pattern =
ngraph::pattern::wrap_type<opset8::Multiply, opset8::Add>({random_uniform_or_convert_pattern, const_pattern});

popovaan marked this conversation as resolved.
Show resolved Hide resolved
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto data = pattern_map.at(data_pattern);
const auto random_uniform = pattern_map.at(random_uniform_pattern);
const auto constant = pattern_map.at(const_pattern);
const auto ru = std::dynamic_pointer_cast<opset8::RandomUniform>(random_uniform.get_node_shared_ptr());
if (!ru)
popovaan marked this conversation as resolved.
Show resolved Hide resolved
return false;
if (!ru->get_out_type().is_real())
return false;

const auto old_const = std::dynamic_pointer_cast<opset8::Constant>(constant.get_node_shared_ptr());
if (!old_const)
return false;
if (!old_const->get_element_type().is_real())
return false;

auto const_shape = old_const->get_shape();
if (shape_size(const_shape) != 1)
return false;

const auto& value = old_const->cast_vector<double>();
auto new_const = op::Constant::create(ru->get_out_type(), Shape{}, value);

const auto& mul_add = pattern_map.at(mul_add_pattern);
const auto mul_add_ptr = std::dynamic_pointer_cast<ngraph::Node>(mul_add.get_node_shared_ptr());
const auto new_mul_add1 = mul_add_ptr->clone_with_new_inputs({ru->input_value(1), new_const});
const auto new_mul_add2 = mul_add_ptr->clone_with_new_inputs({ru->input_value(2), new_const});

const auto& folded_const1 = ngraph::get_constant_from_source(new_mul_add1);
const auto& folded_const2 = ngraph::get_constant_from_source(new_mul_add2);

const auto new_ru = ru->clone_with_new_inputs({data,
folded_const1 ? folded_const1 : new_mul_add1,
folded_const2 ? folded_const2 : new_mul_add2});
new_ru->set_friendly_name(m.get_match_root()->get_friendly_name());

if (pattern_map.count(convert_pattern)) {
const auto& convert = pattern_map.at(convert_pattern);
const auto cvt = std::dynamic_pointer_cast<opset8::Convert>(convert.get_node_shared_ptr());
if (!cvt)
return false;
if (!cvt->get_element_type().is_real())
return false;
const auto new_ru_conv = cvt->clone_with_new_inputs({new_ru});
copy_runtime_info({ru, cvt, mul_add.get_node_shared_ptr()}, {new_mul_add1, new_mul_add2, new_ru, new_ru_conv});
ngraph::replace_node(m.get_match_root(), new_ru_conv);
} else {
copy_runtime_info({ru, mul_add.get_node_shared_ptr()}, {new_mul_add1, new_mul_add2, new_ru});
ngraph::replace_node(m.get_match_root(), new_ru);
}

return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(mul_add_pattern, matcher_name);
this->register_matcher(m, callback);
}
Loading