Skip to content

Commit

Permalink
[Snippets] Added Select support
Browse files Browse the repository at this point in the history
[Snippets] Added Broadcast support
  • Loading branch information
a-sidorova committed Dec 5, 2022
1 parent 026360c commit 7c89867
Show file tree
Hide file tree
Showing 21 changed files with 586 additions and 30 deletions.
12 changes: 12 additions & 0 deletions src/common/snippets/include/snippets/pass/insert_movebroadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ class InsertMoveBroadcast: public ngraph::pass::MatcherPass {
InsertMoveBroadcast();
};

/**
* @interface BroadcastToMoveBroadcast
* @brief Inserts explicit MoveBroadcast instruction if broadcasting by most varying dimension is needed instead of Broadcast,
* otherwise pass removes Brodcast operation.
* The pass is used to convert model to a canonical form for code generation
* @ingroup snippets
*/
class BroadcastToMoveBroadcast: public ngraph::pass::MatcherPass {
public:
BroadcastToMoveBroadcast();
};

} // namespace pass
} // namespace snippets
} // namespace ngraph
10 changes: 10 additions & 0 deletions src/common/snippets/include/snippets/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ std::vector<size_t> get_node_output_layout(const Node* node);
inline ov::Dimension get_inner_dim(const ov::PartialShape &shape) { return *(shape.rbegin()); }
inline ov::Dimension get_outer_dim(const ov::PartialShape &shape) { return *(shape.rbegin() + 1); }

