Skip to content

Commit

Permalink
comment fixes and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Sep 20, 2024
1 parent a1398fc commit 3576d4b
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 19 deletions.
1 change: 0 additions & 1 deletion src/common/low_precision_transformations/src/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ bool SliceTransformation::transform(TransformationContext& context, ov::pass::pa
}

const auto strided_slice = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
auto dequantization = NetworkHelper::getDequantization(strided_slice, defaultPrecisions);
const auto newOperation = moveDequantizationAfter(context, strided_slice, NetworkHelper::getDequantization(strided_slice, defaultPrecisions));

OPENVINO_DEBUG("LPT: done: ", newOperation);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>

#include <gtest/gtest.h>
#include "low_precision_transformations/slice_transformation.hpp"


using namespace LayerTestsDefinitions;

namespace {
const std::vector<ov::element::Type> netPrecisions = {
ov::element::f32
};

const std::vector<ov::pass::low_precision::LayerTransformation::Params> trasformationParamValues = {
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams()
};

const std::vector<LayerTestsDefinitions::SliceTransformationParam> params = {
{
{
256ul,
ov::Shape{ 1, 1, 1, 1 },
{ 0.f },
{ 25.5f },
{ 0.f },
{ 12.8f }
},
{ 0 }, // start
{ 2147483647 }, // end
{ 2 }, // step
{ 2 }, // axes
"u8"
},
{
{
256ul,
ov::Shape{ 1, 3, 1, 1 },
{ 0.f, 0.f, 0.f },
{ 255.f / 1.f, 255.f / 2.f, 255.f / 3.f },
{ 0.f, 0.f, 0.f },
{ 255.f / 1.f, 255.f / 2.f, 255.f / 3.f }
},
{ 0 }, // start
{ 2147483647 }, // end
{ 2 }, // step
{ 2 }, // axes
"f32"
},
};

INSTANTIATE_TEST_SUITE_P(smoke_LPT, SliceTransformation,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(ov::PartialShape({ 1, 3, 24, 24 })),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::ValuesIn(trasformationParamValues),
::testing::ValuesIn(params)),
SliceTransformation::getTestCaseName);

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,6 @@

namespace LayerTestsDefinitions {

inline std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& values) {
os << "{ ";
for (size_t i = 0; i < values.size(); ++i) {
os << values[i];
if (i != (values.size() - 1ul)) {
os << ", ";
}
}
os << " }";
return os;
}

std::string SliceTransformation::getTestCaseName(const testing::TestParamInfo<SliceTransformationParams>& obj) {
ov::element::Type netPrecision;
ov::PartialShape inputShape;
Expand All @@ -34,10 +22,10 @@ std::string SliceTransformation::getTestCaseName(const testing::TestParamInfo<Sl
std::ostringstream result;
result << get_test_case_name_by_params(netPrecision, inputShape, targetDevice, params) << "_" <<
param.fakeQuantize << "_" <<
param.start << "_" <<
param.stop << "_" <<
param.step << "_" <<
param.axes;
ov::test::utils::vec2str(param.start) << "_" <<
ov::test::utils::vec2str(param.stop) << "_" <<
ov::test::utils::vec2str(param.step) << "_" <<
ov::test::utils::vec2str(param.axes);
return result.str();
}

Expand Down
13 changes: 11 additions & 2 deletions src/tests/ov_helpers/ov_lpt_models/src/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ std::shared_ptr<ov::Model> SliceFunction::get(
const std::vector<int64_t>& axes) {
const auto input = std::make_shared<ov::opset1::Parameter>(inputPrecision, inputShape);
input->set_friendly_name("input");
const auto fqOnData = makeFakeQuantize(input, inputPrecision, fakeQuantize);

std::shared_ptr<ov::Node> parent = input;
if (!fakeQuantize.empty()) {
parent = makeFakeQuantize(parent, inputPrecision, fakeQuantize);
}

const auto start_constant = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{ start.size() }, start);
start_constant->set_friendly_name("start");
Expand All @@ -32,7 +36,12 @@ std::shared_ptr<ov::Model> SliceFunction::get(
const auto axes_constant = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{ axes.size() }, axes);
axes_constant->set_friendly_name("axes ");

const auto stridedSlice = std::make_shared<ov::opset8::Slice>(fqOnData, start_constant, stop_constant, step_constant, axes_constant);
const auto stridedSlice = std::make_shared<ov::opset8::Slice>(
parent,
start_constant,
stop_constant,
step_constant,
axes_constant);
stridedSlice->set_friendly_name("slice");

const auto res = std::make_shared<ov::opset1::Result>(stridedSlice);
Expand Down

0 comments on commit 3576d4b

Please sign in to comment.