Skip to content

Commit

Permalink
[LPT] Gemm and FullyConnected 3D improvement #2
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Jun 7, 2020
1 parent a738f64 commit cca8b9d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@
#include <vector>
#include <ie_common.h>
#include <algorithm>
#include "low_precision_transformations/weightable_layer_transformation.hpp"
#include "low_precision_transformations/layer_transformation.hpp"
#include "low_precision_transformations/fully_connected.hpp"

namespace InferenceEngine {
namespace details {

IE_SUPPRESS_DEPRECATED_START

class INFERENCE_ENGINE_API_CLASS(GemmTransformation) : public LayerTransformation {
class INFERENCE_ENGINE_API_CLASS(GemmTransformation) : public FullyConnectedTransformation {
public:
GemmTransformation(const Params& params) : LayerTransformation(params) {}
GemmTransformation(const LayerTransformation::Params& params) : FullyConnectedTransformation(params) {}
~GemmTransformation() override {};
bool canBeTransformed(const TransformationContext& context, const CNNLayer& layer) const override;
void transform(TransformationContext& context, CNNLayer& layer) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void FullyConnectedTransformation::transform(TransformationContext& context, CNN
return;
}

if (!CaselessEq<std::string>()(fullyConnected.type, "FullyConnected")) {
if ((!CaselessEq<std::string>()(fullyConnected.type, "FullyConnected")) && (!CaselessEq<std::string>()(fullyConnected.type, "Gemm"))) {
THROW_IE_EXCEPTION << "layer '" << fullyConnected.name << "' is not correct";
}

Expand Down Expand Up @@ -244,6 +244,7 @@ void FullyConnectedTransformation::transform(TransformationContext& context, CNN
updatePrecisions
? CNNNetworkHelper::quantizeWeights(*parentOnWeights, roundQuantizedValues, dataPrecision.precision)
: CNNNetworkHelper::quantizeWeights(*parentOnWeights, roundQuantizedValues);

const std::vector<CNNLayerPtr> constLayers = CNNNetworkHelper::transformFakeQuantizeToConst(
context, parentOnWeights, weights, CNNNetworkHelper::getParent(*parentOnWeights, 0)->name);

Expand Down
40 changes: 29 additions & 11 deletions inference-engine/src/low_precision_transformations/src/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ bool GemmTransformation::canBeTransformed(const TransformationContext& context,
}

const size_t inputChannelsCount = CNNNetworkHelper::getInputChannelsCount(gemm);
std::vector<CNNLayerPtr> parents = CNNNetworkHelper::getParents(gemm);

const auto checkDequantizationLayer = [&](const CNNLayer& gemm, const size_t index) -> bool {
std::vector<CNNLayerPtr> parents = CNNNetworkHelper::getParents(gemm);
if (parents.size() <= index) {
return false;
}
Expand All @@ -53,7 +53,7 @@ bool GemmTransformation::canBeTransformed(const TransformationContext& context,

std::vector<float> scales;
std::vector<float> shifts;
LayerTransformation::fillFromDequantizationLayer(*scaleShift, scales, shifts);
fillFromDequantizationLayer(*scaleShift, scales, shifts);

if (scales.size() != inputChannelsCount) {
return false;
Expand All @@ -73,8 +73,24 @@ bool GemmTransformation::canBeTransformed(const TransformationContext& context,
};

if ((CNNNetworkHelper::getParents(gemm).size() != 2ul) ||
(!checkDequantizationLayer(gemm, 0ul)) ||
(!checkDequantizationLayer(gemm, 1ul))) {
(!checkDequantizationLayer(gemm, 0ul))) {
return false;
}

if (parents[1]->type == "FakeQuantize") {
if (!QuantizationDetails::isSupportedLevel(parents[1]->GetParamAsUInt("levels"))) {
return false;
}

const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(*parents[1]);
const DataPrecision dataPrecision = getDataPrecision(*parents[1], quantizationDetails, false, false);
if (dataPrecision.precision == Precision::UNSPECIFIED) {
return false;
}
}

if (((parents[1]->type != "FakeQuantize") && (!checkDequantizationLayer(gemm, 1ul))) ||
((parents[1]->type == "FakeQuantize") && (!CNNNetworkHelper::onConstWeightsPath(*parents[1]) || !CNNNetworkHelper::onWeights(*parents[1])))) {
return false;
}

Expand All @@ -91,27 +107,29 @@ void GemmTransformation::transform(TransformationContext& context, CNNLayer& gem
}

std::vector<CNNLayerPtr> parents = CNNNetworkHelper::getParents(gemm);
if (parents[1]->type == "FakeQuantize") {
FullyConnectedTransformation::transform(context, gemm);
return;
}

std::vector<float> originalDataDequantizationScales1;
std::vector<float> originalDataDequantizationShifts1;
fillFromDequantizationLayer(*parents[0], originalDataDequantizationScales1, originalDataDequantizationShifts1);

std::vector<float> originalDataDequantizationScales2;
std::vector<float> originalDataDequantizationShifts2;
fillFromDequantizationLayer(*parents[1], originalDataDequantizationScales2, originalDataDequantizationShifts2);

const size_t outputChannelsCount = CNNNetworkHelper::getOutputChannelsCount(gemm);
std::vector<float> dequantizationScales(outputChannelsCount);
std::vector<float> dequantizationScales(outputChannelsCount, originalDataDequantizationScales1[0] * originalDataDequantizationScales2[0]);
std::vector<float> dequantizationShifts(outputChannelsCount, 0.f);
for (size_t outputChannel = 0ul; outputChannel < outputChannelsCount; ++outputChannel) {
dequantizationScales[outputChannel] = originalDataDequantizationScales1[0] * originalDataDequantizationScales2[0];
}

CNNNetworkHelper::removeLayer(context.network, parents[0]);
context.removeLayer(*parents[0]);

CNNNetworkHelper::removeLayer(context.network, parents[1]);
context.removeLayer(*parents[1]);
if (parents[1]->type != "FakeQuantize") {
CNNNetworkHelper::removeLayer(context.network, parents[1]);
context.removeLayer(*parents[1]);
}

addDequantizationLayer(context, gemm, dequantizationScales, dequantizationShifts);
}

0 comments on commit cca8b9d

Please sign in to comment.