// Non-scalar Constants are tokenized as Parameters inside Subgraph body but some of the operations which Constant inputs
// should have explicit Constants even if they're non-scalar (Reshape, Transpose, Broadcast)
// This check returns True if Constant op of this op should be inside Subgraph body
inline auto constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool {
return ov::is_type<ov::op::v0::FakeQuantize>(node) ||
ov::is_type<ov::op::v1::Transpose>(node) ||
ov::is_type<ov::op::v1::Broadcast>(node) ||
ov::is_type<ov::op::v1::Reshape>(node);
}

} // namespace utils
} // namespace snippets
} // namespace ngraph
20 changes: 11 additions & 9 deletions src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ auto snippets::op::Subgraph::wrap_node_as_subgraph(const std::shared_ptr<ov::Nod

for (const auto& input : node->input_values()) {
if (ov::is_type<ngraph::opset1::Constant>(input.get_node_shared_ptr()) &&
(ngraph::shape_size(input.get_shape()) == 1 ||
ov::is_type<ov::op::v0::FakeQuantize>(node) ||
ov::is_type<ov::op::v1::Transpose>(node))) {
(ngraph::shape_size(input.get_shape()) == 1 || utils::constant_input_should_be_inside_body(node))) {
body_inputs.push_back(input);
} else {
auto parameter = std::make_shared<ngraph::opset1::Parameter>(input.get_element_type(), input.get_partial_shape());
Expand Down Expand Up @@ -382,11 +380,13 @@ void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outpu
// - Insert Convert before operations that doesn't support original element type for execution
// - Insert reverse Convert before operations that support original element type
// but have inputs that doesn't support it (because before them will be inserted Convert with exec_type - first point)
// Then we should use ConstantFolding pass to convert element type of Scalars before inference.
// - Then we should use ConstantFolding pass to convert element type of Scalars before inference.
// - Eliminate redundant Converts which can be inserted in AlignElementType() pass
ngraph::pass::Manager manager;
if (config.m_is_needed_to_align_precision) {
manager.register_pass<snippets::pass::AlignElementType>(execution_element_type);
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::EliminateConvert>();
}
manager.run_passes(m_body);
}
Expand Down Expand Up @@ -415,6 +415,7 @@ void snippets::op::Subgraph::convert_to_snippet_dialect() {
manager.register_pass<snippets::pass::InsertBuffer>(tileRank);
manager.register_pass<snippets::pass::SoftmaxDecomposition>(count, tileRank);
manager.register_pass<snippets::pass::TransposeDecomposition>();
manager.register_pass<snippets::pass::BroadcastToMoveBroadcast>();
manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
manager.register_pass<snippets::pass::ConvertPowerToPowerStatic>();
manager.register_pass<snippets::pass::InsertLoad>(count);
Expand Down Expand Up @@ -505,13 +506,14 @@ snippets::Schedule snippets::op::Subgraph::generate(ngraph::pass::Manager& opt,
// check that body doesn't have constants for scheduling
std::vector<std::shared_ptr<opset1::Constant>> constants;
for (auto op : m_body->get_ordered_ops()) {
if (auto constant = ov::as_type_ptr<opset1::Constant>(op)) {
if (ngraph::shape_size(constant->get_shape()) != 1 && constant->get_shape() != Shape()) {
constants.push_back(constant);
}
if ((ov::is_type<opset1::Constant>(op) && ov::shape_size(op->get_shape()) != 1 && op->get_shape() != Shape()) ||
ov::is_type<ov::op::v1::Softmax>(op) ||
ov::is_type<ov::op::v8::Softmax>(op) ||
ov::is_type<ov::op::v1::Transpose>(op) ||
ov::is_type<ov::op::v1::Broadcast>(op)) {
throw ngraph::ngraph_error("External op detected: " + std::string(op->get_type_name()) + ". Snippet is illigal for scheduling");
}
}
NGRAPH_CHECK(!constants.size(), "External constants detected. Snippet is illigal for scheduling");

return {master_shape, false /*canBeLinearized*/, ptr};
}
Expand Down
14 changes: 8 additions & 6 deletions src/common/snippets/src/pass/align_element_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ inline auto is_in_op(const std::shared_ptr<ov::Node>& n) -> bool {
|| ov::is_type<ov::op::v0::Constant>(n);
}

// At the moment Subgraph supports only Eltwise, Convert, FQ (which is decomposed into Eltwises and Convert) and
// Softmax (which is decompsed into Eltwises as well)
// And only Eltwises supports execution only in "exec_type". So we can check op type from the opposite
// At the moment Subgraph supports only Eltwise, Select, Convert, Broadcast and FQ (which is decomposed into Eltwises and Convert) with
// Softmax (which is decomposed into Eltwises as well)
// And only Eltwise and Select ops supports execution only in "exec_type". So we can check op type from the opposite
// NOTE: This check is only for executable which isn't Parameter/Constant/Result
inline auto op_supports_only_exec_type(const std::shared_ptr<ov::Node>& n) -> bool {
return !is_in_op(n) &&
!ov::is_type<ov::op::v0::Result>(n) &&
!ov::is_type<ov::op::v1::Transpose>(n) &&
!ov::is_type<ov::op::v0::Convert>(n);
!ov::is_type<ov::op::v0::Convert>(n) &&
!ov::is_type<ov::op::v1::Broadcast>(n);
}

} // namespace
Expand Down Expand Up @@ -60,7 +61,8 @@ bool ngraph::snippets::pass::AlignElementType::run_on_model(const std::shared_pt
// - Input is Convert with unsupported destination type
// - Input is Op which support any element type
// We couldn't unite these conditions and just check that element type isn't supported exec type
// because we don't call validate_and_infer_types() so we don't know new precisions
// because we don't call validate_and_infer_types() so we don't know new precisions after setting of original
// input and output element types
if ((existing_convert && existing_convert->get_destination_type() != exec_type) ||
(!op_supports_only_exec_type(shared_input))) {
insertConvert(op, i, exec_type);
Expand Down Expand Up @@ -91,6 +93,6 @@ bool ngraph::snippets::pass::AlignElementType::run_on_model(const std::shared_pt
}

bool ngraph::snippets::pass::AlignElementType::opNeedsAlignElementType(const std::shared_ptr<ov::Node>& op, const ov::element::Type exec_type) {
// At the moment Snippets support only Eltwise/Convert/FQ which one output so we can just call get_element_type()
// At the moment Snippets support only Eltwise/Convert/FQ/Select/Softmax/Broadcast which one output so we can just call get_element_type()
return op_supports_only_exec_type(op) && op->get_element_type() != exec_type;
}
44 changes: 32 additions & 12 deletions src/common/snippets/src/pass/collapse_subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
is_type<opset1::Constant>(n->get_input_node_shared_ptr(4));
};

auto is_supported_ternary_eltwise_op = [](const std::shared_ptr<const Node> &n) -> bool {
return ov::is_type<opset1::Select>(n);
};

auto is_supported_binary_eltwise_op = [](const std::shared_ptr<const Node> &n) -> bool {
return ov::is_type<opset1::Add>(n)
|| ov::is_type<opset1::Divide>(n)
Expand Down Expand Up @@ -138,23 +142,41 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
return axis >= 0 && axis == (rank.get_length() - 1);
};

return is_supported_fq_op(n)
|| is_supported_unary_eltwise_op(n)
|| is_supported_binary_eltwise_op(n)
|| is_supported_transpose(n)
|| is_supported_softmax(n)
|| is_supported_matmul(n);
auto is_supported_broadcast_op = [](const std::shared_ptr<const Node> &n) -> bool {
// TODO: Add check for broadcastable input shapes of Broadcast children
// Codogen removes Broadcast op, insert BroadcastMove if needed and save just last dim.
// But if Broadcast child output shape depends on Broadcast we can loss needed output shape
// Example:
// in0 [1, 1, 1] in0 [1, 1, 1] in0 [1, 1, 1] in0 [1, 1, 1]
// Broadcast [1, 10, 1] / \ /
// \ / --->>> Add
// Add |
// Result [1, 10, 1] Result [1, 1, 1]
auto broadcast = ov::as_type_ptr<const opset1::Broadcast>(n);
return broadcast && broadcast->get_broadcast_spec().m_type == ov::op::AutoBroadcastType::NUMPY;
};

return is_supported_fq_op(n) ||
is_supported_unary_eltwise_op(n) ||
is_supported_binary_eltwise_op(n) ||
is_supported_ternary_eltwise_op(n) ||
is_supported_transpose(n) ||
is_supported_softmax(n) ||
is_supported_matmul(n) ||
is_supported_broadcast_op(n);
}

auto has_supported_in_out(const std::shared_ptr<const Node> &n) -> bool {
auto supported = [&n](descriptor::Tensor& t) -> bool {
static const std::set<ngraph::element::Type> supported_data_types =
{ ngraph::element::f32, ngraph::element::bf16, ngraph::element::i8, ngraph::element::u8 };
// Todo: int32 isn't supported in general because i32 emitters are required for bit-exact i32 calculations in some cases
// So i32 is supported exclusively for transposes
// So i32 is supported exclusively for transposes and broadcast
return t.get_partial_shape().is_static() &&
(supported_data_types.count(t.get_element_type()) != 0 ||
(ov::is_type<const opset1::Transpose>(n) && t.get_element_type() == ngraph::element::i32));
(supported_data_types.count(t.get_element_type()) != 0 ||
(t.get_element_type() == ngraph::element::i32 &&
(ov::is_type<const opset1::Transpose>(n) ||
ov::is_type<const opset1::Broadcast>(n))));
};
const auto & inputs = n->inputs();
const auto & outputs = n->outputs();
Expand Down Expand Up @@ -491,9 +513,7 @@ TokenizeSnippets::TokenizeSnippets() {
// [*] We support Transpose with second Constant input (represents order). This Constant will not be scheduled
// and will only be used to decompose Transpose into a proper Load, Store and Loop combination.
if (ov::is_type<ngraph::opset1::Constant>(input_node) &&
(ngraph::shape_size(input_value.get_shape()) == 1 ||
ov::is_type<ov::op::v0::FakeQuantize>(node) ||
ov::is_type<ov::op::v1::Transpose>(node))) {
(ngraph::shape_size(input_value.get_shape()) == 1 || utils::constant_input_should_be_inside_body(node))) {
internal_inputs.push_back(input_node->output(0));
} else {
external_inputs.push_back(input_value);
Expand Down
3 changes: 2 additions & 1 deletion src/common/snippets/src/pass/common_optimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "snippets/pass/fq_decomposition.hpp"
#include "snippets/pass/softmax_reshape_elimination.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/utils.hpp"
#include "snippets/itt.hpp"

NGRAPH_RTTI_DEFINITION(ngraph::snippets::pass::CommonOptimizations, "Snippets::CommonOptimizations", 0);
Expand All @@ -35,7 +36,7 @@ void ConvertConstantsToParameters(const std::shared_ptr<ngraph::snippets::op::Su
continue;

const auto child = constant->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
if (ov::is_type<ov::op::v1::Transpose>(child) || ov::is_type<ov::op::v1::Reshape>(child))
if (utils::constant_input_should_be_inside_body(child) && !ov::is_type<ov::op::v0::FakeQuantize>(child))
continue;

auto parameter = std::make_shared<opset1::Parameter>(constant->get_element_type(), constant->output(0).get_partial_shape());
Expand Down
23 changes: 23 additions & 0 deletions src/common/snippets/src/pass/insert_movebroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "snippets/pass/insert_movebroadcast.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/utils.hpp"
#include <ngraph/pattern/op/wrap_type.hpp>

#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
Expand Down Expand Up @@ -121,3 +122,25 @@ ngraph::snippets::pass::InsertMoveBroadcast::InsertMoveBroadcast() {

register_matcher(std::make_shared<ngraph::pattern::Matcher>(any, matcher_name), callback);
}

ngraph::snippets::pass::BroadcastToMoveBroadcast::BroadcastToMoveBroadcast() {
MATCHER_SCOPE(BroadcastToMoveBroadcast);

register_matcher(std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<ngraph::opset1::Broadcast>(), matcher_name),
[this](ngraph::pattern::Matcher &m) {
OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::op::BroadcastToMoveBroadcast")
auto root = m.get_match_root();
auto broadcast = ov::as_type_ptr<ngraph::opset1::Broadcast>(root);
if (broadcast->get_broadcast_spec() != ngraph::op::AutoBroadcastType::NUMPY) {
return false;
}

auto broadcast_move = broadcast_node_last_dim(broadcast->input_value(0), broadcast->get_output_shape(0), broadcast->get_input_shape(0));
auto target_output = ov::is_type<ngraph::snippets::op::BroadcastMove>(broadcast_move) ? broadcast_move->output(0) :
broadcast->input_value(0);
replace_output_update_name(broadcast->output(0), target_output);
ngraph::copy_runtime_info(root, broadcast_move);

return true;
});
}
14 changes: 14 additions & 0 deletions src/common/snippets/tests/include/pass/insert_movebroadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ typedef std::tuple<
Shape // Broadcast shape 1
> insertMoveBroadcastParams;

typedef std::tuple<
Shape, // Input shape 0
Shape, // Input shape 1
Shape // Broadcast shape
> BroadcastParams;

using ngraph::snippets::op::Subgraph;
class InsertMoveBroadcastTests : public LoweringTests, public testing::WithParamInterface<insertMoveBroadcastParams> {
public:
Expand All @@ -31,6 +37,14 @@ class InsertMoveBroadcastTests : public LoweringTests, public testing::WithParam
std::shared_ptr<SnippetsFunctionBase> snippets_function;
};

class BroadcastToMoveBroadcastTests : public LoweringTests, public testing::WithParamInterface<BroadcastParams> {
public:
static std::string getTestCaseName(testing::TestParamInfo<BroadcastParams> obj);
protected:
void SetUp() override;
std::shared_ptr<SnippetsFunctionBase> snippets_function;
};

} // namespace snippets
} // namespace test
} // namespace ov
44 changes: 44 additions & 0 deletions src/common/snippets/tests/src/pass/insert_movebroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,28 @@ void InsertMoveBroadcastTests::SetUp() {
master_shape.push_back(static_cast<int64_t>(std::max(inputShapes[0][i], inputShapes[1][i])));
}

std::string BroadcastToMoveBroadcastTests::getTestCaseName(testing::TestParamInfo<BroadcastParams> obj) {
std::vector<Shape> inputShapes(2);
Shape broadcast_shape;
std::tie(inputShapes[0], inputShapes[1], broadcast_shape) = obj.param;
std::ostringstream result;
for (size_t i = 0; i < inputShapes.size(); i++)
result << "IS[" << i << "]=" << CommonTestUtils::vec2str(inputShapes[i]) << "_";
result << "BS=" << CommonTestUtils::vec2str(broadcast_shape) << "_";
return result.str();
}

void BroadcastToMoveBroadcastTests::SetUp() {
TransformationTestsF::SetUp();
std::vector<PartialShape> inputShapes(2);
PartialShape broadcast_shape;
std::tie(inputShapes[0], inputShapes[1], broadcast_shape) = this->GetParam();
snippets_function = std::make_shared<BroadcastAddLoweredFunction>(inputShapes, broadcast_shape);
master_shape = {};
for (int i = 0; i < inputShapes[0].size(); i++)
master_shape.push_back(static_cast<int64_t>(std::max(inputShapes[0].get_shape()[i], inputShapes[1].get_shape()[i])));
}

TEST_P(InsertMoveBroadcastTests, AddBroadcast) {
PartialShape scheduler_shape({master_shape[master_shape.size() - 2],
master_shape[master_shape.size() - 1]});
Expand All @@ -44,6 +66,14 @@ TEST_P(InsertMoveBroadcastTests, AddBroadcast) {
function_ref = snippets_function->getLowered();
}

TEST_P(BroadcastToMoveBroadcastTests, BroadcastSelect) {
PartialShape scheduler_shape({master_shape[master_shape.size() - 2],
master_shape[master_shape.size() - 1]});
auto subgraph = getLoweredSubgraph(snippets_function->getOriginal(), scheduler_shape);
function = subgraph->get_body();
function_ref = snippets_function->getLowered();
}

namespace InsertMoveBroadcastTestsInstantiation {
using ov::Shape;
std::vector<Shape> inputShapes0 {{1, 8, 2, 1}};
Expand Down Expand Up @@ -85,6 +115,20 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_NoBroadcast, InsertMoveBroadcastTests,
::testing::ValuesIn(paramsNo),
InsertMoveBroadcastTests::getTestCaseName);
} // namespace InsertMoveBroadcastTestsInstantiation


namespace BroadcastToMoveBroadcastTestsInstantiation {
using ov::Shape;
std::vector<Shape> inputShapes0 {{1, 8, 2, 10}, {1, 8, 2, 1}, {1, 1, 1, 1}};
std::vector<Shape> inputShapes1 {{1, 8, 2, 10}, {1, 8, 2, 1}, {1, 1, 1, 1}};
Shape broadcastShape {1, 8, 2, 10};
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_Broadcast, BroadcastToMoveBroadcastTests,
::testing::Combine(
::testing::ValuesIn(inputShapes0),
::testing::ValuesIn(inputShapes1),
::testing::Values(broadcastShape)),
BroadcastToMoveBroadcastTests::getTestCaseName);
} // namespace BroadcastToMoveBroadcastTestsInstantiation
} // namespace snippets
} // namespace test
} // namespace ov
3 changes: 3 additions & 0 deletions src/plugins/intel_cpu/src/emitters/cpu_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_
jitters[ngraph::snippets::op::ConvertSaturation::get_type_info_static()] = CREATE_EMITTER(ov::intel_cpu::jit_convert_saturation_emitter);
// jitters[ngraph::opset1::FakeQuantize::get_type_info_static()] = CREATE_EMITTER(); // not supported

// ternary
jitters[ngraph::opset1::Select::get_type_info_static()] = CREATE_EMITTER(ov::intel_cpu::jit_select_emitter);

// binary
jitters[ngraph::opset1::Add::get_type_info_static()] = CREATE_EMITTER(ov::intel_cpu::jit_add_emitter);
jitters[ngraph::opset1::Divide::get_type_info_static()] = CREATE_EMITTER(ov::intel_cpu::jit_divide_emitter);
Expand Down
Loading

0 comments on commit 7c89867

Please sign in to comment.