Skip to content

Commit

Permalink
[LPT] SpaceToBatch & BatchToSpace implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Sep 12, 2023
1 parent e3f1ff7 commit e28c0c6
Show file tree
Hide file tree
Showing 30 changed files with 1,238 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/IE_PLUGIN_DG/layout.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
<tab type="user" title="Step 3. Main transformations" url="@ref openvino_docs_OV_UG_lpt_step3_main">
<tab type="user" title="AddTransformation" url="@ref openvino_docs_OV_UG_lpt_AddTransformation"/>
<tab type="user" title="AvgPoolTransformation" url="@ref openvino_docs_OV_UG_lpt_AvgPoolTransformation"/>
<tab type="user" title="BatchToSpaceTransformation" url="@ref openvino_docs_OV_UG_lpt_BatchToSpaceTransformation"/>
<tab type="user" title="ClampTransformation" url="@ref openvino_docs_OV_UG_lpt_ClampTransformation"/>
<tab type="user" title="ConcatTransformation" url="@ref openvino_docs_OV_UG_lpt_ConcatTransformation"/>
<tab type="user" title="ConvolutionTransformation" url="@ref openvino_docs_OV_UG_lpt_ConvolutionTransformation"/>
Expand All @@ -62,6 +63,7 @@
<tab type="user" title="ReshapeTransformation" url="@ref openvino_docs_OV_UG_lpt_ReshapeTransformation"/>
<tab type="user" title="SqueezeTransformation" url="@ref openvino_docs_OV_UG_lpt_SqueezeTransformation"/>
<tab type="user" title="ShuffleChannelsTransformation" url="@ref openvino_docs_OV_UG_lpt_ShuffleChannelsTransformation"/>
<tab type="user" title="SpaceToBatchTransformation" url="@ref openvino_docs_OV_UG_lpt_SpaceToBatchTransformation"/>
<tab type="user" title="SplitTransformation" url="@ref openvino_docs_OV_UG_lpt_SplitTransformation"/>
<tab type="user" title="StridedSliceTransformation" url="@ref openvino_docs_OV_UG_lpt_StridedSliceTransformation"/>
<tab type="user" title="TransposeTransformation" url="@ref openvino_docs_OV_UG_lpt_TransposeTransformation"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ Transformations:
* :doc:`AddTransformation <openvino_docs_OV_UG_lpt_AddTransformation>`
* :doc:`AvgPoolTransformation <openvino_docs_OV_UG_lpt_AvgPoolTransformation>`
* :doc:`ClampTransformation <openvino_docs_OV_UG_lpt_AvgPoolTransformation>`
* :doc:`BatchToSpaceTransformation <openvino_docs_OV_UG_lpt_BatchToSpaceTransformation>`
* :doc:`ConcatTransformation <openvino_docs_OV_UG_lpt_ConcatTransformation>`
* :doc:`ConvolutionTransformation <openvino_docs_OV_UG_lpt_ConvolutionTransformation>`
* :doc:`ConvolutionBackpropDataTransformation <openvino_docs_OV_UG_lpt_ConvolutionBackpropDataTransformation>`
Expand All @@ -211,6 +212,7 @@ Transformations:
* :doc:`ReshapeTransformation <openvino_docs_OV_UG_lpt_ReshapeTransformation>`
* :doc:`SqueezeTransformation <openvino_docs_OV_UG_lpt_SqueezeTransformation>`
* :doc:`ShuffleChannelsTransformation <openvino_docs_OV_UG_lpt_ShuffleChannelsTransformation>`
* :doc:`SpaceToBatchTransformation <openvino_docs_OV_UG_lpt_SpaceToBatchTransformation>`
* :doc:`SplitTransformation <openvino_docs_OV_UG_lpt_SplitTransformation>`
* :doc:`StridedSliceTransformation <openvino_docs_OV_UG_lpt_StridedSliceTransformation>`
* :doc:`TransposeTransformation <openvino_docs_OV_UG_lpt_TransposeTransformation>`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Main transformations are the majority of low precision transformations. Transfor

* :doc:`AddTransformation <openvino_docs_OV_UG_lpt_AddTransformation>`
* :doc:`AvgPoolTransformation <openvino_docs_OV_UG_lpt_AvgPoolTransformation>`
* :doc:`BatchToSpaceTransformation <openvino_docs_OV_UG_lpt_BatchToSpaceTransformation>`
* :doc:`ClampTransformation <openvino_docs_OV_UG_lpt_AvgPoolTransformation>`
* :doc:`ConcatTransformation <openvino_docs_OV_UG_lpt_ConcatTransformation>`
* :doc:`ConvolutionTransformation <openvino_docs_OV_UG_lpt_ConvolutionTransformation>`
Expand All @@ -34,6 +35,7 @@ Main transformations are the majority of low precision transformations. Transfor
* :doc:`ReduceSumTransformation <openvino_docs_OV_UG_lpt_ReduceSumTransformation>`
* :doc:`ReluTransformation <openvino_docs_OV_UG_lpt_ReluTransformation>`
* :doc:`ReshapeTransformation <openvino_docs_OV_UG_lpt_ReshapeTransformation>`
* :doc:`SpaceToBatchTransformation <openvino_docs_OV_UG_lpt_SpaceToBatchTransformation>`
* :doc:`SqueezeTransformation <openvino_docs_OV_UG_lpt_SqueezeTransformation>`
* :doc:`ShuffleChannelsTransformation <openvino_docs_OV_UG_lpt_ShuffleChannelsTransformation>`
* :doc:`SplitTransformation <openvino_docs_OV_UG_lpt_SplitTransformation>`
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# BatchToSpaceTransformation transformation {#openvino_docs_OV_UG_lpt_BatchToSpaceTransformation}

