Skip to content

Commit

Permalink
Implement nGraph transformation to decompose Einsum-7 operation (open…
Browse files Browse the repository at this point in the history
…vinotoolkit#5529)

* Implement nGraph transformation to decompose Einsum-7 operation

Signed-off-by: Roman Kazantsev <[email protected]>

* Use MatMul instead of Eltwise-multiplication and ReduceSum

Signed-off-by: Roman Kazantsev <[email protected]>

* Add description for new methods

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix code style

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix code style #2

Signed-off-by: Roman Kazantsev <[email protected]>

* Remove unused variables.py

Signed-off-by: Roman Kazantsev <[email protected]>

* Apply feedback after review: fix comments, new_register_node use

Signed-off-by: Roman Kazantsev <[email protected]>

* Add Reshape if needed and apply code-review feedback

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix code-style

Signed-off-by: Roman Kazantsev <[email protected]>

* Remove unused variable

Signed-off-by: Roman Kazantsev <[email protected]>
  • Loading branch information
rkazants authored and rnugmanx committed Aug 26, 2021
1 parent 06dbb68 commit f7c8718
Show file tree
Hide file tree
Showing 5 changed files with 731 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API EinsumDecomposition;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief EinsumDecomposition transformation decomposes Einsum-7 operation into a sub-graph with more simple operations:
* Transpose, Reshape, MatMul, ReduceSum, Unsqueeze, ShapeOf, ReduceProd, StridedSlice, and Concat
*/
class ngraph::pass::EinsumDecomposition : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
EinsumDecomposition();
};
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "transformations/op_conversions/convert_gelu.hpp"
#include "transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp"
#include "transformations/op_conversions/batch_norm_decomposition.hpp"
#include "transformations/op_conversions/einsum_decomposition.hpp"
#include "transformations/op_conversions/gelu7_downgrade.hpp"
#include "transformations/op_conversions/reduce_l1_decomposition.hpp"
#include "transformations/op_conversions/reduce_l2_decomposition.hpp"
Expand Down Expand Up @@ -146,6 +147,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
decomp->add_matcher<ngraph::pass::MVN6Decomposition>();
decomp->add_matcher<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
decomp->add_matcher<ngraph::pass::EinsumDecomposition>();
decomp->set_name("ngraph::pass::CommonDecompositions");

// CF is required after all decompositions
Expand Down
Loading

0 comments on commit f7c8718

Please sign in to comment.