Skip to content

Commit

Permalink
[LPT] Slice transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Aug 16, 2024
1 parent 8713ca2 commit 5c10202
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>

#include "layer_transformation.hpp"

namespace ov {
namespace pass {
namespace low_precision {

/**
* @ingroup ov_transformation_common_api
* @brief SliceTransformation propagates dequantization operations through Slice operation.
*
* For more details about the transformation, refer to
* [SliceTransformation](@ref openvino_docs_OV_UG_lpt_SliceTransformation) page
* in the OpenVINO Developer Guide.
*/
class LP_TRANSFORMATIONS_API SliceTransformation : public LayerTransformation {
public:
OPENVINO_RTTI("SliceTransformation", "0");
SliceTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ov::pass::pattern::Matcher& m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};

} // namespace low_precision
} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
#include "low_precision/relu.hpp"
#include "low_precision/squeeze.hpp"
#include "low_precision/subtract.hpp"
#include "low_precision/slice.hpp"
#include "low_precision/space_to_batch.hpp"
#include "low_precision/split.hpp"
#include "low_precision/shuffle_channels.hpp"
Expand Down Expand Up @@ -267,6 +268,7 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptr<o
ADD_MATCHER(common, ReshapeTransformation, params)
ADD_MATCHER(common, SqueezeTransformation, params)
ADD_MATCHER(common, ShuffleChannelsTransformation, params)
ADD_MATCHER(common, SliceTransformation, params)
ADD_MATCHER(common, SpaceToBatchTransformation, params)
ADD_MATCHER(common, SplitTransformation, params)
ADD_MATCHER(common, StridedSliceTransformation, params)
Expand Down
67 changes: 67 additions & 0 deletions src/common/low_precision_transformations/src/slice.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <memory>

#include "low_precision/slice.hpp"

#include "itt.hpp"
#include "openvino/util/log.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/opsets/opset8.hpp"

#include "low_precision/network_helper.hpp"

namespace ov {
namespace pass {
namespace low_precision {

SliceTransformation::SliceTransformation(const Params& params) : LayerTransformation(params) {
MATCHER_SCOPE(SliceTransformation);
auto matcher = ov::pass::pattern::wrap_type<ov::opset8::Slice>();

ov::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}
return transform(*context, m);
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(matcher, matcher_name);
this->register_matcher(m, callback);
}

bool SliceTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher& m) {
if (!SliceTransformation::canBeTransformed(context, m.get_match_root())) {
return false;
}

const auto strided_slice = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
auto dequantization = NetworkHelper::getDequantization(strided_slice, defaultPrecisions);
const auto newOperation = moveDequantizationAfter(context, strided_slice, NetworkHelper::getDequantization(strided_slice, defaultPrecisions));

OPENVINO_DEBUG("LPT: done: ", newOperation);
return true;
}

bool SliceTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!LayerTransformation::canBeTransformed(context, operation)) {
return false;
}

if (!ov::is_type<ov::opset8::Slice>(operation)) {
return false;
}

const auto dequantization = NetworkHelper::getDequantization(operation);
return dequantization.isPerTensor();
}

bool SliceTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return true;
}
} // namespace low_precision
} // namespace pass
} // namespace ov

0 comments on commit 5c10202

Please sign in to comment.