forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adedd GroupConv transformations (openvinotoolkit#122)
- Loading branch information
1 parent
16a55e9
commit 176c7e3
Showing
5 changed files
with
190 additions
and
0 deletions.
There are no files selected for viewing
65 changes: 65 additions & 0 deletions
65
src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// Copyright (C) 2020-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
#include "convert_group_conv.hpp" | ||
|
||
#include <numeric> | ||
|
||
#include <openvino/opsets/opset1.hpp> | ||
#include <openvino/opsets/opset8.hpp> | ||
#include <ngraph/rt_info.hpp> | ||
|
||
ov::intel_cpu::ConvertGroupConvolution::ConvertGroupConvolution() { | ||
auto gconv = ngraph::pattern::wrap_type<opset8::GroupConvolution>(); | ||
|
||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) { | ||
enum Inputs {Data, Weights}; | ||
auto gconv = std::dynamic_pointer_cast<opset8::GroupConvolution>(m.get_match_root()); | ||
if (!gconv) { | ||
return false; | ||
} | ||
|
||
auto data_shape = gconv->get_input_shape(Inputs::Data); | ||
// Weights layout GOIYX | ||
size_t groups = gconv->get_input_shape(Inputs::Weights)[0]; | ||
if (groups == data_shape.at(1) && groups == gconv->get_output_shape(0)[1]) { // depthwise case | ||
return false; | ||
} | ||
|
||
ngraph::NodeVector replace_nodes; | ||
auto split_weights = std::make_shared<ov::opset1::Split>(gconv->input_value(Inputs::Weights), | ||
ov::opset8::Constant::create<int64_t>(ngraph::element::i64, ngraph::Shape{}, {0}), | ||
groups); | ||
replace_nodes.push_back(split_weights); | ||
|
||
auto axis = ov::opset8::Constant::create<int64_t>(ngraph::element::i64, ngraph::Shape{}, {1}); | ||
auto split = std::make_shared<ov::opset1::Split>(gconv->input_value(Inputs::Data), axis, groups); | ||
replace_nodes.push_back(split); | ||
|
||
ngraph::NodeVector concat_inputs; | ||
for (size_t g = 0; g < groups; g++) { | ||
auto out = split->output(g); | ||
auto filter = std::make_shared<ov::opset1::Squeeze>(split_weights->output(g), | ||
ov::opset8::Constant::create<int64_t>(ngraph::element::i64, ngraph::Shape{}, {0})); | ||
auto conv = std::make_shared<ov::opset8::Convolution>(out, | ||
filter, | ||
gconv->get_strides(), | ||
gconv->get_pads_begin(), | ||
gconv->get_pads_end(), | ||
gconv->get_dilations(), | ||
gconv->get_auto_pad()); | ||
concat_inputs.push_back(conv); | ||
replace_nodes.push_back(conv); | ||
} | ||
auto concat = std::make_shared<ov::opset8::Concat>(concat_inputs, 1); | ||
replace_nodes.push_back(concat); | ||
|
||
concat->set_friendly_name(gconv->get_friendly_name()); | ||
ngraph::copy_runtime_info(gconv, replace_nodes); | ||
ngraph::replace_node(gconv, concat); | ||
return true; | ||
}; | ||
auto m = std::make_shared<ngraph::pattern::Matcher>(gconv, "ConvertGroupConvolution"); | ||
register_matcher(m, callback); | ||
} |
18 changes: 18 additions & 0 deletions
18
src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// Copyright (C) 2020-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <ngraph/pattern/op/wrap_type.hpp> | ||
#include <ngraph/pass/graph_rewrite.hpp> | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
|
||
class ConvertGroupConvolution: public ngraph::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("ConvertGroupConvolution", "0"); | ||
ConvertGroupConvolution(); | ||
}; | ||
} // namespace intel_cpu | ||
} // namespace ov |
73 changes: 73 additions & 0 deletions
73
src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv1d.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Copyright (C) 2020-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
#include "convert_group_conv1d.hpp" | ||
|
||
#include <numeric> | ||
|
||
#include <openvino/opsets/opset1.hpp> | ||
#include <openvino/opsets/opset8.hpp> | ||
#include <ngraph/rt_info.hpp> | ||
#include <ngraph/pattern/op/wrap_type.hpp> | ||
|
||
template <class Conv> | ||
ngraph::matcher_pass_callback ov::intel_cpu::ConvertConv1DBase::convert_conv1d_to_conv2d() { | ||
return [&](ngraph::pattern::Matcher& m) { | ||
auto conv = std::dynamic_pointer_cast<Conv>(m.get_match_root()); | ||
if (!conv) { | ||
return false; | ||
} | ||
|
||
auto input_shape = conv->get_input_shape(0); | ||
// is Conv1D | ||
if (input_shape.size() != 3) { | ||
return false; | ||
} | ||
|
||
auto input = conv->input_value(0); | ||
auto weights = conv->input_value(1); | ||
auto input2d_shape = input_shape; | ||
input2d_shape.push_back(1); | ||
auto in2d_shape = std::make_shared<ov::opset8::Constant>(ngraph::element::i64, ngraph::Shape{4}, input2d_shape); | ||
|
||
auto weights2d_shape = weights.get_shape(); | ||
weights2d_shape.push_back(1); | ||
auto w_shape = std::make_shared<ov::opset8::Constant>(ngraph::element::i64, ngraph::Shape{weights2d_shape.size()}, weights2d_shape); | ||
|
||
auto input2d = std::make_shared<ov::opset8::Reshape>(input, in2d_shape, true); | ||
auto weights2d = std::make_shared<ov::opset8::Reshape>(weights, w_shape, true); | ||
|
||
auto conv2d = std::make_shared<Conv>(input2d, | ||
weights2d, | ||
ngraph::Strides{conv->get_strides()[0], 1}, | ||
ngraph::CoordinateDiff{conv->get_pads_begin()[0], 0}, | ||
ngraph::CoordinateDiff{conv->get_pads_end()[0], 0}, | ||
ngraph::Strides{conv->get_dilations()[0], 1}, | ||
conv->get_auto_pad()); | ||
|
||
auto in_shape = std::make_shared<ov::opset8::Constant>(ngraph::element::i64, ngraph::Shape{3}, conv->get_output_shape(0)); | ||
auto reshape = std::make_shared<ov::opset8::Reshape>(conv2d, in_shape, true); | ||
|
||
reshape->set_friendly_name(conv->get_friendly_name()); | ||
ngraph::copy_runtime_info(conv, {input2d, weights2d, conv2d, reshape}); | ||
ngraph::replace_node(conv, reshape); | ||
return true; | ||
}; | ||
} | ||
|
||
ov::intel_cpu::ConvertConv1D::ConvertConv1D() { | ||
auto m = std::make_shared<ngraph::pattern::Matcher>( | ||
ngraph::pattern::wrap_type<ov::opset8::Convolution>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), | ||
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())}, | ||
ngraph::pattern::has_static_shape()), "ConvertConvolutionToArm"); | ||
register_matcher(m, convert_conv1d_to_conv2d<ov::opset8::Convolution>()); | ||
} | ||
|
||
ov::intel_cpu::ConvertGroupConv1D::ConvertGroupConv1D() { | ||
auto m = std::make_shared<ngraph::pattern::Matcher>( | ||
ngraph::pattern::wrap_type<ov::opset8::GroupConvolution>({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), | ||
ngraph::pattern::any_input(ngraph::pattern::has_static_shape())}, | ||
ngraph::pattern::has_static_shape()), "ConvertGroupConvolutionToArm"); | ||
register_matcher(m, convert_conv1d_to_conv2d<ov::opset8::GroupConvolution>()); | ||
} |
29 changes: 29 additions & 0 deletions
29
src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_group_conv1d.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright (C) 2020-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <ngraph/pass/graph_rewrite.hpp> | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
class ConvertConv1DBase: public ngraph::pass::MatcherPass { | ||
protected: | ||
OPENVINO_RTTI("ConvertConv1DBase", "0"); | ||
template <class Conv> | ||
ngraph::matcher_pass_callback convert_conv1d_to_conv2d(); | ||
}; | ||
|
||
class ConvertConv1D: public ConvertConv1DBase { | ||
public: | ||
OPENVINO_RTTI("ConvertConv1D", "0"); | ||
ConvertConv1D(); | ||
}; | ||
|
||
class ConvertGroupConv1D: public ConvertConv1DBase { | ||
public: | ||
OPENVINO_RTTI("ConvertGroupConv1D", "0"); | ||
ConvertGroupConv1D(); | ||
}; | ||
} // namespace intel_cpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters