Skip to content

Commit

Permalink
[CPU] FakeQuantize: new cases support (openvinotoolkit#5497)
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev authored May 18, 2021
1 parent 0face0e commit 49a8714
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <mkldnn_types.h>
#include <mkldnn_extension_utils.h>
#include "utils/general_utils.h"
#include "utils/cpu_utils.hpp"

#include <algorithm>
#include <set>
Expand Down Expand Up @@ -841,7 +842,7 @@ bool MKLDNNFakeQuantizeNode::isSupportedOperation(const std::shared_ptr<const ng
}
for (size_t i = 1; i < fq->get_input_size(); i++) {
size_t count_not_unit_axis = 0;
auto shape = fq->get_input_shape(i);
auto shape = getNormalizedDimsBySize(fq->get_input_shape(i), fq->get_input_shape(0).size());

if (ngraph::shape_size(shape) != 1) {
size_t not_unit_axis = 0;
Expand Down Expand Up @@ -885,9 +886,7 @@ MKLDNNFakeQuantizeNode::MKLDNNFakeQuantizeNode(const std::shared_ptr<ngraph::Nod
if (fq->get_output_size() != 1)
IE_THROW() << errorPrefix << "has incorrect number of output edges: " << fq->get_output_size();

auto initAxisIdx = [&](size_t edgeIdx) {
const auto &inputDims = fq->get_input_shape(edgeIdx);

auto initAxisIdx = [&](const ngraph::Shape& inputDims) {
size_t axisIdx = 0;
for (int i = 1; i < inputDims.size(); i++) {
if (inputDims[i] > 1) {
Expand All @@ -898,35 +897,36 @@ MKLDNNFakeQuantizeNode::MKLDNNFakeQuantizeNode(const std::shared_ptr<ngraph::Nod
return axisIdx;
};

axis = fq->get_input_shape(0).size() == 1 ? 0 : 1;
const size_t dataNDims = fq->get_input_shape(0).size();
axis = dataNDims == 1 ? 0 : 1;
int axisSize = -1;

auto inputLowAxis = initAxisIdx(1);
const auto ilShape = fq->get_input_shape(1);
const auto ilShape = getNormalizedDimsBySize(fq->get_input_shape(1), dataNDims);
auto inputLowAxis = initAxisIdx(ilShape);
isInputLowBroadcasted = (ngraph::is_scalar(ilShape) || ilShape[inputLowAxis] == 1);
if (!isInputLowBroadcasted) {
axis = inputLowAxis;
axisSize = ilShape[inputLowAxis];
}

auto inputHighAxis = initAxisIdx(2);
const auto ihShape = fq->get_input_shape(2);
const auto ihShape = getNormalizedDimsBySize(fq->get_input_shape(2), dataNDims);
auto inputHighAxis = initAxisIdx(ihShape);
isInputHighBroadcasted = (ngraph::is_scalar(ihShape) || ihShape[inputHighAxis] == 1);
if (!isInputHighBroadcasted) {
axis = inputHighAxis;
axisSize = ihShape[inputHighAxis];
}

auto outputLowAxis = initAxisIdx(3);
const auto olShape = fq->get_input_shape(3);
const auto olShape = getNormalizedDimsBySize(fq->get_input_shape(3), dataNDims);
auto outputLowAxis = initAxisIdx(olShape);
isOutputLowBroadcasted = (ngraph::is_scalar(olShape) || olShape[outputLowAxis] == 1);
if (!isOutputLowBroadcasted) {
axis = outputLowAxis;
axisSize = olShape[outputLowAxis];
}

auto outputHighAxis = initAxisIdx(4);
const auto ohShape = fq->get_input_shape(4);
const auto ohShape = getNormalizedDimsBySize(fq->get_input_shape(4), dataNDims);
auto outputHighAxis = initAxisIdx(ohShape);
isOutputHighBroadcasted = (ngraph::is_scalar(ohShape) || ohShape[outputHighAxis] == 1);
if (!isOutputHighBroadcasted) {
axis = outputHighAxis;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,27 @@ INSTANTIATE_TEST_CASE_P(smoke_FakeQuantizePerChannelAxis1, FakeQuantizeLayerTest
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(config)),
FakeQuantizeLayerTest::getTestCaseName);

const std::vector<std::vector<size_t>> inputShapesPerChannel2D = {{1, 10}};
const std::vector<std::vector<size_t>> constShapesPerChannel2D = { {10}, {1, 10}, {1} };
const auto fqParamsPerChannel2D = ::testing::Combine(
::testing::ValuesIn(levels),
::testing::ValuesIn(constShapesPerChannel2D),
::testing::Values(fqArgs),
::testing::Values(inputParams)
);

INSTANTIATE_TEST_CASE_P(smoke_FakeQuantizePerChannel2D, FakeQuantizeLayerTest,
::testing::Combine(
fqParamsPerChannel2D,
::testing::ValuesIn(netPrecisions),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::ValuesIn(inputShapesPerChannel2D),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(config)),
FakeQuantizeLayerTest::getTestCaseName);

} // namespace

0 comments on commit 49a8714

Please sign in to comment.