diff --git a/inference-engine/src/transformations/include/transformations/low_precision/split.hpp b/inference-engine/src/transformations/include/transformations/low_precision/split.hpp new file mode 100644 index 00000000000000..4d505951f320ac --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/low_precision/split.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "layer_transformation.hpp" +#include "ngraph/node.hpp" + +namespace ngraph { +namespace pass { +namespace low_precision { + +class TRANSFORMATIONS_API SplitTransformation : public LayerTransformation { +public: + SplitTransformation(const Params& params); + void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override; + void transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override; + bool isPrecisionPreserved(std::shared_ptr layer) const noexcept override; + void updateOutputs( + TransformationContext& context, + std::vector> lastNodes, + std::shared_ptr originalNode) const; +protected: + ngraph::Shape getConstSplitShape( + const std::vector& constSplitLengths, + const ngraph::Shape& constShape, const size_t axis, + const size_t idx) const; + virtual std::vector getConstSplitLengths( + const OutputVector& inputs, + const ngraph::Shape& constShape, + const size_t outputSize) const; +}; +} // namespace low_precision +} // namespace pass +} // namespace ngraph diff --git a/inference-engine/src/transformations/include/transformations/low_precision/variadic_split.hpp b/inference-engine/src/transformations/include/transformations/low_precision/variadic_split.hpp index bb4ee522751281..acb626a36231a0 100644 --- a/inference-engine/src/transformations/include/transformations/low_precision/variadic_split.hpp +++ b/inference-engine/src/transformations/include/transformations/low_precision/variadic_split.hpp @@ -22,7 +22,7 @@ class TRANSFORMATIONS_API VariadicSplitTransformation : public SplitTransformati const OutputVector& inputs, const ngraph::Shape& constShape, const size_t outputSize) const override; -}; +}; } // namespace low_precision } // namespace pass } // namespace ngraph