diff --git a/inference-engine/src/low_precision_transformations/include/low_precision_transformations/gemm.hpp b/inference-engine/src/low_precision_transformations/include/low_precision_transformations/gemm.hpp index ba915b89e7aa27..ac5bd226f5cfad 100644 --- a/inference-engine/src/low_precision_transformations/include/low_precision_transformations/gemm.hpp +++ b/inference-engine/src/low_precision_transformations/include/low_precision_transformations/gemm.hpp @@ -7,16 +7,17 @@ #include #include #include -#include "low_precision_transformations/weightable_layer_transformation.hpp" +#include "low_precision_transformations/layer_transformation.hpp" +#include "low_precision_transformations/fully_connected.hpp" namespace InferenceEngine { namespace details { IE_SUPPRESS_DEPRECATED_START -class INFERENCE_ENGINE_API_CLASS(GemmTransformation) : public LayerTransformation { +class INFERENCE_ENGINE_API_CLASS(GemmTransformation) : public FullyConnectedTransformation { public: - GemmTransformation(const Params& params) : LayerTransformation(params) {} + GemmTransformation(const LayerTransformation::Params& params) : FullyConnectedTransformation(params) {} ~GemmTransformation() override {}; bool canBeTransformed(const TransformationContext& context, const CNNLayer& layer) const override; void transform(TransformationContext& context, CNNLayer& layer) const override; diff --git a/inference-engine/src/low_precision_transformations/src/fully_connected.cpp b/inference-engine/src/low_precision_transformations/src/fully_connected.cpp index 4ca8da778e88d4..f1fb2ae14520f0 100644 --- a/inference-engine/src/low_precision_transformations/src/fully_connected.cpp +++ b/inference-engine/src/low_precision_transformations/src/fully_connected.cpp @@ -94,7 +94,7 @@ void FullyConnectedTransformation::transform(TransformationContext& context, CNN return; } - if (!CaselessEq()(fullyConnected.type, "FullyConnected")) { + if ((!CaselessEq()(fullyConnected.type, "FullyConnected")) && (!CaselessEq()(fullyConnected.type, "Gemm"))) { THROW_IE_EXCEPTION << "layer '" << fullyConnected.name << "' is not correct"; } @@ -244,6 +244,7 @@ void FullyConnectedTransformation::transform(TransformationContext& context, CNN updatePrecisions ? CNNNetworkHelper::quantizeWeights(*parentOnWeights, roundQuantizedValues, dataPrecision.precision) : CNNNetworkHelper::quantizeWeights(*parentOnWeights, roundQuantizedValues); + const std::vector constLayers = CNNNetworkHelper::transformFakeQuantizeToConst( context, parentOnWeights, weights, CNNNetworkHelper::getParent(*parentOnWeights, 0)->name); diff --git a/inference-engine/src/low_precision_transformations/src/gemm.cpp b/inference-engine/src/low_precision_transformations/src/gemm.cpp index d60729e441cb44..a24c00139e47e2 100644 --- a/inference-engine/src/low_precision_transformations/src/gemm.cpp +++ b/inference-engine/src/low_precision_transformations/src/gemm.cpp @@ -40,9 +40,9 @@ bool GemmTransformation::canBeTransformed(const TransformationContext& context, } const size_t inputChannelsCount = CNNNetworkHelper::getInputChannelsCount(gemm); + std::vector parents = CNNNetworkHelper::getParents(gemm); const auto checkDequantizationLayer = [&](const CNNLayer& gemm, const size_t index) -> bool { - std::vector parents = CNNNetworkHelper::getParents(gemm); if (parents.size() <= index) { return false; } @@ -53,7 +53,7 @@ bool GemmTransformation::canBeTransformed(const TransformationContext& context, std::vector scales; std::vector shifts; - LayerTransformation::fillFromDequantizationLayer(*scaleShift, scales, shifts); + fillFromDequantizationLayer(*scaleShift, scales, shifts); if (scales.size() != inputChannelsCount) { return false; @@ -73,8 +73,24 @@ bool GemmTransformation::canBeTransformed(const TransformationContext& context, }; if ((CNNNetworkHelper::getParents(gemm).size() != 2ul) || - (!checkDequantizationLayer(gemm, 0ul)) || - (!checkDequantizationLayer(gemm, 1ul))) { + (!checkDequantizationLayer(gemm, 0ul))) { + return false; + } + + if (parents[1]->type == "FakeQuantize") { + if (!QuantizationDetails::isSupportedLevel(parents[1]->GetParamAsUInt("levels"))) { + return false; + } + + const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(*parents[1]); + const DataPrecision dataPrecision = getDataPrecision(*parents[1], quantizationDetails, false, false); + if (dataPrecision.precision == Precision::UNSPECIFIED) { + return false; + } + } + + if (((parents[1]->type != "FakeQuantize") && (!checkDequantizationLayer(gemm, 1ul))) || + ((parents[1]->type == "FakeQuantize") && (!CNNNetworkHelper::onConstWeightsPath(*parents[1]) || !CNNNetworkHelper::onWeights(*parents[1])))) { return false; } @@ -91,27 +107,29 @@ void GemmTransformation::transform(TransformationContext& context, CNNLayer& gem } std::vector parents = CNNNetworkHelper::getParents(gemm); + if (parents[1]->type == "FakeQuantize") { + FullyConnectedTransformation::transform(context, gemm); + return; + } std::vector originalDataDequantizationScales1; std::vector originalDataDequantizationShifts1; fillFromDequantizationLayer(*parents[0], originalDataDequantizationScales1, originalDataDequantizationShifts1); - std::vector originalDataDequantizationScales2; std::vector originalDataDequantizationShifts2; fillFromDequantizationLayer(*parents[1], originalDataDequantizationScales2, originalDataDequantizationShifts2); const size_t outputChannelsCount = CNNNetworkHelper::getOutputChannelsCount(gemm); - std::vector dequantizationScales(outputChannelsCount); + std::vector dequantizationScales(outputChannelsCount, originalDataDequantizationScales1[0] * originalDataDequantizationScales2[0]); std::vector dequantizationShifts(outputChannelsCount, 0.f); - for (size_t outputChannel = 0ul; outputChannel < outputChannelsCount; ++outputChannel) { - dequantizationScales[outputChannel] = originalDataDequantizationScales1[0] * originalDataDequantizationScales2[0]; - } CNNNetworkHelper::removeLayer(context.network, parents[0]); context.removeLayer(*parents[0]); - CNNNetworkHelper::removeLayer(context.network, parents[1]); - context.removeLayer(*parents[1]); + if (parents[1]->type != "FakeQuantize") { + CNNNetworkHelper::removeLayer(context.network, parents[1]); + context.removeLayer(*parents[1]); + } addDequantizationLayer(context, gemm, dequantizationScales, dequantizationShifts); }