Skip to content

Commit

Permalink
[GNA] Added exception for scale factor <= 0
Browse files Browse the repository at this point in the history
  • Loading branch information
elilobanova committed Sep 2, 2021
1 parent 2884472 commit 31a0f4b
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 49 deletions.
47 changes: 47 additions & 0 deletions inference-engine/src/gna_plugin/frontend/layer_quantizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,5 +699,52 @@ using QuantI8_I8 = frontend::QuantPair<frontend::QuantI8_I8, frontend::QuantI8_I
using FakeQuantI16 = frontend::QuantPair<frontend::FakeQuantI16, frontend::FakeQuantI16>;
using FakeQuantI8 = frontend::QuantPair<frontend::FakeQuantI8, frontend::FakeQuantI16>;

enum class QuantizedDataType {
input,
output,
weights,
bias
};

/**
* @brief Returns a scale factor for specific layer data
* @param layer Layer to be quantized
* @param data_type Type of data to be quantized
* @return scale factor
*/
inline float getScaleFactor(InferenceEngine::CNNLayerPtr layer, QuantizedDataType data_type) {
auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);
float scale_factor;
if (!quantized) {
scale_factor = 1.0f;
} else {
switch (data_type) {
case QuantizedDataType::input:
scale_factor = quantized->_src_quant.GetScale();
break;
case QuantizedDataType::output:
scale_factor = quantized->_dst_quant.GetScale();
break;
case QuantizedDataType::weights:
scale_factor = quantized->_weights_quant.GetScale();
break;
case QuantizedDataType::bias:
scale_factor = quantized->_bias_quant.GetScale();
break;
default:
THROW_GNA_LAYER_EXCEPTION(layer) << "Unsupported data type for quantization: " << static_cast<int>(data_type);
}
}

auto isZero = [](float p1) {
return std::abs(p1) <= 0.00001f;
};

if (scale_factor < 0.0 || isZero(scale_factor)) {
THROW_GNA_LAYER_EXCEPTION(layer) << "Invalid scale factor: " << scale_factor;
}