ngraph::pass::low_precision::BatchToSpaceTransformation class represents the `BatchToSpace` operation transformation.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SpaceToBatchTransformation transformation {#openvino_docs_OV_UG_lpt_SpaceToBatchTransformation}

ngraph::pass::low_precision::SpaceToBatchTransformation class represents the `SpaceToBatch` operation transformation.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <ngraph/ngraph.hpp>
#include "low_precision/layer_transformation.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

/**
* @ingroup ie_transformation_common_api
* @brief BatchToSpaceTransformation propagates dequantization operations through BatchToSpace operation.
*
* For more details about the transformation, refer to
* [BatchToSpaceTransformation](@ref openvino_docs_OV_UG_lpt_BatchToSpaceTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API BatchToSpaceTransformation : public LayerTransformation {
public:
OPENVINO_RTTI("BatchToSpaceTransformation", "0");
BatchToSpaceTransformation(const Params& params = Params());
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};

} // namespace low_precision
} // namespace pass
} // namespace ngraph
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LP_TRANSFORMATIONS_API FakeQuantizeDequantization {
bool multiplyHasZeroOrDenormal() const;
bool isShared() const;
bool isLowPrecision() const;
bool isPerTensor() const;
std::shared_ptr<Node> copyWithNewInput(const std::shared_ptr<Node>& input) const;

bool checkElementwise(const std::shared_ptr<ngraph::Node>& elementwise) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* @brief A macro used to throw the exception with a notable description for low precision transformations
*/
#define THROW_IE_LPT_EXCEPTION(node) throw ::ngraph::pass::low_precision::InferenceEngineLptException(__FILE__, __LINE__, node)
#define THROW_IE_LPT_EXCEPTION_BASE throw ::ngraph::pass::low_precision::InferenceEngineLptException(__FILE__, __LINE__)

namespace ngraph {
namespace pass {
Expand Down Expand Up @@ -49,6 +50,10 @@ class LP_TRANSFORMATIONS_API InferenceEngineLptException : public Exception {
<< filename << ":" << line << " Exception during low precision transformation for "
<< node << " node with type '" << node.get_type_name() << "', name '" << node.get_friendly_name() << "'. ";
}

InferenceEngineLptException(const std::string& filename, const size_t line) {
*this << filename << ":" << line << " Exception during low precision transformation. ";
}
};

} // namespace low_precision
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <ngraph/ngraph.hpp>
#include "low_precision/layer_transformation.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

/**
* @ingroup ie_transformation_common_api
* @brief SpaceToBatchTransformation propagates dequantization operations through SpaceToBatch operation.
*
* For more details about the transformation, refer to
* [SpaceToBatchTransformation](@ref openvino_docs_OV_UG_lpt_SpaceToBatchTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API SpaceToBatchTransformation : public LayerTransformation {
public:
OPENVINO_RTTI("SpaceToBatchTransformation", "0");
SpaceToBatchTransformation(const Params& params = Params());
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};

} // namespace low_precision
} // namespace pass
} // namespace ngraph
66 changes: 66 additions & 0 deletions src/common/low_precision_transformations/src/batch_to_space.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "low_precision/batch_to_space.hpp"

#include <memory>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset2.hpp>

#include <ngraph/pattern/op/wrap_type.hpp>

#include "low_precision/network_helper.hpp"
#include "itt.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

BatchToSpaceTransformation::BatchToSpaceTransformation(const Params& params) : LayerTransformation(params) {
MATCHER_SCOPE(BatchToSpaceTransformation);
auto matcher = pattern::wrap_type<opset2::BatchToSpace>();

ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}
return transform(*context, m);
};

auto m = std::make_shared<ngraph::pattern::Matcher>(matcher, matcher_name);
this->register_matcher(m, callback);
}

bool BatchToSpaceTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
if (!LayerTransformation::canBeTransformed(context, op)) {
return false;
}

const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op, defaultPrecisions);
if (dequantization.empty()) {
return false;
}

return dequantization.isPerTensor();
}

bool BatchToSpaceTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
if (!canBeTransformed(context, m.get_match_root())) {
return false;
}

const std::shared_ptr<Node> pooling = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
moveDequantizationAfter(context, pooling, NetworkHelper::getDequantization(pooling, defaultPrecisions), false);
return true;
}

bool BatchToSpaceTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return true;
}

} // namespace low_precision
} // namespace pass
} // namespace ngraph
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,26 @@ bool FakeQuantizeDequantization::isLowPrecision() const {
return DataPrecision::isSupported(data.get_element_type());
}

bool FakeQuantizeDequantization::isPerTensor() const {
if (multiplyConstant == nullptr) {
THROW_IE_LPT_EXCEPTION_BASE << "multiply constant can not be empty";
}

const std::vector<float>& scales = multiplyConstant->cast_vector<float>();
if (scales.size() != 1ull) {
return false;
}

if (subtractConstant != nullptr) {
const std::vector<float>& scales = subtractConstant->cast_vector<float>();
if (scales.size() != 1ull) {
return false;
}
}

return true;
}

bool FakeQuantizeDequantization::checkShape(const std::shared_ptr<ngraph::Node>& elementwise) {
std::shared_ptr<ngraph::opset1::Convert> convert;
std::shared_ptr<ngraph::opset1::Constant> constant;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "low_precision/add.hpp"
#include "low_precision/assign_and_read_value.hpp"
#include "low_precision/avg_pool.hpp"
#include "low_precision/batch_to_space.hpp"
#include "low_precision/clamp.hpp"
#include "low_precision/convolution.hpp"
#include "low_precision/convolution_backprop_data.hpp"
Expand All @@ -66,6 +67,7 @@
#include "low_precision/relu.hpp"
#include "low_precision/squeeze.hpp"
#include "low_precision/subtract.hpp"
#include "low_precision/space_to_batch.hpp"
#include "low_precision/split.hpp"
#include "low_precision/shuffle_channels.hpp"
#include "low_precision/strided_slice.hpp"
Expand Down Expand Up @@ -237,6 +239,7 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p
ADD_MATCHER(common, AddTransformation, params)
ADD_MATCHER(common, AssignAndReadValueTransformation, f, params)
ADD_MATCHER(common, AvgPoolTransformation, params)
ADD_MATCHER(common, BatchToSpaceTransformation, params)
ADD_MATCHER(common, ClampTransformation, params)
ADD_MATCHER(common, ConcatTransformation, params)
ADD_MATCHER(common, ConvolutionTransformation, params)
Expand All @@ -262,6 +265,7 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p
ADD_MATCHER(common, ReshapeTransformation, params)
ADD_MATCHER(common, SqueezeTransformation, params)
ADD_MATCHER(common, ShuffleChannelsTransformation, params)
ADD_MATCHER(common, SpaceToBatchTransformation, params)
ADD_MATCHER(common, SplitTransformation, params)
ADD_MATCHER(common, StridedSliceTransformation, params)
ADD_MATCHER(common, TransposeTransformation, params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <vector>

#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/opsets/opset6.hpp>
Expand Down Expand Up @@ -152,10 +153,12 @@ bool ngraph::pass::low_precision::MarkupPrecisions::isPrecisionPreserved(const s
{ name<opset1::ReduceMin>() },
{ name<opset1::Relu>() },
// TODO: there are conditions
{ name<opset2::BatchToSpace>() },
{ name<opset1::Pad>() },
{ name<ov::opset12::Pad>() },
{ name<opset1::Reshape>() },
{ name<opset1::Squeeze>() },
{ name<opset2::SpaceToBatch>() },
{ name<opset1::Split>() },
{ name<opset1::StridedSlice>() },
{ name<opset1::ShuffleChannels>() },
Expand Down
66 changes: 66 additions & 0 deletions src/common/low_precision_transformations/src/space_to_batch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "low_precision/space_to_batch.hpp"

#include <memory>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset2.hpp>

#include <ngraph/pattern/op/wrap_type.hpp>

#include "low_precision/network_helper.hpp"
#include "itt.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

SpaceToBatchTransformation::SpaceToBatchTransformation(const Params& params) : LayerTransformation(params) {
MATCHER_SCOPE(SpaceToBatchTransformation);
auto matcher = pattern::wrap_type<opset2::SpaceToBatch>();

ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}
return transform(*context, m);
};

auto m = std::make_shared<ngraph::pattern::Matcher>(matcher, matcher_name);
this->register_matcher(m, callback);
}

bool SpaceToBatchTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
if (!LayerTransformation::canBeTransformed(context, op)) {
return false;
}

const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(op, defaultPrecisions);
if (dequantization.empty()) {
return false;
}

return dequantization.isPerTensor();
}

bool SpaceToBatchTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
if (!canBeTransformed(context, m.get_match_root())) {
return false;
}

const std::shared_ptr<Node> pooling = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
moveDequantizationAfter(context, pooling, NetworkHelper::getDequantization(pooling, defaultPrecisions), false);
return true;
}

bool SpaceToBatchTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return true;
}

} // namespace low_precision
} // namespace pass
} // namespace ngraph
Loading

0 comments on commit e28c0c6

Please sign in to comment.