Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement nGraph transformation to decompose Einsum-7 operation #5529

Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API EinsumDecomposition;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief EinsumDecomposition transformation decomposes Einsum-7 operation into a sub-graph with more simple operations:
* Transpose, Reshape, MatMul, ReduceSum, Unsqueeze, ShapeOf, ReduceProd, StridedSlice, and Concat
*/
class ngraph::pass::EinsumDecomposition : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
EinsumDecomposition();
};
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "transformations/op_conversions/convert_gelu.hpp"
#include "transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp"
#include "transformations/op_conversions/batch_norm_decomposition.hpp"
#include "transformations/op_conversions/einsum_decomposition.hpp"
#include "transformations/op_conversions/gelu7_downgrade.hpp"
#include "transformations/op_conversions/reduce_l1_decomposition.hpp"
#include "transformations/op_conversions/reduce_l2_decomposition.hpp"
Expand Down Expand Up @@ -146,6 +147,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
decomp->add_matcher<ngraph::pass::MVN6Decomposition>();
decomp->add_matcher<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
decomp->add_matcher<ngraph::pass::EinsumDecomposition>();
rkazants marked this conversation as resolved.
Show resolved Hide resolved
decomp->set_name("ngraph::pass::CommonDecompositions");

// CF is required after all decompositions
Expand Down
Loading