Skip to content

Commit

Permalink
[LPT] Slice transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Aug 17, 2024
1 parent 8713ca2 commit 537a91a
Show file tree
Hide file tree
Showing 8 changed files with 358 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>

#include "layer_transformation.hpp"

namespace ov {
namespace pass {
namespace low_precision {

/**
* @ingroup ov_transformation_common_api
* @brief SliceTransformation propagates dequantization operations through Slice operation.
*
* For more details about the transformation, refer to
* [SliceTransformation](@ref openvino_docs_OV_UG_lpt_SliceTransformation) page
* in the OpenVINO Developer Guide.
*/
class LP_TRANSFORMATIONS_API SliceTransformation : public LayerTransformation {
public:
OPENVINO_RTTI("SliceTransformation", "0");
SliceTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ov::pass::pattern::Matcher& m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};

} // namespace low_precision
} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
#include "low_precision/relu.hpp"
#include "low_precision/squeeze.hpp"
#include "low_precision/subtract.hpp"
#include "low_precision/slice.hpp"
#include "low_precision/space_to_batch.hpp"
#include "low_precision/split.hpp"
#include "low_precision/shuffle_channels.hpp"
Expand Down Expand Up @@ -267,6 +268,7 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptr<o
ADD_MATCHER(common, ReshapeTransformation, params)
ADD_MATCHER(common, SqueezeTransformation, params)
ADD_MATCHER(common, ShuffleChannelsTransformation, params)
ADD_MATCHER(common, SliceTransformation, params)
ADD_MATCHER(common, SpaceToBatchTransformation, params)
ADD_MATCHER(common, SplitTransformation, params)
ADD_MATCHER(common, StridedSliceTransformation, params)
Expand Down
67 changes: 67 additions & 0 deletions src/common/low_precision_transformations/src/slice.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <memory>

#include "low_precision/slice.hpp"

#include "itt.hpp"
#include "openvino/util/log.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/opsets/opset8.hpp"

#include "low_precision/network_helper.hpp"

namespace ov {
namespace pass {
namespace low_precision {

SliceTransformation::SliceTransformation(const Params& params) : LayerTransformation(params) {
MATCHER_SCOPE(SliceTransformation);
auto matcher = ov::pass::pattern::wrap_type<ov::opset8::Slice>();

ov::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}
return transform(*context, m);
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(matcher, matcher_name);
this->register_matcher(m, callback);
}

bool SliceTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher& m) {
if (!SliceTransformation::canBeTransformed(context, m.get_match_root())) {
return false;
}

const auto strided_slice = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
auto dequantization = NetworkHelper::getDequantization(strided_slice, defaultPrecisions);
const auto newOperation = moveDequantizationAfter(context, strided_slice, NetworkHelper::getDequantization(strided_slice, defaultPrecisions));

OPENVINO_DEBUG("LPT: done: ", newOperation);
return true;
}

bool SliceTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!LayerTransformation::canBeTransformed(context, operation)) {
return false;
}

if (!ov::is_type<ov::opset8::Slice>(operation)) {
return false;
}

const auto dequantization = NetworkHelper::getDequantization(operation);
return dequantization.isPerTensor();
}

bool SliceTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return true;
}
} // namespace low_precision
} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>

#include <gtest/gtest.h>
#include "low_precision_transformations/slice_transformation.hpp"


using namespace LayerTestsDefinitions;

namespace {
const std::vector<ov::element::Type> netPrecisions = {
ov::element::f32
};

const std::vector<ov::pass::low_precision::LayerTransformation::Params> trasformationParamValues = {
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams()
};

const std::vector<LayerTestsDefinitions::SliceTransformationParam> params = {
{
{
256ul,
ov::Shape{ 1, 1, 1, 1 },
{ 0.f },
{ 25.5f },
{ 0.f },
{ 12.8f }
},
{ 0 }, // start
{ 2147483647 }, // end
{ 2 }, // step
{ 2 }, // axes
"u8"
},
{
{
256ul,
ov::Shape{ 1, 3, 1, 1 },
{ 0.f, 0.f, 0.f },
{ 255.f / 1.f, 255.f / 2.f, 255.f / 3.f },
{ 0.f, 0.f, 0.f },
{ 255.f / 1.f, 255.f / 2.f, 255.f / 3.f }
},
{ 0 }, // start
{ 2147483647 }, // end
{ 2 }, // step
{ 2 }, // axes
"f32"
},
};

INSTANTIATE_TEST_SUITE_P(smoke_LPT, SliceTransformation,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(ov::PartialShape({ 1, 3, 24, 24 })),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::ValuesIn(trasformationParamValues),
::testing::ValuesIn(params)),
SliceTransformation::getTestCaseName);

} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp"
#include "ov_lpt_models/common/fake_quantize_on_data.hpp"
#include "ov_lpt_models/common/dequantization_operations.hpp"

