diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_binary.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_binary.hpp deleted file mode 100644 index b910b7d3e5fcfd..00000000000000 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_binary.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (C) 2022-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pass.hpp" -#include "transformations_visibility.hpp" - -namespace ov { -namespace pass { - -class TRANSFORMATIONS_API TransposeSinkingBinaryForward; -class TRANSFORMATIONS_API TransposeSinkingBinaryBackward; - -} // namespace pass -} // namespace ov - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingBinaryForward transformation sinks Transpose through BinaryElementwiseArithmetic, - * BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the forward direction. - */ -class ov::pass::TransposeSinkingBinaryForward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryForward", "0"); - TransposeSinkingBinaryForward(); -}; - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingBinaryBackward transformation sinks Transpose through BinaryElementwiseArithmetic, - * BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the backward direction. - */ -class ov::pass::TransposeSinkingBinaryBackward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryBackward", "0"); - TransposeSinkingBinaryBackward(); -}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_concat.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_concat.hpp deleted file mode 100644 index 66c2c7ffe98021..00000000000000 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_concat.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (C) 2022-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pass.hpp" -#include "transformations_visibility.hpp" - -namespace ov { -namespace pass { - -class TRANSFORMATIONS_API TransposeSinkingConcatForward; -class TRANSFORMATIONS_API TransposeSinkingConcatBackward; - -} // namespace pass -} // namespace ov - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingConcatForward transformation sinks Transpose through Concat operation - * in the forward direction. - */ -class ov::pass::TransposeSinkingConcatForward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("ov::pass::TransposeSinkingConcatForward", "0"); - TransposeSinkingConcatForward(); -}; - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingConcatBackward transformation sinks Transpose through Concat operation - * in the backward direction. - */ -class ov::pass::TransposeSinkingConcatBackward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("ov::pass::TransposeSinkingConcatBackward", "0"); - TransposeSinkingConcatBackward(); -}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_data_movement.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_data_movement.hpp deleted file mode 100644 index 602f862f77d994..00000000000000 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_data_movement.hpp +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (C) 2022-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pass.hpp" -#include "transformations_visibility.hpp" - -namespace ov { -namespace pass { - -class TRANSFORMATIONS_API TransposeSinkingDataMovementForward; -class TRANSFORMATIONS_API TransposeSinkingDataMovementBackward; - -} // namespace pass -} // namespace ov - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingDataMovementForward transformation sinks Transpose through BatchToSpace, SpaceToBatch - * and Pad operations in the forward direction. - * These operations are categorized as "DataMovement" and are handled in a similar way in this transformation. - */ -class ov::pass::TransposeSinkingDataMovementForward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementForward", "0"); - TransposeSinkingDataMovementForward(); -}; - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingDataMovementBackward transformation sinks Transpose through BatchToSpace, SpaceToBatch - * and Pad operations in the backward direction. - * These operations are categorized as "DataMovement" and are handled in a similar way in this transformation. - */ -class ov::pass::TransposeSinkingDataMovementBackward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("ov::pass::TransposeSinkingDataMovementBackward", "0"); - TransposeSinkingDataMovementBackward(); -}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_general.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_general.hpp deleted file mode 100644 index 722469530e45fd..00000000000000 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_general.hpp +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (C) 2022-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" -#include "transformations_visibility.hpp" - -namespace ov { -namespace pass { - -class TRANSFORMATIONS_API TransposeSinkingGeneralForward; -class TRANSFORMATIONS_API TransposeSinkingGeneralBackward; -class TRANSFORMATIONS_API TransposeSinkingGeneral; - -} // namespace pass -} // namespace ov - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingGeneralForward transformation combines all TransposeSinkingForward* transformations into - * single GraphRewrite pass. - */ -class ov::pass::TransposeSinkingGeneralForward : public ov::pass::GraphRewrite { -public: - OPENVINO_RTTI("TransposeSinkingGeneralForward", "0"); - TransposeSinkingGeneralForward(); -}; - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingGeneralBackward transformation combines all TransposeSinkingBackward* transformations into - * single GraphRewrite pass. - */ -class ov::pass::TransposeSinkingGeneralBackward : public ov::pass::GraphRewrite { -public: - OPENVINO_RTTI("TransposeSinkingGeneralBackward", "0"); - TransposeSinkingGeneralBackward(); -}; - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingGeneral transformation combines TransposeSinkingGeneralForward and - * TransposeSinkingGeneralBackward transformations into single ModelPass pass and inserts - * ConstantFolding pass after them. - */ -class ov::pass::TransposeSinkingGeneral : public ov::pass::ModelPass { -public: - OPENVINO_RTTI("TransposeSinkingGeneral", "0"); - bool run_on_model(const std::shared_ptr& m) override; -}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_interpolate.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_interpolate.hpp deleted file mode 100644 index f51c09649ed91e..00000000000000 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_interpolate.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (C) 2022-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pass.hpp" -#include "transformations_visibility.hpp" - -namespace ov { -namespace pass { - -class TRANSFORMATIONS_API TransposeSinkingInterpolateForward; -class TRANSFORMATIONS_API TransposeSinkingInterpolateBackward; - -} // namespace pass -} // namespace ov - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingInterpolateForward transformation sinks Transpose through Interpolate operation - * in the forward direction. - */ -class ov::pass::TransposeSinkingInterpolateForward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateForward", "0"); - TransposeSinkingInterpolateForward(); -}; - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingInterpolateBackward transformation sinks Transpose through Interpolate operation - * in the backward direction. - */ -class ov::pass::TransposeSinkingInterpolateBackward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("ov::pass::TransposeSinkingInterpolateBackward", "0"); - TransposeSinkingInterpolateBackward(); -}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_split.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_split.hpp deleted file mode 100644 index b40eb2741470fb..00000000000000 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_split.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (C) 2022-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" -#include "openvino/pass/pass.hpp" -#include "transformations_visibility.hpp" - -namespace ov { -namespace pass { - -class TRANSFORMATIONS_API TransposeSinkingSplitBackward; -class TRANSFORMATIONS_API TransposeSinkingSplitForward; - -} // namespace pass -} // namespace ov - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingSplitForward transformation sinks Transpose through Split, VariadicSplit operations - * in the forward direction. - */ -class ov::pass::TransposeSinkingSplitForward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("ov::pass::TransposeSinkingSplitForward", "0"); - TransposeSinkingSplitForward(); -}; - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingSplitBackward transformation sinks Transpose through Split, VariadicSplit operations - * in the backward direction. - */ -class ov::pass::TransposeSinkingSplitBackward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("ov::pass::TransposeSinkingSplitBackward", "0"); - TransposeSinkingSplitBackward(); -}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_unary.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_unary.hpp deleted file mode 100644 index 1d0098ee9cc0dc..00000000000000 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_unary.hpp +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (C) 2022-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" -#include "transformations_visibility.hpp" - -namespace ov { -namespace pass { - -class TRANSFORMATIONS_API TransposeSinkingUnaryForward; -class TRANSFORMATIONS_API TransposeSinkingUnaryBackward; - -} // namespace pass -} // namespace ov - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingUnaryForward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu, - * SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite operations in the forward direction. - */ -class ov::pass::TransposeSinkingUnaryForward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("TransposeSinkingUnaryForward", "0"); - TransposeSinkingUnaryForward(); -}; - -/** - * @ingroup ie_transformation_common_api - * @brief TransposeSinkingUnaryBackward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu, - * SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite in the backward direction. - */ -class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("TransposeSinkingUnaryBackwardMultiConsumers", "0"); - TransposeSinkingUnaryBackward(); -}; diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_binary.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_binary.hpp new file mode 100644 index 00000000000000..a1b559f3c4d682 --- /dev/null +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_binary.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +namespace transpose_sinking { + +class TRANSFORMATIONS_API TSBinaryForward; +class TRANSFORMATIONS_API TSBinaryBackward; + +} // namespace transpose_sinking +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief TSBinaryForward transformation sinks Transpose through BinaryElementwiseArithmetic, + * BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the forward direction. + */ +class ov::pass::transpose_sinking::TSBinaryForward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSBinaryForward", "0"); + TSBinaryForward(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TSBinaryBackward transformation sinks Transpose through BinaryElementwiseArithmetic, + * BinaryElementwiseComparison, BinaryElementwiseLogical and PRelu operations in the backward direction. + */ +class ov::pass::transpose_sinking::TSBinaryBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0"); + TSBinaryBackward(); +}; diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_concat.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_concat.hpp new file mode 100644 index 00000000000000..904f68ec4fa8f5 --- /dev/null +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_concat.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +namespace transpose_sinking { + +class TRANSFORMATIONS_API TSConcatForward; +class TRANSFORMATIONS_API TSConcatBackward; + +} // namespace transpose_sinking +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief TSConcatForward transformation sinks Transpose through Concat operation + * in the forward direction. + */ +class ov::pass::transpose_sinking::TSConcatForward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSConcatForward", "0"); + TSConcatForward(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TSConcatBackward transformation sinks Transpose through Concat operation + * in the backward direction. + */ +class ov::pass::transpose_sinking::TSConcatBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSConcatBackward", "0"); + TSConcatBackward(); +}; diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_data_movement.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_data_movement.hpp new file mode 100644 index 00000000000000..090f67492f4f0b --- /dev/null +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_data_movement.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2022-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +namespace transpose_sinking { + +class TRANSFORMATIONS_API TSDataMovementForward; +class TRANSFORMATIONS_API TSDataMovementBackward; + +} // namespace transpose_sinking +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief TSDataMovementForward transformation sinks Transpose through BatchToSpace, SpaceToBatch + * and Pad operations in the forward direction. + * These operations are categorized as "DataMovement" and are handled in a similar way in this transformation. + */ +class ov::pass::transpose_sinking::TSDataMovementForward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSDataMovementForward", "0"); + TSDataMovementForward(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TSDataMovementBackward transformation sinks Transpose through BatchToSpace, SpaceToBatch + * and Pad operations in the backward direction. + * These operations are categorized as "DataMovement" and are handled in a similar way in this transformation. + */ +class ov::pass::transpose_sinking::TSDataMovementBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSDataMovementBackward", "0"); + TSDataMovementBackward(); +}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_fuse.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_fuse.hpp similarity index 55% rename from src/common/transformations/include/transformations/common_optimizations/transpose_sinking_fuse.hpp rename to src/common/transformations/include/transformations/transpose_sinking/ts_fuse.hpp index 4bb2c1238bbd4f..84294d9640e24d 100644 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_fuse.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_fuse.hpp @@ -10,19 +10,21 @@ namespace ov { namespace pass { +namespace transpose_sinking { -class TRANSFORMATIONS_API TransposeSinkingFuse; +class TRANSFORMATIONS_API TSFuse; +} // namespace transpose_sinking } // namespace pass } // namespace ov /** * @ingroup ie_transformation_common_api - * @brief TransposeSinkingFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input + * @brief TSFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input * or fuses them to single Transpose if input gets changed */ -class ov::pass::TransposeSinkingFuse : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSFuse : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("TransposeSinkingFuse", "0"); - TransposeSinkingFuse(); + OPENVINO_RTTI("TSFuse", "0"); + TSFuse(); }; \ No newline at end of file diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_general.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_general.hpp new file mode 100644 index 00000000000000..8199f4c378dc9a --- /dev/null +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_general.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2022-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +namespace transpose_sinking { + +class TRANSFORMATIONS_API TSGeneralForward; +class TRANSFORMATIONS_API TSGeneralBackward; +class TRANSFORMATIONS_API TSGeneral; + +} // namespace transpose_sinking +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief TSGeneralForward transformation combines all TransposeSinkingForward* transformations into + * single GraphRewrite pass. + */ +class ov::pass::transpose_sinking::TSGeneralForward : public ov::pass::GraphRewrite { +public: + OPENVINO_RTTI("TSGeneralForward", "0"); + TSGeneralForward(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TSGeneralBackward transformation combines all TransposeSinkingBackward* transformations into + * single GraphRewrite pass. + */ +class ov::pass::transpose_sinking::TSGeneralBackward : public ov::pass::GraphRewrite { +public: + OPENVINO_RTTI("TSGeneralBackward", "0"); + TSGeneralBackward(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TSGeneral transformation combines TSGeneralForward and + * TSGeneralBackward transformations into single ModelPass pass and inserts + * ConstantFolding pass after them. + */ +class ov::pass::transpose_sinking::TSGeneral : public ov::pass::ModelPass { +public: + OPENVINO_RTTI("TSGeneral", "0"); + bool run_on_model(const std::shared_ptr& m) override; +}; diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_interpolate.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_interpolate.hpp new file mode 100644 index 00000000000000..519154626a99d9 --- /dev/null +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_interpolate.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +namespace transpose_sinking { + +class TRANSFORMATIONS_API TSInterpolateForward; +class TRANSFORMATIONS_API TSInterpolateBackward; + +} // namespace transpose_sinking +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief TSInterpolateForward transformation sinks Transpose through Interpolate operation + * in the forward direction. + */ +class ov::pass::transpose_sinking::TSInterpolateForward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSInterpolateForward", "0"); + TSInterpolateForward(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TSInterpolateBackward transformation sinks Transpose through Interpolate operation + * in the backward direction. + */ +class ov::pass::transpose_sinking::TSInterpolateBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSInterpolateBackward", "0"); + TSInterpolateBackward(); +}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_reduction.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_reduction.hpp similarity index 58% rename from src/common/transformations/include/transformations/common_optimizations/transpose_sinking_reduction.hpp rename to src/common/transformations/include/transformations/transpose_sinking/ts_reduction.hpp index 313a92a38c0bc3..6f462875b7f6ba 100644 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_reduction.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_reduction.hpp @@ -10,10 +10,12 @@ namespace ov { namespace pass { +namespace transpose_sinking { -class TRANSFORMATIONS_API TransposeSinkingReductionForward; -class TRANSFORMATIONS_API TransposeSinkingReductionBackward; +class TRANSFORMATIONS_API TSReductionForward; +class TRANSFORMATIONS_API TSReductionBackward; +} // namespace transpose_sinking } // namespace pass } // namespace ov @@ -22,10 +24,10 @@ class TRANSFORMATIONS_API TransposeSinkingReductionBackward; * @brief TransposeReductionForward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations * in the forward direction. */ -class ov::pass::TransposeSinkingReductionForward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("ov::pass::TransposeSinkingReductionForward", "0"); - TransposeSinkingReductionForward(); + OPENVINO_RTTI("ov::pass::TSReductionForward", "0"); + TSReductionForward(); }; /** @@ -33,8 +35,8 @@ class ov::pass::TransposeSinkingReductionForward : public ov::pass::MatcherPass * @brief TransposeReductionBackward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations * in the backward direction. */ -class ov::pass::TransposeSinkingReductionBackward : public ov::pass::MatcherPass { +class ov::pass::transpose_sinking::TSReductionBackward : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("ov::pass::TransposeSinkingReductionBackward", "0"); - TransposeSinkingReductionBackward(); + OPENVINO_RTTI("ov::pass::TSReductionBackward", "0"); + TSReductionBackward(); }; \ No newline at end of file diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_split.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_split.hpp new file mode 100644 index 00000000000000..ba75ac6566229b --- /dev/null +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_split.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +namespace transpose_sinking { + +class TRANSFORMATIONS_API TSSplitBackward; +class TRANSFORMATIONS_API TSSplitForward; + +} // namespace transpose_sinking +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief TSSplitForward transformation sinks Transpose through Split, VariadicSplit operations + * in the forward direction. + */ +class ov::pass::transpose_sinking::TSSplitForward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSSplitForward", "0"); + TSSplitForward(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TSSplitBackward transformation sinks Transpose through Split, VariadicSplit operations + * in the backward direction. + */ +class ov::pass::transpose_sinking::TSSplitBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TSSplitBackward", "0"); + TSSplitBackward(); +}; diff --git a/src/common/transformations/include/transformations/transpose_sinking/ts_unary.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_unary.hpp new file mode 100644 index 00000000000000..9c6e93356f7ab6 --- /dev/null +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_unary.hpp @@ -0,0 +1,41 @@ +// Copyright (C) 2022-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { +namespace transpose_sinking { + +class TRANSFORMATIONS_API TSUnaryForward; +class TRANSFORMATIONS_API TSUnaryBackward; + +} // namespace transpose_sinking +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief TSUnaryForward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu, + * SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite operations in the forward direction. + */ +class ov::pass::transpose_sinking::TSUnaryForward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("TSUnaryForward", "0"); + TSUnaryForward(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TSUnaryBackward transformation sinks Transpose through UnaryElementwiseArithmetic, Clamp, Elu, + * SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite in the backward direction. + */ +class ov::pass::transpose_sinking::TSUnaryBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("TSUnaryBackwardMultiConsumers", "0"); + TSUnaryBackward(); +}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp b/src/common/transformations/include/transformations/transpose_sinking/ts_utils.hpp similarity index 95% rename from src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp rename to src/common/transformations/include/transformations/transpose_sinking/ts_utils.hpp index e67f6ec97f5f97..7c14886f893cd2 100644 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp +++ b/src/common/transformations/include/transformations/transpose_sinking/ts_utils.hpp @@ -4,14 +4,17 @@ #pragma once -#include #include #include "openvino/op/util/op_types.hpp" #include "openvino/opsets/opset10.hpp" #include "openvino/util/common_util.hpp" +#include "transformations/utils/utils.hpp" +namespace ov { +namespace pass { namespace transpose_sinking { +namespace utils { struct TransposeInputsInfo { std::shared_ptr transpose; @@ -106,4 +109,7 @@ ov::Output ChangeValuesOrder(const ov::Output& input, const ov::AxisVector& transpose_axis_order, const std::shared_ptr& axis); +} // namespace utils } // namespace transpose_sinking +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp index 1e06d99b4dc101..6fb9a12cc6e095 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp @@ -14,7 +14,6 @@ #include #include "itt.hpp" -#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/utils/utils.hpp" using namespace ov; @@ -32,7 +31,7 @@ std::shared_ptr get_reduced_order_constant(const std::shared_p order.erase(order.begin() + i); } else { // if 2nd input for Squeeze op is not provided, we should remove all 1 dims - // this case will be supported in new TransposeSinkingGeneral transformation. + // this case will be supported in new TSGeneral transformation. return nullptr; } @@ -318,8 +317,6 @@ ov::pass::TransposeFuse::TransposeFuse() { new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name()); ngraph::copy_runtime_info({transpose1, transpose2}, new_transpose); ngraph::replace_node(m.get_match_root(), new_transpose); - - transpose_sinking::UpdateForwardSinkingAbility(new_transpose); } return true; diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp deleted file mode 100644 index 619e7e4326e163..00000000000000 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (C) 2022-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "transformations/common_optimizations/transpose_sinking_general.hpp" - -#include -#include -#include -#include - -#include "itt.hpp" -#include "transformations/common_optimizations/transpose_sinking_binary.hpp" -#include "transformations/common_optimizations/transpose_sinking_concat.hpp" -#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp" -#include "transformations/common_optimizations/transpose_sinking_fuse.hpp" -#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp" -#include "transformations/common_optimizations/transpose_sinking_reduction.hpp" -#include "transformations/common_optimizations/transpose_sinking_split.hpp" -#include "transformations/common_optimizations/transpose_sinking_unary.hpp" -#include "transformations/utils/utils.hpp" - -ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() { - MATCHER_SCOPE(TransposeSinkingGeneralForward); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); -} - -ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() { - MATCHER_SCOPE(TransposeSinkingGeneralBackward); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); - add_matcher(); -} - -bool ov::pass::TransposeSinkingGeneral::run_on_model(const std::shared_ptr& f) { - RUN_ON_FUNCTION_SCOPE(TransposeSinkingGeneral); - { - ov::pass::Manager manager(get_pass_config()); - manager.register_pass(); - manager.register_pass(); - manager.run_passes(f); - } - - { - ov::pass::Manager manager(get_pass_config()); - manager.register_pass(); - manager.register_pass(); - manager.run_passes(f); - } - - return false; -} diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp similarity index 84% rename from src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp rename to src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp index 06b0329188e310..ef46b4143c2305 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp @@ -2,24 +2,24 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/common_optimizations/transpose_sinking_binary.hpp" - -#include -#include +#include "transformations/transpose_sinking/ts_binary.hpp" #include "itt.hpp" #include "openvino/op/util/op_types.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" -#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/rt_info/transpose_sinking_attr.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" -using namespace ov::pass::pattern; using namespace ov; using namespace ov::opset10; -using namespace transpose_sinking; +using namespace ov::pass::pattern; +using namespace ov::pass::transpose_sinking; +using namespace ov::pass::transpose_sinking::utils; -ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() { - MATCHER_SCOPE(TransposeSinkingBinaryForward); +TSBinaryForward::TSBinaryForward() { + MATCHER_SCOPE(TSBinaryForward); auto main_node_label = wrap_typevalidate_and_infer_types(); for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { register_new_node(new_node); - transpose_sinking::UpdateForwardSinkingAbility(new_node); + UpdateForwardSinkingAbility(new_node); } return true; }; @@ -51,8 +51,8 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() { register_matcher(m, matcher_pass_callback); } -ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() { - MATCHER_SCOPE(TransposeSinkingBinaryBackward); +TSBinaryBackward::TSBinaryBackward() { + MATCHER_SCOPE(TSBinaryBackward); auto main_node_label = wrap_type(IfNodeHasTransposeInputs); @@ -47,7 +48,7 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() { main_node->validate_and_infer_types(); for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { register_new_node(new_node); - transpose_sinking::UpdateForwardSinkingAbility(new_node); + UpdateForwardSinkingAbility(new_node); } return true; @@ -57,8 +58,8 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() { register_matcher(m, matcher_pass_callback); } -ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() { - MATCHER_SCOPE(TransposeSinkingConcatBackward); +TSConcatBackward::TSConcatBackward() { + MATCHER_SCOPE(TSConcatBackward); auto main_node_label = wrap_type([](const Output& output) -> bool { return has_static_rank()(output) && HasSameOutputTransposeNodes(output); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_data_movement.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp similarity index 89% rename from src/common/transformations/src/transformations/common_optimizations/transpose_sinking_data_movement.cpp rename to src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp index 4a13d4179c2333..482841d2c6532f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_data_movement.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp @@ -2,25 +2,24 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp" - -#include +#include "transformations/transpose_sinking/ts_data_movement.hpp" #include "itt.hpp" #include "openvino/op/util/op_types.hpp" #include "openvino/opsets/opset10.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/util/common_util.hpp" -#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/rt_info/transpose_sinking_attr.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" using namespace ov; using namespace ov::opset10; using namespace ov::pass::pattern; -using namespace transpose_sinking; +using namespace ov::pass::transpose_sinking; +using namespace ov::pass::transpose_sinking::utils; -ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForward() { - MATCHER_SCOPE(TransposeSinkingDataMovementForward); +TSDataMovementForward::TSDataMovementForward() { + MATCHER_SCOPE(TSDataMovementForward); auto const_label = wrap_type(); auto transpose_label = wrap_type({any_input(), const_label}); auto main_node_label = @@ -65,7 +64,7 @@ ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForwa TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0}; for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { register_new_node(new_node); - transpose_sinking::UpdateForwardSinkingAbility(new_node); + UpdateForwardSinkingAbility(new_node); } return true; }; @@ -74,8 +73,8 @@ ov::pass::TransposeSinkingDataMovementForward::TransposeSinkingDataMovementForwa register_matcher(m, matcher_pass_callback); } -ov::pass::TransposeSinkingDataMovementBackward::TransposeSinkingDataMovementBackward() { - MATCHER_SCOPE(TransposeSinkingDataMovementBackward); +TSDataMovementBackward::TSDataMovementBackward() { + MATCHER_SCOPE(TSDataMovementBackward); auto main_node_label = wrap_type([](const Output& output) -> bool { return has_static_rank()(output) && HasSameOutputTransposeNodes(output); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_fuse.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_fuse.cpp similarity index 85% rename from src/common/transformations/src/transformations/common_optimizations/transpose_sinking_fuse.cpp rename to src/common/transformations/src/transformations/transpose_sinking/ts_fuse.cpp index 3e0309b09e108e..89805f82211ed5 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_fuse.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_fuse.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/common_optimizations/transpose_sinking_fuse.hpp" +#include "transformations/transpose_sinking/ts_fuse.hpp" #include #include @@ -11,16 +11,18 @@ #include "openvino/core/validation_util.hpp" #include "openvino/opsets/opset10.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" -#include "transformations/common_optimizations/transpose_sinking_utils.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" #include "transformations/utils/utils.hpp" using namespace ov; using namespace opset10; +using namespace ov::pass::transpose_sinking; +using namespace ov::pass::transpose_sinking::utils; -ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() { +TSFuse::TSFuse() { MATCHER_SCOPE(TransposeFuse); auto transpose_1_label = pattern::wrap_type({pattern::any_input(), pattern::wrap_type()}, - transpose_sinking::HasSameOutputTransposeNodes); + HasSameOutputTransposeNodes); auto transpose_2_label = pattern::wrap_type({transpose_1_label, pattern::wrap_type()}); ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { const auto& pattern_to_output = m.get_pattern_map(); @@ -62,11 +64,11 @@ ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() { auto new_transpose = register_new_node(input, new_order); new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name()); - transpose_sinking::RemoveSingleOutputConsumers(transpose1); + RemoveSingleOutputConsumers(transpose1); copy_runtime_info(transpose1, new_transpose); ov::replace_node(transpose1, new_transpose); - transpose_sinking::UpdateForwardSinkingAbility(new_transpose); + UpdateForwardSinkingAbility(new_transpose); } return true; }; diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp new file mode 100644 index 00000000000000..507b874d78b3a9 --- /dev/null +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp @@ -0,0 +1,65 @@ +// Copyright (C) 2022-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/transpose_sinking/ts_general.hpp" + +#include "itt.hpp" +#include "openvino/pass/constant_folding.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/manager.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/transpose_sinking/ts_binary.hpp" +#include "transformations/transpose_sinking/ts_concat.hpp" +#include "transformations/transpose_sinking/ts_data_movement.hpp" +#include "transformations/transpose_sinking/ts_fuse.hpp" +#include "transformations/transpose_sinking/ts_interpolate.hpp" +#include "transformations/transpose_sinking/ts_reduction.hpp" +#include "transformations/transpose_sinking/ts_split.hpp" +#include "transformations/transpose_sinking/ts_unary.hpp" +#include "transformations/utils/utils.hpp" + +using namespace ov::pass::transpose_sinking; + +TSGeneralForward::TSGeneralForward() { + MATCHER_SCOPE(TSGeneralForward); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); +} + +TSGeneralBackward::TSGeneralBackward() { + MATCHER_SCOPE(TSGeneralBackward); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); +} + +bool TSGeneral::run_on_model(const std::shared_ptr& f) { + RUN_ON_FUNCTION_SCOPE(TSGeneral); + { + Manager manager(get_pass_config()); + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + } + + { + Manager manager(get_pass_config()); + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + } + + return false; +} diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_interpolate.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_interpolate.cpp similarity index 91% rename from src/common/transformations/src/transformations/common_optimizations/transpose_sinking_interpolate.cpp rename to src/common/transformations/src/transformations/transpose_sinking/ts_interpolate.cpp index 003a1ee6bb1aa1..0a9c2b7458f019 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_interpolate.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_interpolate.cpp @@ -2,25 +2,25 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp" - -#include +#include "transformations/transpose_sinking/ts_interpolate.hpp" #include "itt.hpp" #include "openvino/op/util/op_types.hpp" #include "openvino/opsets/opset10.hpp" +#include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/util/common_util.hpp" -#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/rt_info/transpose_sinking_attr.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" using namespace ov; using namespace ov::opset10; using namespace ov::pass::pattern; -using namespace transpose_sinking; +using namespace ov::pass::transpose_sinking; +using namespace ov::pass::transpose_sinking::utils; -ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward() { - MATCHER_SCOPE(TransposeSinkingInterpolateForward); +TSInterpolateForward::TSInterpolateForward() { + MATCHER_SCOPE(TSInterpolateForward); auto const_label = wrap_type(); auto transpose_label = wrap_type({any_input(), const_label}); auto main_node_label = wrap_type({transpose_label, any_input(), any_input(), any_input()}); @@ -74,7 +74,7 @@ ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0}; for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { register_new_node(new_node); - transpose_sinking::UpdateForwardSinkingAbility(new_node); + UpdateForwardSinkingAbility(new_node); } return true; }; @@ -83,8 +83,8 @@ ov::pass::TransposeSinkingInterpolateForward::TransposeSinkingInterpolateForward register_matcher(m, matcher_pass_callback); } -ov::pass::TransposeSinkingInterpolateBackward::TransposeSinkingInterpolateBackward() { - MATCHER_SCOPE(TransposeSinkingInterpolateBackward); +TSInterpolateBackward::TSInterpolateBackward() { + MATCHER_SCOPE(TSInterpolateBackward); auto main_node_label = wrap_type([](const Output& output) -> bool { return has_static_rank()(output) && HasSameOutputTransposeNodes(output); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_reduction.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_reduction.cpp similarity index 94% rename from src/common/transformations/src/transformations/common_optimizations/transpose_sinking_reduction.cpp rename to src/common/transformations/src/transformations/transpose_sinking/ts_reduction.cpp index 1fc0a8e3efc53b..ab4e76994ece7f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_reduction.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_reduction.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/common_optimizations/transpose_sinking_reduction.hpp" +#include "transformations/transpose_sinking/ts_reduction.hpp" #include #include @@ -13,11 +13,13 @@ #include "openvino/op/util/logical_reduction_keep_dims.hpp" #include "openvino/opsets/opset10.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" -#include "transformations/common_optimizations/transpose_sinking_utils.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" #include "transformations/utils/utils.hpp" using namespace ov; using namespace opset10; +using namespace ov::pass::transpose_sinking; +using namespace ov::pass::transpose_sinking::utils; namespace { std::vector get_updated_order_forward(const std::vector& axes_values, @@ -80,8 +82,8 @@ bool get_keep_dims(const std::shared_ptr& reduction) { } } // namespace -ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() { - MATCHER_SCOPE(TransposeSinkingReductionForward); +TSReductionForward::TSReductionForward() { + MATCHER_SCOPE(TSReductionForward); auto transpose_label = pattern::wrap_type({pattern::any_input(), pattern::wrap_type()}, pattern::consumers_count(1)); @@ -150,7 +152,7 @@ ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() { replace_node(reduction, new_transpose); new_reduction->set_friendly_name(transpose->get_friendly_name()); new_transpose->set_friendly_name(reduction->get_friendly_name()); - transpose_sinking::UpdateForwardSinkingAbility(new_transpose); + UpdateForwardSinkingAbility(new_transpose); register_new_node(new_transpose); copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction}); return true; @@ -160,13 +162,13 @@ ov::pass::TransposeSinkingReductionForward::TransposeSinkingReductionForward() { register_matcher(m, matcher_pass_callback); } -ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward() { - MATCHER_SCOPE(TransposeSinkingReductionBackward); +TSReductionBackward::TSReductionBackward() { + MATCHER_SCOPE(TSReductionBackward); auto reduce_or_squeeze_label = pattern:: wrap_type( {pattern::any_input(), pattern::wrap_type()}, - transpose_sinking::HasSameOutputTransposeNodes); + HasSameOutputTransposeNodes); auto transpose_label = pattern::wrap_type({reduce_or_squeeze_label, pattern::wrap_type()}); ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { const auto& pattern_to_output = m.get_pattern_value_map(); @@ -225,7 +227,7 @@ ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward() } if (!unsqueeze) { - auto reversed_order_values = transpose_sinking::ReverseTransposeOrder(transpose_order_values); + auto reversed_order_values = ReverseTransposeOrder(transpose_order_values); for (const auto& axis : non_negative_axes) { new_values.push_back(reversed_order_values[axis]); } @@ -246,7 +248,7 @@ ov::pass::TransposeSinkingReductionBackward::TransposeSinkingReductionBackward() } replace_node(transpose, new_reduction); copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction}); - transpose_sinking::UpdateForwardSinkingAbility(new_transpose); + UpdateForwardSinkingAbility(new_transpose); new_reduction->set_friendly_name(transpose->get_friendly_name()); register_new_node(new_transpose); return true; diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_split.cpp similarity index 92% rename from src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp rename to src/common/transformations/src/transformations/transpose_sinking/ts_split.cpp index 386a3d909a5bc3..3aeb74436e7783 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_split.cpp @@ -2,24 +2,22 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/common_optimizations/transpose_sinking_split.hpp" - -#include -#include -#include +#include "transformations/transpose_sinking/ts_split.hpp" #include "itt.hpp" #include "openvino/op/util/op_types.hpp" #include "openvino/opsets/opset10.hpp" #include "openvino/pass/pattern/op/label.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" -#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/rt_info/transpose_sinking_attr.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" +#include "transformations/utils/utils.hpp" using namespace ov::pass::pattern; using namespace ov; using namespace ov::opset10; -using namespace transpose_sinking; +using namespace ov::pass::transpose_sinking; +using namespace ov::pass::transpose_sinking::utils; namespace { @@ -107,7 +105,7 @@ bool GetSplitAxis(const std::shared_ptr& split_axis, const ov::Rank& r * Consider case Split (1) -> Split (2) -> Transpose * If specify Split as main searched node after first transformation work we will have * Split (1) -> Transpose -> Split(2) - * Matcher pass will not call TransposeSinkingSplitBackward since + * Matcher pass will not call TSSplitBackward since * - matcher pattern has no Transpose label * - Split (1) has already been proceeded * Adding Split(2) into the working queue as register_new_node(split) @@ -121,8 +119,8 @@ bool GetSplitAxis(const std::shared_ptr& split_axis, const ov::Rank& r * - add reversed Transpose operations on all outputs except sinking Transpose * nothing to do with new added output Transposes */ -ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() { - MATCHER_SCOPE(TransposeSinkingSplitBackward); +TSSplitBackward::TSSplitBackward() { + MATCHER_SCOPE(TSSplitBackward); auto transpose_const_label = wrap_type(); auto transpose_label = wrap_type({any_input(), transpose_const_label}, IsSplitSinked); @@ -192,8 +190,8 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() { register_matcher(m, matcher_pass_callback); } -ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() { - MATCHER_SCOPE(TransposeSinkingSplitForward); +TSSplitForward::TSSplitForward() { + MATCHER_SCOPE(TSSplitForward); auto main_node_label = wrap_type(IfNodeHasTransposeInputs); @@ -225,7 +223,7 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() { for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { register_new_node(new_node); - transpose_sinking::UpdateForwardSinkingAbility(new_node); + UpdateForwardSinkingAbility(new_node); } return true; diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_unary.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp similarity index 87% rename from src/common/transformations/src/transformations/common_optimizations/transpose_sinking_unary.cpp rename to src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp index de9cb0bf3d5118..ed543ca9d92139 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_unary.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_unary.cpp @@ -2,22 +2,23 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/common_optimizations/transpose_sinking_unary.hpp" +#include "transformations/transpose_sinking/ts_unary.hpp" -#include #include #include "itt.hpp" #include "openvino/opsets/opset10.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" -#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/rt_info/transpose_sinking_attr.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" +#include "transformations/utils/utils.hpp" using namespace ov; using namespace ov::opset10; using namespace ov::pass::pattern; using namespace ov::op::util; -using namespace transpose_sinking; +using namespace ov::pass::transpose_sinking; +using namespace ov::pass::transpose_sinking::utils; namespace { @@ -51,8 +52,8 @@ NodePair SwapNodes(const NodePtr& first_node, const NodePtr& second_node) { } // namespace -ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() { - MATCHER_SCOPE(TransposeSinkingUnaryForward); +TSUnaryForward::TSUnaryForward() { + MATCHER_SCOPE(TSUnaryForward); auto transpose_label = wrap_type({any_input(), any_input()}); auto unary_label = @@ -73,7 +74,7 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() { return true; }; - auto m = std::make_shared(unary_label, "ov::pass::TransposeSinkingUnaryForward"); + auto m = std::make_shared(unary_label, "ov::pass::TSUnaryForward"); register_matcher(m, matcher_pass_callback); } @@ -83,8 +84,8 @@ bool IfSinkingEnabled(const Output& output) { } } // namespace -ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() { - MATCHER_SCOPE(TransposeSinkingUnaryBackwardMultiConsumers); +TSUnaryBackward::TSUnaryBackward() { + MATCHER_SCOPE(TSUnaryBackwardMultiConsumers); auto unary_restrictions = [](const Output& output) -> bool { return HasSameOutputTransposeNodes(output); @@ -115,6 +116,6 @@ ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() { return true; }; - auto m = std::make_shared(transpose_label, "ov::pass::TransposeSinkingUnaryBackward"); + auto m = std::make_shared(transpose_label, "ov::pass::TSUnaryBackward"); register_matcher(m, matcher_pass_callback); } diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp similarity index 98% rename from src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp rename to src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp index dcdc5175ad85e3..9925087be7d9ce 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp @@ -2,17 +2,19 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/common_optimizations/transpose_sinking_utils.hpp" - -#include +#include "transformations/transpose_sinking/ts_utils.hpp" #include "itt.hpp" #include "openvino/op/util/op_types.hpp" #include "openvino/opsets/opset10.hpp" #include "openvino/util/common_util.hpp" #include "transformations/rt_info/transpose_sinking_attr.hpp" +#include "transformations/utils/utils.hpp" +namespace ov { +namespace pass { namespace transpose_sinking { +namespace utils { using namespace ov; using namespace ov::opset10; @@ -377,4 +379,7 @@ void RemoveSingleOutputConsumers(const NodePtr& node) { } } +} // namespace utils } // namespace transpose_sinking +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_binary_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_binary_test.cpp similarity index 95% rename from src/common/transformations/tests/common_optimizations/transpose_sinking_binary_test.cpp rename to src/common/transformations/tests/transpose_sinking/ts_binary_test.cpp index 05e2efd48505ce..6340b3b41e9e5b 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_binary_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_binary_test.cpp @@ -1,40 +1,24 @@ -// Copyright (C) 2022 Intel Corporation +// Copyright (C) 2022-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#include "transformations/transpose_sinking/ts_binary.hpp" + #include -#include -#include -#include -#include -#include #include "common_test_utils/ngraph_test_utils.hpp" #include "gtest/gtest.h" -#include "transpose_sinking_test_utils.hpp" +#include "openvino/frontend/manager.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/manager.hpp" +#include "ts_test_utils.hpp" using namespace ov; using namespace ov::opset10; -using namespace transpose_sinking::testing; - -namespace transpose_sinking_binary_eltwise { +using namespace ov::pass::transpose_sinking; +using namespace transpose_sinking::testing::utils; namespace { -namespace { -std::string to_string(const Shape& shape) { - std::ostringstream result; - result << "{"; - for (size_t idx = 0; idx < shape.size(); ++idx) { - if (idx) - result << ","; - result << shape[idx]; - } - result << "}"; - return result.str(); -} -} // namespace - -// ---------------------------------------------------------------------------- template class BinaryFactory : public IFactory { @@ -87,6 +71,10 @@ std::vector binary_transpose_input_indexes = {0, 1}; } // namespace +namespace transpose_sinking { +namespace testing { +namespace binary { + namespace single_consumer { namespace forward { namespace one_input_transpose { @@ -207,7 +195,7 @@ class TransposeSinkingBinaryTwoTransposeInputsTestFixture : public ::testing::WithParamInterface, public TransformationTestsF { public: - static std::string get_test_name(const testing::TestParamInfo& obj) { + static std::string get_test_name(const ::testing::TestParamInfo& obj) { FactoryPtr binary_factory; PassFactoryPtr pass_factory; size_t num_binary_ops; @@ -247,7 +235,7 @@ TEST_P(TransposeSinkingBinaryTwoTransposeInputsTestFixture, CompareFunctions) { INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryTwoTransposeInputsForwardTestSuite, TransposeSinkingBinaryTwoTransposeInputsTestFixture, ::testing::Combine(::testing::ValuesIn(binary_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)), + ::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)), ::testing::ValuesIn(binary_operations_numbers), ::testing::Values(CreateFunction), ::testing::Values(CreateReferenceFunction), @@ -327,7 +315,7 @@ using TestBinaryParams = std::tuple, public TransformationTestsF { public: - static std::string get_test_name(const testing::TestParamInfo& obj) { + static std::string get_test_name(const ::testing::TestParamInfo& obj) { FactoryPtr binary_factory; PassFactoryPtr pass_factory; size_t num_binary_ops; @@ -377,10 +365,10 @@ TEST_P(TransposeSinkingBinaryTestFixture, CompareFunctions) { } INSTANTIATE_TEST_SUITE_P( - TransposeSinkingBinaryForwardTestSuite, + TSBinaryForwardTestSuite, TransposeSinkingBinaryTestFixture, ::testing::Combine(::testing::ValuesIn(binary_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)), + ::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)), ::testing::ValuesIn(binary_operations_numbers), ::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction), ::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction), @@ -389,10 +377,10 @@ INSTANTIATE_TEST_SUITE_P( TransposeSinkingBinaryTestFixture::get_test_name); INSTANTIATE_TEST_SUITE_P( - TransposeSinkingBinaryBackwardTestSuite, + TSBinaryBackwardTestSuite, TransposeSinkingBinaryTestFixture, ::testing::Combine(::testing::ValuesIn(binary_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)), + ::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)), ::testing::ValuesIn(binary_operations_numbers), ::testing::Values(single_consumer::backward::one_input_transpose::CreateFunction), ::testing::Values(single_consumer::backward::one_input_transpose::CreateReferenceFunction), @@ -421,7 +409,7 @@ class TransposeSinkingBinaryIncompatShapesTestFixture : public ::testing::WithParamInterface, public TransformationTestsF { public: - static std::string get_test_name(const testing::TestParamInfo& obj) { + static std::string get_test_name(const ::testing::TestParamInfo& obj) { FactoryPtr binary_factory; PassFactoryPtr pass_factory; Shape input_shape; @@ -600,7 +588,7 @@ INSTANTIATE_TEST_SUITE_P( TransposeSinkingBinaryIncompatShapesBackwardTestSuite, TransposeSinkingBinaryIncompatShapesTestFixture, ::testing::Combine(::testing::ValuesIn(binary_elementwise_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)), + ::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)), ::testing::Values(Shape{1, 96, 55, 55}), ::testing::ValuesIn(binary::single_consumer::backward::incompat_shapes::constant_shapes), ::testing::Values(binary::single_consumer::backward::incompat_shapes::CreateFunction), @@ -613,7 +601,7 @@ INSTANTIATE_TEST_SUITE_P( TransposeSinkingBinaryIncompatShapesForwardTestSuite, TransposeSinkingBinaryIncompatShapesTestFixture, ::testing::Combine(::testing::ValuesIn(binary_elementwise_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)), + ::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)), ::testing::Values(Shape{1, 96, 55, 55}), ::testing::ValuesIn(binary::single_consumer::forward::incompat_shapes::constant_shapes), ::testing::Values(binary::single_consumer::forward::incompat_shapes::CreateFunction), @@ -626,7 +614,7 @@ INSTANTIATE_TEST_SUITE_P( TransposeSinkingPReluIncompatShapesBackwardTestSuite, TransposeSinkingBinaryIncompatShapesTestFixture, ::testing::Combine(::testing::Values(CREATE_BINARY_FACTORY(PRelu)), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)), + ::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)), ::testing::Values(Shape{1, 3, 16, 16}), ::testing::ValuesIn(std::vector{Shape{3}}), ::testing::Values(binary::single_consumer::backward::incompat_shapes::CreateFunction), @@ -639,7 +627,7 @@ INSTANTIATE_TEST_SUITE_P( TransposeSinkingPReluIncompatShapesForwardTestSuite, TransposeSinkingBinaryIncompatShapesTestFixture, ::testing::Combine(::testing::Values(CREATE_BINARY_FACTORY(PRelu)), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)), + ::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)), ::testing::Values(Shape{1, 3, 16, 16}), ::testing::ValuesIn(std::vector{Shape{3}}), ::testing::Values(binary::single_consumer::forward::incompat_shapes::CreateFunction), @@ -1090,7 +1078,7 @@ using TestBinaryParams = std::tuple, public TransformationTestsF { public: - static std::string get_test_name(const testing::TestParamInfo& obj) { + static std::string get_test_name(const ::testing::TestParamInfo& obj) { FactoryPtr binary_factory; PassFactoryPtr pass_factory; CreateGraphFunctionDesc function_desc; @@ -1139,19 +1127,19 @@ std::vector backward_subtests = { #undef SUBTEST -INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryForwardMultiConsumersTestSuite, +INSTANTIATE_TEST_SUITE_P(TSBinaryForwardMultiConsumersTestSuite, TransposeBinaryMultiSinkingFixture, ::testing::Combine(::testing::ValuesIn(binary_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)), + ::testing::Values(CREATE_PASS_FACTORY(TSBinaryForward)), ::testing::ValuesIn(forward_subtests), ::testing::Values(element::f32), ::testing::ValuesIn(binary_transpose_input_indexes)), TransposeBinaryMultiSinkingFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryBackwardMultiConsumersTestSuite, +INSTANTIATE_TEST_SUITE_P(TSBinaryBackwardMultiConsumersTestSuite, TransposeBinaryMultiSinkingFixture, ::testing::Combine(::testing::ValuesIn(binary_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)), + ::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)), ::testing::ValuesIn(backward_subtests), ::testing::Values(element::f32), ::testing::ValuesIn(binary_transpose_input_indexes)), @@ -1177,7 +1165,7 @@ using TestBinaryParams = std::tuple, public TransformationTestsF { public: - static std::string get_test_name(const testing::TestParamInfo& obj) { + static std::string get_test_name(const ::testing::TestParamInfo& obj) { FactoryPtr binary_factory; PassFactoryPtr pass_factory; CreateGraphFunctionDesc function_desc; @@ -1219,10 +1207,10 @@ std::vector backward_subtests_binary_consumers = { }; #undef SUBTEST -INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryBackwardBinaryMultiConsumersTestSuite, +INSTANTIATE_TEST_SUITE_P(TSBinaryBackwardBinaryMultiConsumersTestSuite, TransposeBinaryMultiSinkingBinaryMultiConsumersFixture, ::testing::Combine(::testing::ValuesIn(binary_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)), + ::testing::Values(CREATE_PASS_FACTORY(TSBinaryBackward)), ::testing::ValuesIn(backward_subtests_binary_consumers), ::testing::Values(element::f32), ::testing::ValuesIn(binary_transpose_input_indexes)), @@ -1232,4 +1220,6 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryBackwardBinaryMultiConsumersTestS } // namespace mult_consumers -} // namespace transpose_sinking_binary_eltwise +} // namespace binary +} // namespace testing +} // namespace transpose_sinking diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_common_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp similarity index 86% rename from src/common/transformations/tests/common_optimizations/transpose_sinking_common_test.cpp rename to src/common/transformations/tests/transpose_sinking/ts_common_test.cpp index 03fdd5463fdc1c..10de11c2893071 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_common_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp @@ -2,26 +2,27 @@ // SPDX-License-Identifier: Apache-2.0 // -#include -#include - #include "common_test_utils/ngraph_test_utils.hpp" #include "gtest/gtest.h" -#include "transformations/common_optimizations/transpose_sinking_binary.hpp" -#include "transformations/common_optimizations/transpose_sinking_concat.hpp" -#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp" -#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp" -#include "transformations/common_optimizations/transpose_sinking_reduction.hpp" -#include "transformations/common_optimizations/transpose_sinking_split.hpp" -#include "transformations/common_optimizations/transpose_sinking_unary.hpp" -#include "transpose_sinking_test_utils.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/manager.hpp" +#include "transformations/transpose_sinking/ts_binary.hpp" +#include "transformations/transpose_sinking/ts_concat.hpp" +#include "transformations/transpose_sinking/ts_data_movement.hpp" +#include "transformations/transpose_sinking/ts_interpolate.hpp" +#include "transformations/transpose_sinking/ts_reduction.hpp" +#include "transformations/transpose_sinking/ts_split.hpp" +#include "transformations/transpose_sinking/ts_unary.hpp" +#include "ts_test_utils.hpp" using namespace std; using namespace ov; using namespace ov::opset10; -using namespace transpose_sinking::testing; +using namespace ov::pass::transpose_sinking; +using namespace transpose_sinking::testing::utils; namespace transpose_sinking { +namespace testing { namespace common { template @@ -362,7 +363,7 @@ auto test_forward_unary = [](const vector& factories, const vector& factories, const vector& idxs, const OutputVector& out_vec) -> OutputVector { @@ -470,7 +471,7 @@ auto test_forward_split = []() { test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}}; test_case.model_ref.main_op = {CREATE_SPLIT_FACTORY(Split)}; test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0, 1, 2}}}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -481,7 +482,7 @@ auto test_forward_pad = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward); + test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward); test_case.num_main_ops = {1, 2}; test_case.inputs_to_main = { parameter(element::f32, {1, 3, 55, 55}), @@ -492,13 +493,13 @@ auto test_forward_pad = []() { // Test model description: test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; test_case.model.main_op = {CREATE_PAD_FACTORY(Pad)}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2}}}; test_case.model_ref.main_op = {CREATE_PAD_FACTORY(Pad)}; test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -509,7 +510,7 @@ auto test_forward_batch_to_space = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward); + test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward); test_case.num_main_ops = {1, 2}; test_case.inputs_to_main = { parameter(element::f32, {128, 55, 3, 128}), @@ -521,13 +522,13 @@ auto test_forward_batch_to_space = []() { // Test model description: test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; test_case.model.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2, 3}}}; test_case.model_ref.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)}; test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -540,7 +541,7 @@ auto test_forward_space_to_batch = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward); + test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward); test_case.num_main_ops = {1}; test_case.inputs_to_main = { parameter(element::f32, {64, 9, 8, 1}), @@ -552,13 +553,13 @@ auto test_forward_space_to_batch = []() { // Test model description: test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; test_case.model.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2, 3}}}; test_case.model_ref.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)}; test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -571,7 +572,7 @@ auto test_forward_reduction = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward); + test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward); test_case.num_main_ops = {1}; test_case.inputs_to_main = { parameter(element::f32, {32, 4, 2, 1}), @@ -581,7 +582,7 @@ auto test_forward_reduction = []() { // Test model description: test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; test_case.model.main_op = reduction_factories; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: auto new_constant = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { @@ -594,7 +595,7 @@ auto test_forward_reduction = []() { test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}}; test_case.model_ref.main_op = reduction_factories; test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -605,7 +606,7 @@ auto test_forward_interpolate = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingInterpolateForward); + test_case.transformation = CREATE_PASS_FACTORY(TSInterpolateForward); test_case.num_main_ops = {1}; test_case.inputs_to_main = { parameter(element::f32, {1, 2, 48, 80}), @@ -617,7 +618,7 @@ auto test_forward_interpolate = []() { // Test model description: test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; test_case.model.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, false)}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: auto set_specific_gather_for = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { @@ -637,7 +638,7 @@ auto test_forward_interpolate = []() { test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{3}}}; test_case.model_ref.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, true)}; test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -650,7 +651,7 @@ auto test_forward_squeeze = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward); + test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward); test_case.num_main_ops = {1}; test_case.inputs_to_main = { parameter(element::f32, {32, 1, 2, 1}), @@ -660,7 +661,7 @@ auto test_forward_squeeze = []() { // Test model description: test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; test_case.model.main_op = {CREATE_BINARY_FACTORY(Squeeze)}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: auto new_constant = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { @@ -673,7 +674,7 @@ auto test_forward_squeeze = []() { test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}}; test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Squeeze)}; test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -684,7 +685,7 @@ auto test_forward_unsqueeze = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward); + test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward); test_case.num_main_ops = {1}; test_case.inputs_to_main = { parameter(element::f32, {32, 3, 2, 1}), @@ -694,7 +695,7 @@ auto test_forward_unsqueeze = []() { // Test model description: test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; test_case.model.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: auto new_constant = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { @@ -714,7 +715,7 @@ auto test_forward_unsqueeze = []() { return new_out_vec; }; test_case.model_ref.preprocess_outputs_of_main = {{new_transpose}, {{0}}}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -727,7 +728,7 @@ auto test_backward_unary = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward); test_case.num_main_ops = {1, 10}; test_case.inputs_to_main = { parameter(element::f32, {1, 96, 55, 55}), @@ -736,12 +737,12 @@ auto test_backward_unary = []() { // Test model description: test_case.model.main_op = unary_factories; test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; test_case.model_ref.main_op = unary_factories; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -752,7 +753,7 @@ auto test_backward_binary = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSBinaryBackward); test_case.num_main_ops = {1, 10}; test_case.inputs_to_main = { parameter(element::f32, {1, 96, 55, 55}), @@ -762,12 +763,12 @@ auto test_backward_binary = []() { // Test model description: test_case.model.main_op = binary_factories; test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{0, 1}}}; test_case.model_ref.main_op = binary_factories; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -778,7 +779,7 @@ auto test_backward_concat = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingConcatBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSConcatBackward); test_case.num_main_ops = {1, 3}; test_case.inputs_to_main = { parameter(element::f32, {1, 96, 55, 55}), @@ -789,12 +790,12 @@ auto test_backward_concat = []() { // Test model description: test_case.model.main_op = {CREATE_CONCAT_FACTORY(Concat)}; test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{0, 1, 2}}}; test_case.model_ref.main_op = {CREATE_CONCAT_REF_FACTORY(Concat)}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -805,7 +806,7 @@ auto test_backward_split = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSplitBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSSplitBackward); test_case.num_main_ops = {1, 2}; test_case.inputs_to_main = { parameter(element::f32, {1, 9, 55, 55}), @@ -815,7 +816,7 @@ auto test_backward_split = []() { // Test model description: test_case.model.main_op = {CREATE_SPLIT_FACTORY(Split)}; test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0, 1, 2}}}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: auto new_constant = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { @@ -827,7 +828,7 @@ auto test_backward_split = []() { }; test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}}; test_case.model_ref.main_op = {CREATE_SPLIT_FACTORY(Split)}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -837,7 +838,7 @@ auto test_backward_pad = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementBackward); test_case.num_main_ops = {1, 2}; test_case.inputs_to_main = { parameter(element::f32, {1, 3, 55, 55}), @@ -848,12 +849,12 @@ auto test_backward_pad = []() { // Test model description: test_case.model.main_op = {CREATE_PAD_FACTORY(Pad)}; test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_gather_for}, {{0}, {1, 2}}}; test_case.model_ref.main_op = {CREATE_PAD_FACTORY(Pad)}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -864,7 +865,7 @@ auto test_backward_batch_to_space = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementBackward); test_case.num_main_ops = {1}; test_case.inputs_to_main = { parameter(element::f32, {128, 55, 3, 128}), @@ -876,12 +877,12 @@ auto test_backward_batch_to_space = []() { // Reference model description: test_case.model.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)}; test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Test model description: test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_gather_for}, {{0}, {1, 2, 3}}}; test_case.model_ref.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -894,7 +895,7 @@ auto test_backward_space_to_batch = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementBackward); test_case.num_main_ops = {1}; test_case.inputs_to_main = { parameter(element::f32, {1, 8, 9, 64}), @@ -906,12 +907,12 @@ auto test_backward_space_to_batch = []() { // Test model description: test_case.model.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)}; test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_gather_for}, {{0}, {1, 2, 3}}}; test_case.model_ref.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -923,7 +924,7 @@ auto test_backward_reduction = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward); test_case.num_main_ops = {1}; test_case.inputs_to_main = { parameter(element::f32, {32, 4, 2, 1}), @@ -933,7 +934,7 @@ auto test_backward_reduction = []() { // Test model description: test_case.model.main_op = reduction_factories; test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: auto new_constant = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { @@ -945,7 +946,7 @@ auto test_backward_reduction = []() { }; test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}}; test_case.model_ref.main_op = reduction_factories; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -958,7 +959,7 @@ auto test_backward_interpolate = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingInterpolateBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSInterpolateBackward); test_case.num_main_ops = {1}; test_case.inputs_to_main = { parameter(element::f32, {1, 2, 48, 80}), @@ -970,7 +971,7 @@ auto test_backward_interpolate = []() { // Test model description: test_case.model.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, true)}; test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: auto set_specific_gather_for = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { @@ -989,7 +990,7 @@ auto test_backward_interpolate = []() { }; test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {3}}}; test_case.model_ref.main_op = {CREATE_INTERPOLATE_FACTORY(Interpolate, false)}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -1002,7 +1003,7 @@ auto test_backward_squeeze = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward); test_case.num_main_ops = {1}; test_case.inputs_to_main = { parameter(element::f32, {32, 1, 2, 1}), @@ -1012,7 +1013,7 @@ auto test_backward_squeeze = []() { // Test model description: test_case.model.main_op = {CREATE_BINARY_FACTORY(Squeeze)}; test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: auto new_transpose = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { @@ -1024,7 +1025,7 @@ auto test_backward_squeeze = []() { }; test_case.model_ref.preprocess_inputs_to_main = {{new_transpose}, {{0}}}; test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Squeeze)}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -1035,7 +1036,7 @@ auto test_backward_unsqueeze = []() { TestCase test_case; // Initialize common attributes - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward); test_case.num_main_ops = {1}; test_case.inputs_to_main = { parameter(element::f32, {32, 3, 2, 1}), @@ -1045,7 +1046,7 @@ auto test_backward_unsqueeze = []() { // Test model description: test_case.model.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)}; test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; - test_case.model.model_template = transpose_sinking::common::create_model; + test_case.model.model_template = create_model; // Reference model description: auto new_constant = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { @@ -1057,7 +1058,7 @@ auto test_backward_unsqueeze = []() { }; test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}}; test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)}; - test_case.model_ref.model_template = transpose_sinking::common::create_model; + test_case.model_ref.model_template = create_model; return wrapper(test_case); }; @@ -1066,4 +1067,5 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_unsqueeze()); } // namespace common +} // namespace testing } // namespace transpose_sinking \ No newline at end of file diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_concat_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_concat_test.cpp similarity index 97% rename from src/common/transformations/tests/common_optimizations/transpose_sinking_concat_test.cpp rename to src/common/transformations/tests/transpose_sinking/ts_concat_test.cpp index 516d8f151f4549..d9c8ccad341e62 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_concat_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_concat_test.cpp @@ -1,27 +1,28 @@ -// Copyright (C) 2022 Intel Corporation +// Copyright (C) 2022-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#include "transformations/transpose_sinking/ts_concat.hpp" + #include -#include -#include -#include -#include -#include -#include #include "common_test_utils/ngraph_test_utils.hpp" #include "gtest/gtest.h" -#include "transpose_sinking_test_utils.hpp" +#include "openvino/frontend/manager.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/manager.hpp" +#include "transformations/init_node_info.hpp" +#include "transformations/transpose_sinking/ts_utils.hpp" +#include "ts_test_utils.hpp" using namespace ov; using namespace ov::opset10; -using namespace transpose_sinking::testing; +using namespace ov::pass::transpose_sinking; +using namespace transpose_sinking::testing::utils; namespace { std::vector concat_operations_numbers = {1, 10}; - std::vector concat_transpose_input_indexes = {0, 2}; NodePtr CreateConcatChain(NodePtr input_node, @@ -331,9 +332,9 @@ TEST_P(TransposeSinkingConcatTestFixture, CompareFunctions) { } INSTANTIATE_TEST_SUITE_P( - TransposeSinkingConcatForwardTestSuite, + TSConcatForwardTestSuite, TransposeSinkingConcatTestFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatForward)), ::testing::ValuesIn(concat_operations_numbers), ::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction), ::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction), @@ -342,9 +343,9 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(5)), TransposeSinkingConcatTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardTestSuite, +INSTANTIATE_TEST_SUITE_P(TSConcatBackwardTestSuite, TransposeSinkingConcatTestFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatBackward)), ::testing::ValuesIn(concat_operations_numbers), ::testing::Values(single_consumer::backward::CreateFunction), ::testing::Values(single_consumer::backward::CreateReferenceFunction), @@ -408,9 +409,9 @@ TEST_P(TransposeSinkingConcatAllTransposesInputTestFixture, CompareFunctions) { } INSTANTIATE_TEST_SUITE_P( - TransposeSinkingConcatForwardAllTransposesTestSuite, + TSConcatForwardAllTransposesTestSuite, TransposeSinkingConcatAllTransposesInputTestFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatForward)), ::testing::ValuesIn(concat_operations_numbers), ::testing::Values(single_consumer::forward::double_transpose::CreateFunction), ::testing::Values(single_consumer::forward::double_transpose::CreateReferenceFunction), @@ -937,9 +938,9 @@ std::vector backward_subtests = { #undef SUBTEST -INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatForwardMultiConsumersTestSuite, +INSTANTIATE_TEST_SUITE_P(TSConcatForwardMultiConsumersTestSuite, TransposeConcatMultiSinkingFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatForward)), ::testing::ValuesIn(concat_operations_numbers), ::testing::ValuesIn(forward_subtests), ::testing::Values(element::f32), @@ -947,9 +948,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatForwardMultiConsumersTestSuite, ::testing::Values(5)), TransposeConcatMultiSinkingFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardMultiConsumersTestSuite, +INSTANTIATE_TEST_SUITE_P(TSConcatBackwardMultiConsumersTestSuite, TransposeConcatMultiSinkingFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatBackward)), ::testing::ValuesIn(concat_operations_numbers), ::testing::ValuesIn(backward_subtests), ::testing::Values(element::f32), @@ -1029,9 +1030,9 @@ std::vector backward_subtests_no_sinking = { #undef SUBTEST -INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardMultiConsumersTestSuite, +INSTANTIATE_TEST_SUITE_P(TSConcatBackwardMultiConsumersTestSuite, TransposeConcatMultiSinkingConcatConsumersFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSConcatBackward)), ::testing::ValuesIn(concat_operations_numbers), ::testing::ValuesIn(backward_subtests_no_sinking), ::testing::Values(element::f32), diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_general_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_general_test.cpp similarity index 92% rename from src/common/transformations/tests/common_optimizations/transpose_sinking_general_test.cpp rename to src/common/transformations/tests/transpose_sinking/ts_general_test.cpp index cc18e93387f3e4..ee077f667340ab 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_general_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_general_test.cpp @@ -1,26 +1,28 @@ -// Copyright (C) 2022 Intel Corporation +// Copyright (C) 2022-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#include "transformations/transpose_sinking/ts_general.hpp" + #include -#include -#include -#include -#include -#include #include "common_test_utils/ngraph_test_utils.hpp" #include "gtest/gtest.h" +#include "openvino/frontend/manager.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/manager.hpp" +#include "transformations/init_node_info.hpp" using namespace testing; using namespace ov::opset10; +using namespace ov::pass::transpose_sinking; using NodePtr = std::shared_ptr; namespace transpose_sinking { namespace testing { namespace general { -TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward) { +TEST_F(TransformationTestsF, TSGeneralTestUnariesTransposesForward) { ov::Shape input_shape = {1, 96, 55, 55}; ov::element::Type input_type = ov::element::f32; size_t num_unary_ops = 10; @@ -53,10 +55,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward function_ref = std::make_shared(in_op, ov::ParameterVector{X}); } - manager.register_pass(); + manager.register_pass(); } -TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackward) { +TEST_F(TransformationTestsF, TSGeneralTestUnariesTransposesBackward) { ov::Shape input_shape = {1, 96, 55, 55}; ov::element::Type input_type = ov::element::f32; size_t num_unary_ops = 10; @@ -88,10 +90,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackwar function_ref = std::make_shared(in_op, ov::ParameterVector{X}); } - manager.register_pass(); + manager.register_pass(); } -TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral) { +TEST_F(TransformationTestsF, TSGeneralTestUnariesTransposesGeneral) { ov::Shape input_shape = {1, 96, 55, 55}; ov::element::Type input_type = ov::element::f32; size_t num_unary_ops = 10; @@ -130,10 +132,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral function_ref = std::make_shared(transpose0, ov::ParameterVector{X}); } - manager.register_pass(); + manager.register_pass(); } -TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) { +TEST_F(TransformationTestsF, TSGeneralTestBinaryGeneral) { ov::Shape input_shape = {1, 96, 55, 55}; ov::element::Type input_type = ov::element::f32; size_t num_binary_ops = 10; @@ -171,10 +173,10 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) { function_ref = std::make_shared(transpose0, ov::ParameterVector{X}); } - manager.register_pass(); + manager.register_pass(); } -TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) { +TEST_F(TransformationTestsF, TSGeneralTestConcatGeneral) { ov::Shape input_shape = {1, 96, 55, 55}; ov::element::Type input_type = ov::element::f32; const size_t num_concat_ops = 3; @@ -224,7 +226,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) { function_ref = std::make_shared(transpose0, ov::ParameterVector{X}); } - manager.register_pass(); + manager.register_pass(); } // ---------------------------------------------------------------------------------------------------------------------- @@ -364,7 +366,7 @@ NodePtr MakeAllNodesSubgraph(NodePtr parent, size_t split_axis, size_t concat_ax return in_op; } -TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) { +TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) { using namespace transpose_sinking::testing::general; ov::Shape input_shape = {1, 96, 40, 55}; ov::element::Type input_type = ov::element::f32; @@ -407,7 +409,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) { function_ref = std::make_shared(transpose1, ov::ParameterVector{X}); } - manager.register_pass(); + manager.register_pass(); } } // namespace general diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_split_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_split_test.cpp similarity index 95% rename from src/common/transformations/tests/common_optimizations/transpose_sinking_split_test.cpp rename to src/common/transformations/tests/transpose_sinking/ts_split_test.cpp index 944b9aeab5f6ee..67e506566cbfe5 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_split_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_split_test.cpp @@ -1,20 +1,23 @@ -// Copyright (C) 2022 Intel Corporation +// Copyright (C) 2022-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#include "transformations/transpose_sinking/ts_split.hpp" + #include -#include -#include -#include -#include -#include #include "common_test_utils/ngraph_test_utils.hpp" #include "gtest/gtest.h" -#include "transpose_sinking_test_utils.hpp" +#include "openvino/frontend/manager.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/manager.hpp" +#include "transformations/init_node_info.hpp" +#include "ts_test_utils.hpp" using namespace ov; using namespace ov::opset10; +using namespace ov::pass::transpose_sinking; +using namespace transpose_sinking::testing::utils; namespace transpose_sinking { namespace testing { @@ -527,9 +530,9 @@ TEST_P(TransposeSinkingSplitTestFixture, CompareFunctions) { pass_factory->registerPass(manager); } -INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitForwardSingleConsumerTestSuite, +INSTANTIATE_TEST_SUITE_P(TSSplitForwardSingleConsumerTestSuite, TransposeSinkingSplitTestFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)), ::testing::ValuesIn(split_operations_numbers), ::testing::ValuesIn(split_outputs_numbers), ::testing::Values(forward::single_consumer::CreateFunction), @@ -538,9 +541,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitForwardSingleConsumerTestSuite, TransposeSinkingSplitTestFixture::get_test_name); INSTANTIATE_TEST_SUITE_P( - TransposeSinkingSplitForwardMultInputNodeConsumersTestSuite, + TSSplitForwardMultInputNodeConsumersTestSuite, TransposeSinkingSplitTestFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)), ::testing::ValuesIn(split_operations_numbers), ::testing::ValuesIn(split_outputs_numbers), ::testing::Values(forward::mult_consumers::input_node_consumers::CreateFunction), @@ -549,9 +552,9 @@ INSTANTIATE_TEST_SUITE_P( TransposeSinkingSplitTestFixture::get_test_name); INSTANTIATE_TEST_SUITE_P( - TransposeSinkingSplitForwardMultInputTransposeConsumersTestSuite, + TSSplitForwardMultInputTransposeConsumersTestSuite, TransposeSinkingSplitTestFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)), ::testing::ValuesIn(split_operations_numbers), ::testing::ValuesIn(split_outputs_numbers), ::testing::Values(forward::mult_consumers::input_transpose_consumers::CreateFunction), @@ -560,9 +563,9 @@ INSTANTIATE_TEST_SUITE_P( TransposeSinkingSplitTestFixture::get_test_name); INSTANTIATE_TEST_SUITE_P( - TransposeSinkingSplitForwardMultOutputConsumersTestSuite, + TSSplitForwardMultOutputConsumersTestSuite, TransposeSinkingSplitTestFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitForward)), ::testing::ValuesIn(split_operations_numbers), ::testing::ValuesIn(split_outputs_numbers), ::testing::Values(forward::mult_consumers::output_consumers::CreateFunction), @@ -570,9 +573,9 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(element::f32)), TransposeSinkingSplitTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardTestSuite, +INSTANTIATE_TEST_SUITE_P(TSSplitBackwardTestSuite, TransposeSinkingSplitTestFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)), ::testing::ValuesIn(split_tree_depth_nums), ::testing::ValuesIn(split_outputs_numbers), ::testing::Values(backward::single_consumer::CreateFunction), @@ -580,9 +583,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardTestSuite, ::testing::Values(element::f32)), TransposeSinkingSplitTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardMultOutputConsumersTestSuite, +INSTANTIATE_TEST_SUITE_P(TSSplitBackwardMultOutputConsumersTestSuite, TransposeSinkingSplitTestFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)), ::testing::ValuesIn(split_tree_depth_nums), ::testing::ValuesIn(split_outputs_numbers), ::testing::Values(backward::mult_output_consumers::CreateFunction), @@ -590,9 +593,9 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardMultOutputConsumersTestSui ::testing::Values(element::f32)), TransposeSinkingSplitTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardMultSplitConsumersTestSuite, +INSTANTIATE_TEST_SUITE_P(TSSplitBackwardMultSplitConsumersTestSuite, TransposeSinkingSplitTestFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)), ::testing::ValuesIn(split_tree_depth_nums), ::testing::ValuesIn(split_outputs_numbers), ::testing::Values(backward::mult_split_consumers::CreateFunction), @@ -764,9 +767,8 @@ using TestSplitBackwardRestrictParams = std::tuple; /* insert transpose function */ -class TransposeSinkingSplitBackwardRestrictTestFixture - : public ::testing::WithParamInterface, - public TransformationTestsF { +class TSSplitBackwardRestrictTestFixture : public ::testing::WithParamInterface, + public TransformationTestsF { public: static std::string get_test_name(const ::testing::TestParamInfo& obj) { PassFactoryPtr pass_factory; @@ -794,7 +796,7 @@ class TransposeSinkingSplitBackwardRestrictTestFixture } }; -TEST_P(TransposeSinkingSplitBackwardRestrictTestFixture, CompareFunctions) { +TEST_P(TSSplitBackwardRestrictTestFixture, CompareFunctions) { PassFactoryPtr pass_factory; size_t split_tree_depth; size_t num_split_outputs; @@ -821,15 +823,15 @@ std::vector insertTransposeFactories = {FUNC(OnlyFirstT #undef FUNC -INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardRestrictTestSuite, - TransposeSinkingSplitBackwardRestrictTestFixture, - ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)), +INSTANTIATE_TEST_SUITE_P(TSSplitBackwardRestrictTestSuite, + TSSplitBackwardRestrictTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TSSplitBackward)), ::testing::Values(1), ::testing::Values(5), ::testing::Values(backward::restrictions::CreateFunction), ::testing::Values(element::f32), ::testing::ValuesIn(insertTransposeFactories)), - TransposeSinkingSplitBackwardRestrictTestFixture::get_test_name); + TSSplitBackwardRestrictTestFixture::get_test_name); } // namespace restrictions diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_test_utils.cpp b/src/common/transformations/tests/transpose_sinking/ts_test_utils.cpp similarity index 93% rename from src/common/transformations/tests/common_optimizations/transpose_sinking_test_utils.cpp rename to src/common/transformations/tests/transpose_sinking/ts_test_utils.cpp index 616ded5ac0d0d2..b0fd9d45bcd8e5 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_test_utils.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_test_utils.cpp @@ -2,13 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transpose_sinking_test_utils.hpp" - -#include -#include -#include +#include "ts_test_utils.hpp" #include "gtest/gtest.h" +#include "openvino/frontend/manager.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/manager.hpp" using namespace std; using namespace ov; @@ -16,6 +15,7 @@ using namespace ov::opset10; namespace transpose_sinking { namespace testing { +namespace utils { shared_ptr create_main_node(const OutputVector& inputs, size_t num_ops, const FactoryPtr& creator) { OutputVector current_inputs = inputs; @@ -83,5 +83,6 @@ std::shared_ptr parameter(ov::element::Type el_type, const PartialShap return std::make_shared(el_type, ps); } +} // namespace utils } // namespace testing } // namespace transpose_sinking \ No newline at end of file diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_test_utils.hpp b/src/common/transformations/tests/transpose_sinking/ts_test_utils.hpp similarity index 91% rename from src/common/transformations/tests/common_optimizations/transpose_sinking_test_utils.hpp rename to src/common/transformations/tests/transpose_sinking/ts_test_utils.hpp index a6654cce2852f6..5bae4e8a8109d1 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_test_utils.hpp +++ b/src/common/transformations/tests/transpose_sinking/ts_test_utils.hpp @@ -4,15 +4,15 @@ #pragma once -#include -#include -#include - #include "common_test_utils/ngraph_test_utils.hpp" #include "gtest/gtest.h" +#include "openvino/frontend/manager.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/manager.hpp" namespace transpose_sinking { namespace testing { +namespace utils { using NodePtr = std::shared_ptr; @@ -53,7 +53,7 @@ class PassFactory : public IPassFactory { } }; using PassFactoryPtr = std::shared_ptr; -#define CREATE_PASS_FACTORY(pass_name) std::make_shared>(#pass_name) +#define CREATE_PASS_FACTORY(pass_name) std::make_shared>(#pass_name) std::string to_string(const ov::Shape& shape); ov::OutputVector set_transpose_for(const std::vector& idxs, const ov::OutputVector& out_vec); @@ -67,5 +67,6 @@ std::shared_ptr constant(ov::element::Type el_type, const ov::Shape& s return ov::opset10::Constant::create(el_type, shape, value); } +} // namespace utils } // namespace testing } // namespace transpose_sinking \ No newline at end of file diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_unary_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_unary_test.cpp similarity index 92% rename from src/common/transformations/tests/common_optimizations/transpose_sinking_unary_test.cpp rename to src/common/transformations/tests/transpose_sinking/ts_unary_test.cpp index 70a153b8005092..7a931d65bf66aa 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_unary_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_unary_test.cpp @@ -2,19 +2,19 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "transformations/common_optimizations/transpose_sinking_unary.hpp" - -#include -#include -#include +#include "transformations/transpose_sinking/ts_unary.hpp" #include "common_test_utils/ngraph_test_utils.hpp" #include "gtest/gtest.h" -#include "transpose_sinking_test_utils.hpp" +#include "openvino/frontend/manager.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/manager.hpp" +#include "ts_test_utils.hpp" using namespace ov; using namespace ov::opset10; -using namespace transpose_sinking::testing; +using namespace ov::pass::transpose_sinking; +using namespace transpose_sinking::testing::utils; namespace transpose_sinking { namespace testing { @@ -407,7 +407,7 @@ auto wrapper = [](const TestCase& test_case) { auto test_forward = []() { TestCase test_case; test_case.main_node = unary_factories; - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward); + test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward); test_case.num_main_ops = {1, 10}; test_case.test_model = CreateFunctionTransposeBefore; test_case.ref_model = CreateFunctionTransposeAfter; @@ -419,7 +419,7 @@ auto test_forward = []() { auto test_backward = []() { TestCase test_case; test_case.main_node = unary_factories; - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward); test_case.num_main_ops = {1, 10}; test_case.test_model = CreateFunctionTransposeAfter; test_case.ref_model = CreateFunctionTransposeBefore; @@ -431,7 +431,7 @@ auto test_backward = []() { auto test_forward_multiple_consumers_reshape = []() { TestCase test_case; test_case.main_node = unary_factories; - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward); + test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward); test_case.num_main_ops = {1, 10}; test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore; test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter; @@ -443,7 +443,7 @@ auto test_forward_multiple_consumers_reshape = []() { auto test_backward_multiple_consumers_reshape = []() { TestCase test_case; test_case.main_node = unary_factories; - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward); test_case.num_main_ops = {1, 10}; test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter; test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore; @@ -456,7 +456,7 @@ auto test_backward_multiple_consumers_reshape = []() { auto test_forward_multiple_consumers_eltwise = []() { TestCase test_case; test_case.main_node = unary_factories; - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward); + test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward); test_case.num_main_ops = {1, 10}; test_case.test_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore; test_case.ref_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter; @@ -468,7 +468,7 @@ auto test_forward_multiple_consumers_eltwise = []() { auto test_backward_multiple_consumers_eltwise = []() { TestCase test_case; test_case.main_node = unary_factories; - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward); test_case.num_main_ops = {1, 10}; test_case.test_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter; test_case.ref_model = mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore; @@ -480,7 +480,7 @@ auto test_backward_multiple_consumers_eltwise = []() { auto test_backward_multiple_consumers_first_node = []() { TestCase test_case; test_case.main_node = unary_factories; - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward); test_case.num_main_ops = {1, 10}; test_case.test_model = mult_consumers_first_node::backward::CreateFunction; test_case.ref_model = mult_consumers_first_node::backward::CreateFunction; @@ -492,7 +492,7 @@ auto test_backward_multiple_consumers_first_node = []() { auto test_backward_multiple_transposes_first_node = []() { TestCase test_case; test_case.main_node = unary_factories; - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward); + test_case.transformation = CREATE_PASS_FACTORY(TSUnaryBackward); test_case.num_main_ops = {1, 10}; test_case.test_model = mult_consumers_first_node::backward_mult_transposes::CreateFunction; test_case.ref_model = mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction; @@ -504,7 +504,7 @@ auto test_backward_multiple_transposes_first_node = []() { auto test_forward_multiple_consumers_first_node = []() { TestCase test_case; test_case.main_node = unary_factories; - test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward); + test_case.transformation = CREATE_PASS_FACTORY(TSUnaryForward); test_case.num_main_ops = {1, 10}; test_case.test_model = mult_consumers_first_node::forward::CreateFunction; test_case.ref_model = mult_consumers_first_node::forward::CreateReferenceFunction; @@ -513,47 +513,47 @@ auto test_forward_multiple_consumers_first_node = []() { return wrapper(test_case); }; -INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardTestSuite, +INSTANTIATE_TEST_SUITE_P(TSUnaryForwardTestSuite, TransposeSinkingUnaryTestFixture, transpose_sinking::testing::unary::test_forward(), TransposeSinkingUnaryTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardTestSuite, +INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardTestSuite, TransposeSinkingUnaryTestFixture, transpose_sinking::testing::unary::test_backward(), TransposeSinkingUnaryTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape, +INSTANTIATE_TEST_SUITE_P(TSUnaryForwardMultConsumersTestSuiteLastNodeReshape, TransposeSinkingUnaryTestFixture, transpose_sinking::testing::unary::test_forward_multiple_consumers_reshape(), TransposeSinkingUnaryTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape, +INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultConsumersTestSuiteLastNodeReshape, TransposeSinkingUnaryTestFixture, transpose_sinking::testing::unary::test_backward_multiple_consumers_reshape(), TransposeSinkingUnaryTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise, +INSTANTIATE_TEST_SUITE_P(TSUnaryForwardMultConsumersTestSuiteLastNodeEltwise, TransposeSinkingUnaryTestFixture, transpose_sinking::testing::unary::test_forward_multiple_consumers_eltwise(), TransposeSinkingUnaryTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteEltwise, +INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultConsumersTestSuiteEltwise, TransposeSinkingUnaryTestFixture, transpose_sinking::testing::unary::test_backward_multiple_consumers_eltwise(), TransposeSinkingUnaryTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode, +INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultConsumersTestSuiteFirstNode, TransposeSinkingUnaryTestFixture, transpose_sinking::testing::unary::test_backward_multiple_consumers_first_node(), TransposeSinkingUnaryTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode, +INSTANTIATE_TEST_SUITE_P(TSUnaryBackwardMultTransposeConsumersTestSuiteFirstNode, TransposeSinkingUnaryTestFixture, transpose_sinking::testing::unary::test_backward_multiple_transposes_first_node(), TransposeSinkingUnaryTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultTransposeConsumersTestSuiteFirstNode, +INSTANTIATE_TEST_SUITE_P(TSUnaryForwardMultTransposeConsumersTestSuiteFirstNode, TransposeSinkingUnaryTestFixture, transpose_sinking::testing::unary::test_forward_multiple_consumers_first_node(), TransposeSinkingUnaryTestFixture::get_test_name); diff --git a/src/frontends/tensorflow/src/frontend.cpp b/src/frontends/tensorflow/src/frontend.cpp index cde35a0aa64f2b..4893ea0a9236c8 100644 --- a/src/frontends/tensorflow/src/frontend.cpp +++ b/src/frontends/tensorflow/src/frontend.cpp @@ -21,7 +21,7 @@ #include "so_extension.hpp" #include "tf_framework_node.hpp" #include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp" -#include "transformations/common_optimizations/transpose_sinking_general.hpp" +#include "transformations/transpose_sinking/ts_general.hpp" #include "translate_session.hpp" #include "utils.hpp" @@ -239,7 +239,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.run_passes(model); } - // TODO: TransposeSinkingGeneral can fail on models with Framework nodes (not converted to OV opset) + // TODO: TSGeneral can fail on models with Framework nodes (not converted to OV opset) auto unsupported_ops = get_unconverted_types_from_model(model); if (unsupported_ops.size() > 0) { return; @@ -248,7 +248,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { { // perform transpose sinking and reverse infer if the model contains only OpenVINO operations ov::pass::Manager manager; - manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.run_passes(model); } diff --git a/src/frontends/tensorflow_lite/src/frontend.cpp b/src/frontends/tensorflow_lite/src/frontend.cpp index d946e041274124..b708d711ec9f76 100644 --- a/src/frontends/tensorflow_lite/src/frontend.cpp +++ b/src/frontends/tensorflow_lite/src/frontend.cpp @@ -17,7 +17,7 @@ #include "tflite_transformations/rfft2d_complex_abs.h" #include "tflite_transformations/tflite_quantize_resolver.hpp" #include "transformations/common_optimizations/transpose_sinking.hpp" -#include "transformations/common_optimizations/transpose_sinking_general.hpp" +#include "transformations/transpose_sinking/ts_general.hpp" using namespace ov; using namespace ov::frontend::tensorflow_lite; @@ -268,7 +268,7 @@ void FrontEnd::normalize(const std::shared_ptr& function) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); - manager.register_pass(); + manager.register_pass(); manager.run_passes(function); }