diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index e47a773298a4ef..b88c9fa6c4866b 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -506,6 +506,7 @@ snippets::Schedule snippets::op::Subgraph::generate(ngraph::pass::Manager& opt, const auto& outerTileBegin = insertTileBegin(commonParams); insertTileEnd(commonResults, outerTileBegin, outer_dim, outer_WA, 1, apply_increments); } + m_body->validate_nodes_and_infer_types(); } else { throw ngraph_error("Dynamic case is not supported yet"); } @@ -522,7 +523,6 @@ snippets::Schedule snippets::op::Subgraph::generate(ngraph::pass::Manager& opt, // } // std::cerr << "\n"; // } - m_body->validate_nodes_and_infer_types(); // std::cerr << "Tile after is dumped"; // ov::pass::Serialize("tile_after.xml", "tile_after.bin").run_on_model(m_body); diff --git a/src/common/snippets/tests/include/lowering_utils.hpp b/src/common/snippets/tests/include/lowering_utils.hpp index 5af4af2a32b099..e78bc1940f4477 100644 --- a/src/common/snippets/tests/include/lowering_utils.hpp +++ b/src/common/snippets/tests/include/lowering_utils.hpp @@ -40,8 +40,10 @@ class DummyGenerator : public ngraph::snippets::Generator { class LoweringTests : public TransformationTestsF { protected: static std::shared_ptr getSubgraph(const std::shared_ptr& f); - static std::shared_ptr getLoweredSubgraph(const std::shared_ptr& f); + static std::shared_ptr getLoweredSubgraph(const std::shared_ptr& f, + const ov::PartialShape& master_shape); static std::shared_ptr getTokenizedSubgraph(const std::shared_ptr& f); + ov::PartialShape master_shape{}; }; } // namespace snippets diff --git a/src/common/snippets/tests/src/lowering_utils.cpp b/src/common/snippets/tests/src/lowering_utils.cpp index 4aab86d5d7c07c..a0e460fc15b35e 100644 --- a/src/common/snippets/tests/src/lowering_utils.cpp +++ b/src/common/snippets/tests/src/lowering_utils.cpp @@ -30,8 +30,8 @@ DummyTargetMachine::DummyTargetMachine() { jitters[ngraph::snippets::op::Scalar::get_type_info_static()] = dummy_functor; jitters[ngraph::snippets::op::BroadcastMove::get_type_info_static()] = dummy_functor; jitters[ngraph::snippets::op::Kernel::get_type_info_static()] = dummy_functor; - jitters[ngraph::snippets::op::Tile::get_type_info_static()] = dummy_functor; - jitters[ngraph::snippets::op::TileScheduler::get_type_info_static()] = dummy_functor; + jitters[ngraph::snippets::op::TileBegin::get_type_info_static()] = dummy_functor; + jitters[ngraph::snippets::op::TileEnd::get_type_info_static()] = dummy_functor; } std::shared_ptr LoweringTests::getSubgraph(const std::shared_ptr& f) { @@ -52,9 +52,11 @@ std::shared_ptr LoweringTests::getSubgraph(const return subgraph; } -std::shared_ptr LoweringTests::getLoweredSubgraph(const std::shared_ptr &f) { +std::shared_ptr LoweringTests::getLoweredSubgraph(const std::shared_ptr &f, + const ov::PartialShape& master_shape) { auto subgraph = getTokenizedSubgraph(f); subgraph->set_generator(std::make_shared()); + subgraph->set_master_shape(master_shape); subgraph->generate(); return subgraph; } diff --git a/src/common/snippets/tests/src/pass/insert_load_store.cpp b/src/common/snippets/tests/src/pass/insert_load_store.cpp index 97fe94a3fe8d8f..25d189cd096151 100644 --- a/src/common/snippets/tests/src/pass/insert_load_store.cpp +++ b/src/common/snippets/tests/src/pass/insert_load_store.cpp @@ -32,10 +32,13 @@ void InsertLoadStoreTests::SetUp() { broadcastShapes[0], broadcastShapes[1], broadcastShapes[2]) = this->GetParam(); snippets_function = std::make_shared( std::vector {inputShapes[0], inputShapes[1], inputShapes[2]}, broadcastShapes); + master_shape = inputShapes[0]; } TEST_P(InsertLoadStoreTests, ThreeInputsEltwise) { - auto subgraph = getLoweredSubgraph(snippets_function->getOriginal()); + PartialShape scheduler_shape({master_shape[master_shape.size() - 2], + master_shape[master_shape.size() - 1]}); + auto subgraph = getLoweredSubgraph(snippets_function->getOriginal(), scheduler_shape); function = subgraph->get_body(); function_ref = snippets_function->getLowered(); } diff --git a/src/common/snippets/tests/src/pass/insert_movebroadcast.cpp b/src/common/snippets/tests/src/pass/insert_movebroadcast.cpp index 577c23df19fb70..a5886ddd474aa1 100644 --- a/src/common/snippets/tests/src/pass/insert_movebroadcast.cpp +++ b/src/common/snippets/tests/src/pass/insert_movebroadcast.cpp @@ -29,10 +29,17 @@ void InsertMoveBroadcastTests::SetUp() { std::vector broadcastShapes(2); std::tie(inputShapes[0], inputShapes[1], broadcastShapes[0], broadcastShapes[1]) = this->GetParam(); snippets_function = std::make_shared(std::vector {inputShapes[0], inputShapes[1]}, broadcastShapes); + if (inputShapes[0].size() != inputShapes[1].size()) + IE_THROW() << "Expected input shapes of the same size"; + master_shape = {}; + for (int i = 0; i < inputShapes[0].size(); i++) + master_shape.push_back(static_cast(std::max(inputShapes[0][i], inputShapes[1][i]))); } TEST_P(InsertMoveBroadcastTests, AddBroadcast) { - auto subgraph = getLoweredSubgraph(snippets_function->getOriginal()); + PartialShape scheduler_shape({master_shape[master_shape.size() - 2], + master_shape[master_shape.size() - 1]}); + auto subgraph = getLoweredSubgraph(snippets_function->getOriginal(), scheduler_shape); function = subgraph->get_body(); function_ref = snippets_function->getLowered(); } diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp index 15106c47c8c335..25dcdb7c92ff60 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp @@ -6,6 +6,7 @@ #include "common_test_utils/data_utils.hpp" #include #include "ngraph_functions/builders.hpp" +#include namespace ov { namespace test { @@ -29,7 +30,24 @@ std::shared_ptr AddFunctionLoweredBroadcast::initLowered() const { } auto add = std::make_shared(add_input0, add_input1); auto store = std::make_shared(add); - return std::make_shared(NodeVector{store}, ParameterVector{data0, data1}); + ParameterVector input_params {data0, data1}; + auto model = std::make_shared(NodeVector{store}, input_params); + + // Create dummy scheduler to pass graph comparison tests + // Note that if there is more than one results, they should be reverted + ResultVector results({model->get_results()[0]}); + const auto& innerTileBegin = ngraph::snippets::op::insertTileBegin(input_params); + std::vector apply_increments(input_params.size() + results.size(), true); + const auto& innerTileEnd = insertTileEnd(results, innerTileBegin, 1, 1, 1, apply_increments); + auto outer_WA = std::accumulate(input_shapes.begin(), input_shapes.end(), 0, + [](int64_t max_val, const PartialShape& ps) { + return std::max(ps[ps.size() - 2].get_length(), max_val); + }); + if (outer_WA > 1) { + const auto& outerTileBegin = ngraph::snippets::op::insertTileBegin(input_params); + insertTileEnd(results, outerTileBegin, 0, 1, 1, apply_increments); + } + return model; } std::shared_ptr EltwiseThreeInputsLoweredFunction::initLowered() const { // todo: implement conversion between std::vector and std::vector @@ -60,12 +78,6 @@ std::shared_ptr EltwiseThreeInputsLoweredFunction::initLowered() cons const std::vector const_values = CommonTestUtils::generate_float_numbers(1, -10., 10.); auto sub_scalar = std::make_shared(precision, Shape{1}, const_values[0]); std::shared_ptr sub_load; -// Todo: Uncomment when invalid read in vector tile will be fixed -// if (input_shapes[2].back() == 1) -// sub_load = std::make_shared(input_params[2]); -// else -// sub_load = std::make_shared(input_params[2]); -// remove when the code above is enabled: sub_load = std::make_shared(input_params[2]); auto sub = std::make_shared(sub_load, sub_scalar); std::shared_ptr sub_out; @@ -75,7 +87,23 @@ std::shared_ptr EltwiseThreeInputsLoweredFunction::initLowered() cons sub_out = std::make_shared(sub, broadcast_shapes[2]); auto mul = std::make_shared(add, sub_out); auto store = std::make_shared(mul); - return std::make_shared(NodeVector{store}, input_params); + auto model = std::make_shared(NodeVector{store}, input_params); + + // Create dummy scheduler to pass graph comparison tests + // Note that if there is more than one results, they should be reverted + ResultVector results({model->get_results()[0]}); + const auto& innerTileBegin = ngraph::snippets::op::insertTileBegin(input_params); + std::vector apply_increments(input_params.size() + results.size(), true); + const auto& innerTileEnd = insertTileEnd(results, innerTileBegin, 1, 1, 1, apply_increments); + auto outer_WA = std::accumulate(input_shapes.begin(), input_shapes.end(), 0, + [](int64_t max_val, const PartialShape& ps) { + return std::max(ps[ps.size() - 2].get_length(), max_val); + }); + if (outer_WA > 1) { + const auto& outerTileBegin = ngraph::snippets::op::insertTileBegin(input_params); + insertTileEnd(results, outerTileBegin, 0, 1, 1, apply_increments); + } + return model; } } // namespace snippets } // namespace test