namespace LayerTestsDefinitions {
class SliceTransformationParam {
public:
ov::builder::subgraph::FakeQuantizeOnData fakeQuantize;
std::vector<int64_t> start;
std::vector<int64_t> stop;
std::vector<int64_t> step;
std::vector<int64_t> axes;
std::string expectedPrecision;
};

typedef std::tuple<
ov::element::Type,
ov::PartialShape,
std::string,
ov::pass::low_precision::LayerTransformation::Params,
SliceTransformationParam
> SliceTransformationParams;

class SliceTransformation :
public testing::WithParamInterface<SliceTransformationParams>,
public LayerTestsUtils::LayerTransformation {
public:
static std::string getTestCaseName(const testing::TestParamInfo<SliceTransformationParams>& obj);

protected:
void SetUp() override;
};

} // namespace LayerTestsDefinitions
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "low_precision_transformations/slice_transformation.hpp"
#include <sstream>
#include <string>
#include <vector>

#include "ov_lpt_models/slice.hpp"

namespace LayerTestsDefinitions {

inline std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& values) {
os << "{ ";
for (size_t i = 0; i < values.size(); ++i) {
os << values[i];
if (i != (values.size() - 1ul)) {
os << ", ";
}
}
os << " }";
return os;
}

std::string SliceTransformation::getTestCaseName(const testing::TestParamInfo<SliceTransformationParams>& obj) {
ov::element::Type netPrecision;
ov::PartialShape inputShape;
std::string targetDevice;
ov::pass::low_precision::LayerTransformation::Params params;
SliceTransformationParam param;;
std::tie(netPrecision, inputShape, targetDevice, params, param) = obj.param;

std::ostringstream result;
result << get_test_case_name_by_params(netPrecision, inputShape, targetDevice, params) << "_" <<
param.fakeQuantize << "_" <<
param.start << "_" <<
param.stop << "_" <<
param.step << "_" <<
param.axes;
return result.str();
}

void SliceTransformation::SetUp() {
ov::element::Type netPrecision;
ov::PartialShape inputShape;
ov::pass::low_precision::LayerTransformation::Params params;
SliceTransformationParam param;
std::tie(netPrecision, inputShape, targetDevice, params, param) = this->GetParam();

init_input_shapes(inputShape);

function = ov::builder::subgraph::SliceFunction::get(
netPrecision,
inputShape,
param.fakeQuantize,
param.start,
param.stop,
param.step,
param.axes);
}

TEST_P(SliceTransformation, CompareWithRefImpl) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
run();

const auto params = std::get<4>(GetParam());
const auto& actualPrecision = get_runtime_precision_by_type("StridedSlice");
EXPECT_EQ(actualPrecision, params.expectedPrecision);
};

} // namespace LayerTestsDefinitions
30 changes: 30 additions & 0 deletions src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/slice.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <low_precision/layer_transformation.hpp>
#include "ov_lpt_models/common/builders.hpp"
#include "ov_lpt_models/common/dequantization_operations.hpp"

namespace ov {
namespace builder {
namespace subgraph {

class SliceFunction {
public:
static std::shared_ptr<ov::Model> get(
const ov::element::Type inputPrecision,
const ov::PartialShape& inputShape,
const ov::builder::subgraph::FakeQuantizeOnData& fakeQuantize,
const std::vector<int64_t>& start,
const std::vector<int64_t>& stop,
const std::vector<int64_t>& step,
const std::vector<int64_t>& axes);
};

} // namespace subgraph
} // namespace builder
} // namespace ov
49 changes: 49 additions & 0 deletions src/tests/ov_helpers/ov_lpt_models/src/slice.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ov_lpt_models/slice.hpp"
#include "openvino/opsets/opset8.hpp"

using namespace ov::pass::low_precision;

namespace ov {
namespace builder {
namespace subgraph {

std::shared_ptr<ov::Model> SliceFunction::get(
const ov::element::Type inputPrecision,
const ov::PartialShape& inputShape,
const ov::builder::subgraph::FakeQuantizeOnData& fakeQuantize,
const std::vector<int64_t>& start,
const std::vector<int64_t>& stop,
const std::vector<int64_t>& step,
const std::vector<int64_t>& axes) {
const auto input = std::make_shared<ov::opset1::Parameter>(inputPrecision, inputShape);
input->set_friendly_name("input");
const auto fqOnData = makeFakeQuantize(input, inputPrecision, fakeQuantize);

const auto start_constant = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{ start.size() }, start);
start_constant->set_friendly_name("start");
const auto stop_constant = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{ stop.size() }, stop);
stop_constant->set_friendly_name("stop");
const auto step_constant = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{ step.size() }, step);
step_constant->set_friendly_name("step");
const auto axes_constant = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{ axes.size() }, axes);
axes_constant->set_friendly_name("axes ");

const auto stridedSlice = std::make_shared<ov::opset8::Slice>(fqOnData, start_constant, stop_constant, step_constant, axes_constant);
stridedSlice->set_friendly_name("slice");

const auto res = std::make_shared<ov::opset1::Result>(stridedSlice);
const auto function = std::make_shared<ov::Model>(
ov::ResultVector{ res },
ov::ParameterVector{ input },
"SliceTransformation");

return function;
}

} // namespace subgraph
} // namespace builder
} // namespace ov

0 comments on commit 537a91a

Please sign in to comment.