return scale_factor;
}

} // namespace GNAPluginNS
80 changes: 31 additions & 49 deletions inference-engine/src/gna_plugin/gna_graph_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,9 @@ void GNAGraphCompiler::finalizeConvolution1DPrimitive(InferenceEngine::CNNLayerP
uint32_t num_bytes_per_weight = convolution._weights->getTensorDesc().getPrecision().size();
uint32_t num_bytes_per_bias = biasPrecision.size();

float weight_scale_factor = 1.0f;
float output_scale_factor = 1.0f;
auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(convolution);
if (quantized != nullptr) {
weight_scale_factor = quantized->_weights_quant.GetScale();
output_scale_factor = quantized->_dst_quant.GetScale();
}
float weight_scale_factor = getScaleFactor(layer, QuantizedDataType::weights);
float output_scale_factor = getScaleFactor(layer, QuantizedDataType::output);

auto& currentComponent = dnnComponents.addComponent(convolution.name, "convolution");
dnn->InitConvolutional1DComponent(currentComponent,
num_columns_in,
Expand Down Expand Up @@ -586,13 +582,8 @@ void GNAGraphCompiler::finalizeConvolution2DPrimitive(InferenceEngine::CNNLayerP
in_height, in_width, in_channels,
convolution._kernel_y, convolution._kernel_x, filter_n, convolution._stride_y, convolution._stride_x, inputPrec);

float weight_scale_factor = 1.0f;
float output_scale_factor = 1.0f;
auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(convolution);
if (quantized != nullptr) {
weight_scale_factor = quantized->_weights_quant.GetScale();
output_scale_factor = quantized->_dst_quant.GetScale();
}
float weight_scale_factor = getScaleFactor(layer, QuantizedDataType::weights);
float output_scale_factor = getScaleFactor(layer, QuantizedDataType::output);

auto& currentComponent = dnnComponents.addComponent(convolution.name, "convolution");
dnn->InitConvolutional2DComponent(currentComponent,
Expand Down Expand Up @@ -673,9 +664,6 @@ void GNAGraphCompiler::finalizeConvolution2DPrimitive(InferenceEngine::CNNLayerP

void GNAGraphCompiler::PowerPrimitive(InferenceEngine::CNNLayerPtr layer) {
auto& power = dynamic_cast<PowerLayer&>(*layer.get());
auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);
IE_ASSERT(gnaFlags->sw_fp32 ? (quantized == nullptr) : (quantized != nullptr));

if (power.power < 0.0f || power.power > 2.8f) {
IE_THROW() << "[GNA plugin] unsupported power factor, expected be in <0, 2.8> range but was " << power.power;
}
Expand Down Expand Up @@ -705,6 +693,8 @@ void GNAGraphCompiler::PowerPrimitive(InferenceEngine::CNNLayerPtr layer) {

auto& currentComponent = dnnComponents.addComponent(layer->name, "power");

auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);
IE_ASSERT(gnaFlags->sw_fp32 ? (quantized == nullptr) : (quantized != nullptr));
dnn->InitAffineComponent(currentComponent,
num_rows_in + num_padding,
num_columns_in,
Expand Down Expand Up @@ -764,8 +754,8 @@ void GNAGraphCompiler::PowerPrimitive(InferenceEngine::CNNLayerPtr layer) {

gna_pwl_segment_t* ptr_pwl_segments_target = nullptr;

float output_pwl_scale_factor = quantized != nullptr ? quantized->_dst_quant.GetScale() : 1.0f;
float input_pwl_scale_factor = quantized != nullptr ? quantized->_src_quant.GetScale() : 1.0f;
float output_pwl_scale_factor = getScaleFactor(layer, QuantizedDataType::output);
float input_pwl_scale_factor = getScaleFactor(layer, QuantizedDataType::input);

if (!gnaFlags->sw_fp32) {
if (gnaFlags->uniformPwlDesign) {
Expand Down Expand Up @@ -823,7 +813,6 @@ void GNAGraphCompiler::PowerPrimitive(InferenceEngine::CNNLayerPtr layer) {

void GNAGraphCompiler::PoolingPrimitive(InferenceEngine::CNNLayerPtr layer) {
auto& pooling = dynamic_cast<PoolingLayer&>(*layer.get());
auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);

IE_ASSERT(!layer->insData.empty());
IE_ASSERT(!layer->outData.empty());
Expand Down Expand Up @@ -883,7 +872,7 @@ void GNAGraphCompiler::PoolingPrimitive(InferenceEngine::CNNLayerPtr layer) {
outputs->getPrecision().size(),
{ pooling._kernel[X_AXIS], pooling._kernel[Y_AXIS] },
{ pooling._stride[X_AXIS], pooling._stride[Y_AXIS] },
quantized == nullptr ? 1 : quantized->_dst_quant.GetScale(),
getScaleFactor(layer, QuantizedDataType::output),
ptr_inputs,
ptr_outputs);

Expand All @@ -901,8 +890,6 @@ void GNAGraphCompiler::PoolingPrimitive(InferenceEngine::CNNLayerPtr layer) {
}

void GNAGraphCompiler::CopyPrimitive(InferenceEngine::CNNLayerPtr layer) {
auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);

IE_ASSERT(!layer->insData.empty());
IE_ASSERT(!layer->outData.empty());
auto inputs = layer->insData.begin()->lock();
Expand All @@ -928,7 +915,7 @@ void GNAGraphCompiler::CopyPrimitive(InferenceEngine::CNNLayerPtr layer) {
num_columns_out,
inputs->getPrecision().size(),
outputs->getPrecision().size(),
quantized == nullptr ? 1 : quantized->_dst_quant.GetScale(),
getScaleFactor(layer, QuantizedDataType::output),
num_rows_out + num_padding_out,
num_columns_out,
ptr_inputs,
Expand Down Expand Up @@ -1053,7 +1040,6 @@ void GNAGraphCompiler::CropPrimitive(InferenceEngine::CNNLayerPtr layer) {
<< axis.size() << ".";
}

auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);
size_t cropOffset = offset.front() * cropLayer->precision.size();
size_t cropOutputSize = dim.front() * cropLayer->precision.size();
const uint32_t noOfInputsDivisor = gnaFlags->input_low_precision ?
Expand Down Expand Up @@ -1111,6 +1097,7 @@ void GNAGraphCompiler::CropPrimitive(InferenceEngine::CNNLayerPtr layer) {

auto& currentComponent = dnnComponents.addComponent(layer->name, "crop");

auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);
dnn->InitAffineComponent(currentComponent,
num_rows_in + num_padding,
num_columns_in,
Expand All @@ -1119,8 +1106,8 @@ void GNAGraphCompiler::CropPrimitive(InferenceEngine::CNNLayerPtr layer) {
outputs->getPrecision().size(),
quantized == nullptr ? inputs->getPrecision().size() : (gnaFlags->input_low_precision ? 1 : 2),
gnaFlags->input_low_precision ? 1 : 4,
quantized == nullptr ? 1 : quantized->_weights_quant.GetScale(),
quantized == nullptr ? 1 : quantized->_dst_quant.GetScale(),
getScaleFactor(layer, QuantizedDataType::weights),
getScaleFactor(layer, QuantizedDataType::output),
ptr_inputs,
ptr_outputs,
ptr_weights,
Expand Down Expand Up @@ -1254,8 +1241,8 @@ void GNAGraphCompiler::EltwisePrimitive(InferenceEngine::CNNLayerPtr layer) {
// TODO: only fp32 and Int16 tested
quantized == nullptr ? inputs2Bytes->getPrecision().size() : (gnaFlags->input_low_precision ? 1 : 2),
quantized == nullptr ? inputs4Bytes->getPrecision().size() : (gnaFlags->input_low_precision ? 1 : 4),
quantized == nullptr ? 1 : quantized->_weights_quant.GetScale(),
quantized == nullptr ? 1 : quantized->_dst_quant.GetScale(),
getScaleFactor(layer, QuantizedDataType::weights),
getScaleFactor(layer, QuantizedDataType::output),
ptr_inputs,
ptr_outputs,
ptr_weights,
Expand Down Expand Up @@ -1363,8 +1350,8 @@ void GNAGraphCompiler::GemmPrimitive(InferenceEngine::CNNLayerPtr layer) {
outputs->getPrecision().size(),
quantized == nullptr ? input_2->getPrecision().size() : 2,
quantized == nullptr ? input_2->getPrecision().size() : 4,
quantized == nullptr ? 1 : quantized->_weights_quant.GetScale(),
quantized == nullptr ? 1 : quantized->_dst_quant.GetScale(),
getScaleFactor(layer, QuantizedDataType::weights),
getScaleFactor(layer, QuantizedDataType::output),
ptr_input_1,
ptr_outputs,
ptr_input_2,
Expand Down Expand Up @@ -1452,8 +1439,8 @@ void GNAGraphCompiler::AffinePrimitive(InferenceEngine::CNNLayerPtr layer, bool
outputs->getPrecision().size(),
weightable._weights->getTensorDesc().getPrecision().size(),
biasPrecisionSize,
quantized == nullptr ? 1 : quantized->_weights_quant.GetScale(),
quantized == nullptr ? 1 : quantized->_dst_quant.GetScale(),
getScaleFactor(layer, QuantizedDataType::weights),
getScaleFactor(layer, QuantizedDataType::output),
ptr_inputs,
ptr_outputs,
ptr_weights,
Expand Down Expand Up @@ -1592,8 +1579,6 @@ void GNAGraphCompiler::ConcatAlignFilterPrimitive(InferenceEngine::CNNLayerPtr l
return;
}

auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);

void* ptr_inputs = nullptr;
void* ptr_outputs = nullptr;
void* ptr_weights = nullptr;
Expand Down Expand Up @@ -1632,7 +1617,7 @@ void GNAGraphCompiler::ConcatAlignFilterPrimitive(InferenceEngine::CNNLayerPtr l
num_columns_in,
inputs->getPrecision().size(),
inputs->getPrecision().size(),
quantized == nullptr ? 1 : quantized->_dst_quant.GetScale(),
getScaleFactor(layer, QuantizedDataType::output),
num_rows_copied,
num_columns_in,
ptr_inputs,
Expand Down Expand Up @@ -1669,8 +1654,8 @@ void GNAGraphCompiler::ConcatAlignFilterPrimitive(InferenceEngine::CNNLayerPtr l
outputs->getPrecision().size(),
filterLayer->_weights->getTensorDesc().getPrecision().size(),
biasPrecisionSize,
quantized == nullptr ? 1 : quantized->_weights_quant.GetScale(),
quantized == nullptr ? 1 : quantized->_dst_quant.GetScale(),
getScaleFactor(layer, QuantizedDataType::weights),
getScaleFactor(layer, QuantizedDataType::output),
ptr_inputs,
ptr_outputs,
ptr_weights,
Expand Down Expand Up @@ -1726,8 +1711,6 @@ void GNAGraphCompiler::ConvolutionFilterPrimitive(InferenceEngine::CNNLayerPtr l
return;
}

auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);

auto prevLayer = CNNNetPrevLayer(layer.get(), 0);
if (!LayerInfo(prevLayer).isSplit() && !LayerInfo(prevLayer).isSlice()) {
THROW_GNA_EXCEPTION << "Case with Affine Aligning Filter for not Split/Slice layers is not implemented yet!";
Expand Down Expand Up @@ -1774,8 +1757,8 @@ void GNAGraphCompiler::ConvolutionFilterPrimitive(InferenceEngine::CNNLayerPtr l
numberOfFilters,
filterWidth,
convolutionStride,
quantized == nullptr ? 1 : quantized->_weights_quant.GetScale(),
quantized == nullptr ? 1 : quantized->_dst_quant.GetScale(),
getScaleFactor(layer, QuantizedDataType::weights),
getScaleFactor(layer, QuantizedDataType::output),
ptr_inputs,
ptr_outputs,
ptr_weights,
Expand Down Expand Up @@ -1834,9 +1817,8 @@ void GNAGraphCompiler::PWLPrimitive(InferenceEngine::CNNLayerPtr layer) {

auto inputs = layer->insData.begin()->lock();
auto outputs = *layer->outData.begin();
auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);
float output_pwl_scale_factor = quantized != nullptr ? quantized->_dst_quant.GetScale() : 1.0f;
float input_pwl_scale_factor = quantized != nullptr ? quantized->_src_quant.GetScale() : 1.0f;
float output_pwl_scale_factor = getScaleFactor(layer, QuantizedDataType::output);
float input_pwl_scale_factor = getScaleFactor(layer, QuantizedDataType::input);

auto orientation = kDnnInterleavedOrientation;

Expand Down Expand Up @@ -1903,6 +1885,7 @@ void GNAGraphCompiler::PWLPrimitive(InferenceEngine::CNNLayerPtr layer) {
}
auto activation_type = DnnActivation::fromType(it->second);
activation_type.fqParams.set = false;
auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);
if (quantized != nullptr && quantized->_dst_quant.IsStatsSet()) {
activation_type.fqParams.set = true;
activation_type.fqParams.levels = quantized->_dst_quant.GetLevels();
Expand Down Expand Up @@ -2044,7 +2027,6 @@ void GNAGraphCompiler::PermutePrimitive(InferenceEngine::CNNLayerPtr layer) {
return;
}
auto layerOrder = layer->GetParamAsInts("order");
auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);
if (layer->insData.empty()) {
THROW_GNA_LAYER_EXCEPTION(layer) << "Input layer pointer is unexpectedly absent";
}
Expand Down Expand Up @@ -2088,7 +2070,7 @@ void GNAGraphCompiler::PermutePrimitive(InferenceEngine::CNNLayerPtr layer) {
squeezedInputOrder[1],
inputs->getPrecision().size(),
outputs->getPrecision().size(),
(quantized == nullptr) ? 1.0f : quantized->_dst_quant.GetScale(),
getScaleFactor(layer, QuantizedDataType::output),
ptr_inputs,
ptr_outputs);
}
Expand All @@ -2103,7 +2085,7 @@ void GNAGraphCompiler::PermutePrimitive(InferenceEngine::CNNLayerPtr layer) {
squeezedInputOrder[1],
inputs->getPrecision().size(),
outputs->getPrecision().size(),
quantized == nullptr ? 1 : quantized->_dst_quant.GetScale(),
getScaleFactor(layer, QuantizedDataType::output),
ptr_inputs,
ptr_outputs);
}
Expand Down Expand Up @@ -2595,4 +2577,4 @@ GNAGraphCompiler::transposeMatrix(uint8_t* ptr_matrix, size_t element_size, uint
}
}
return temp_buffer;
}
}
48 changes: 48 additions & 0 deletions inference-engine/tests/unit/gna/gna_get_scale_factor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>

#include <gtest/gtest.h>
// to suppress deprecated definition errors
#define IMPLEMENT_INFERENCE_ENGINE_PLUGIN
#include "legacy/layer_transform.hpp"
#include "frontend/layer_quantizer.hpp"

namespace {

class GnaGetScaleFactorTest : public ::testing::Test {
protected:
void GetScaleFactorAndCheck(float src_scale, float dst_scale, float weights_scale, float bias_scale) const {
InferenceEngine::LayerParams params("fc", "FullyConnected", InferenceEngine::Precision::FP32);
InferenceEngine::CNNLayerPtr layer = std::make_shared<InferenceEngine::CNNLayer>(params);
layer = InferenceEngine::injectData<GNAPluginNS::QuantizedLayerParams>(*layer);
auto quant = InferenceEngine::getInjectedData<GNAPluginNS::QuantizedLayerParams>(*layer);
quant->_src_quant.SetScale(src_scale);
quant->_dst_quant.SetScale(dst_scale);
quant->_weights_quant.SetScale(weights_scale);
quant->_bias_quant.SetScale(bias_scale);
ASSERT_EQ(GNAPluginNS::getScaleFactor(layer, GNAPluginNS::QuantizedDataType::input), src_scale);
ASSERT_EQ(GNAPluginNS::getScaleFactor(layer, GNAPluginNS::QuantizedDataType::output), dst_scale);
ASSERT_EQ(GNAPluginNS::getScaleFactor(layer, GNAPluginNS::QuantizedDataType::weights), weights_scale);
ASSERT_EQ(GNAPluginNS::getScaleFactor(layer, GNAPluginNS::QuantizedDataType::bias), bias_scale);
}
};

TEST_F(GnaGetScaleFactorTest, validSF) {
EXPECT_NO_THROW(GetScaleFactorAndCheck(100, 200, 300, 400));
}

TEST_F(GnaGetScaleFactorTest, invalidSF) {
EXPECT_ANY_THROW(GetScaleFactorAndCheck(0, 200, 300, 400));
EXPECT_ANY_THROW(GetScaleFactorAndCheck(100, 0, 300, 400));
EXPECT_ANY_THROW(GetScaleFactorAndCheck(100, 200, 0, 400));
EXPECT_ANY_THROW(GetScaleFactorAndCheck(100, 200, 300, 0));
EXPECT_ANY_THROW(GetScaleFactorAndCheck(-100, 200, 300, 400));
EXPECT_ANY_THROW(GetScaleFactorAndCheck(100, -200, 300, 400));
EXPECT_ANY_THROW(GetScaleFactorAndCheck(100, 200, -300, 400));
EXPECT_ANY_THROW(GetScaleFactorAndCheck(100, 200, 300, -400));
}

} // namespace

0 comments on commit 31a0f4b

Please sign in to comment.