Skip to content

Commit

Permalink
fix gather_sinking_transpose_reshape - check if flatten/unflatten the…
Browse files Browse the repository at this point in the history
… last dimension
  • Loading branch information
evkotov committed Mar 14, 2023
1 parent 41e1b72 commit 90b37d7
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/plugins/intel_gna/src/debug_new_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ void DebugVisualize(ov::pass::Manager& manager, const std::string& name) {
#ifdef DEBUG_VISUALIZE
static unsigned counter = 0;
std::stringstream ss;
#ifdef DEBUG_VISUALIZETREE
ss << counter << name << ".png";
manager.register_pass<ov::pass::VisualizeTree>(ss.str());
//ss << counter << name;
//manager.register_pass<ov::pass::Serialize>(ss.str() + ".xml", ss.str() + ".bin");
#else
ss << counter << name;
manager.register_pass<ov::pass::Serialize>(ss.str() + ".xml", ss.str() + ".bin");
#endif
++counter;
#endif
}
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gna/src/debug_new_pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

#undef DEBUG_VISUALIZE
//#define DEBUG_VISUALIZE 1
#undef DEBUG_VISUALIZETREE
//#define DEBUG_VISUALIZETREE 1

#define EMUTEX_DEBUG_CHECKPOINT std::cout << "[EMUTEX DEBUG] CHECKPOINT " << __FILE__ << ":" << __LINE__ << std::endl;
#define EMUTEX_DEBUG_CHECKPOINT_MESSAGE(message) std::cout << "[EMUTEX DEBUG] CHECKPOINT " << __FILE__ << ":" << __LINE__ << \
Expand Down
6 changes: 6 additions & 0 deletions src/plugins/intel_gna/src/gna_transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
// In OV API 2.0(IRv10) default convertion to fp32 (inputs, outputs and weights) is disabled
// and we need to run the ConvertPrecision transformation to support old networks.
manager.register_pass<ov::pass::ConvertPrecision>(precisions_array{{ngraph::element::f16, ngraph::element::f32}});
intel_gna_debug::DebugVisualize(manager, "start");
manager.register_pass<ov::pass::ConvertMVN1ToMVN6>();
manager.register_pass<ov::intel_gna::pass::DecomposeMVN>();
manager.register_pass<ov::pass::CommonOptimizations>();
Expand Down Expand Up @@ -122,15 +123,18 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
manager.register_pass<ov::intel_gna::pass::RemoveSingleInputConcat>();
manager.register_pass<ov::intel_gna::pass::SubstituteSoftsign>();
manager.register_pass<ov::intel_gna::pass::InsertCopyBeforeLayerToBeEliminated>();
intel_gna_debug::DebugVisualize(manager, "before_TransposeNCHW");
manager.register_pass<ov::intel_gna::pass::TransposeNCHW>();
manager.register_pass<ov::intel_gna::pass::ReshapeTransposeSubstitute>();
intel_gna_debug::DebugVisualize(manager, "before_TransposeSinkingGeneral");
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<ov::intel_gna::pass::GatherSinkingGeneral>();
manager.register_pass<ov::pass::ReshapeSequenceFusion>();
manager.register_pass<ov::pass::TransposeToReshape>();
manager.register_pass<ov::intel_gna::pass::GnaConvolutionFusion>();
manager.register_pass<ov::intel_gna::pass::RemoveInputsProcessing>(subgraph_cpu_map);
manager.register_pass<ov::intel_gna::pass::RemoveOutputsProcessing>(subgraph_cpu_map);
intel_gna_debug::DebugVisualize(manager, "after_our_transformations");
manager.register_pass<ov::pass::ConvertOpSet3ToOpSet2>();
manager.register_pass<ov::pass::ConvertOpSet2ToOpSet1>();
manager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
Expand Down Expand Up @@ -192,6 +196,8 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
// Operations Max and Min aren't supported
pass_config->disable<ov::pass::ConcatReduceFusion>();

intel_gna_debug::DebugVisualize(manager, "final");

manager.run_passes(model);

is_ngraph_passes_used = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ bool AreFlattenShapes(const Shape& shape1, const Shape& shape2) {
while (shape1[i] == shape2[i]) {
++i;
}
// consider only last dimension to be flatten/unflatten
if (shape1.size() - 1 != i && shape2.size() - 1 != i)
return false;
// min_shape.back() == MULTIPLY(max_shape.begin() + i, max_shape.end())
const size_t mult1 = std::accumulate(shape1.begin() + i, shape1.end(), 1, std::multiplies<size_t>());
const size_t mult2 = std::accumulate(shape2.begin() + i, shape2.end(), 1, std::multiplies<size_t>());
Expand Down

0 comments on commit 90b37d7

Please sign in to comment.