Skip to content

Commit

Permalink
[LPT] Minor fixes; review from openvinotoolkit#5313
Browse files Browse the repository at this point in the history
  • Loading branch information
vzinovie committed Apr 21, 2021
1 parent b6fbe88 commit 2edd806
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace low_precision {
class TRANSFORMATIONS_API ConvolutionTransformation : public WeightableLayerTransformation {
public:
ConvolutionTransformation(const Params& params);
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) const override;
bool isQuantized(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TRANSFORMATIONS_API WeightableLayerTransformation : public LayerTransforma
public:
WeightableLayerTransformation(const Params& params);
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool canConvolutionBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const;
bool isQuantized(std::shared_ptr<Node> layer, bool reshapeIsRequired) const noexcept;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ bool ConvolutionTransformation::isQuantized(std::shared_ptr<Node> layer) const n
bool ConvolutionTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
auto convolution = m.get_match_root();

if (!canBeTransformed(context, convolution)) {
if (!canConvolutionBeTransformed(context, convolution)) {
return false;
}

Expand Down Expand Up @@ -281,63 +281,6 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
}
return true;
}

bool ConvolutionTransformation::canBeTransformed(const TransformationContext &context,
std::shared_ptr<Node> layer) const {
if (!WeightableLayerTransformation::canBeTransformed(context, layer)) {
return false;
}

FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(layer);
if (!canSubtractBeHandled(layer, dequantization)) {
return false;
}

if (updatePrecisions && !NetworkHelper::checkZeroPoint(dequantization.subtract)) {
return false;
}

if (updatePrecisions && !dequantization.empty() && !dequantization.isLowPrecision()) {
return false;
}

std::shared_ptr<opset1::Reshape> reshapeFromWeights = as_type_ptr<opset1::Reshape>(layer->get_input_node_shared_ptr(1));
dequantization = reshapeFromWeights == nullptr ?
NetworkHelper::getDequantization(layer, 1ul) :
NetworkHelper::getDequantization(reshapeFromWeights);

const auto fqOnWeights = getFakeQuantizeOnWeights(layer);
if (dequantization.empty()) {
const auto dataPrecision = getDataPrecisionOnWeights(layer);
if ((!supportAsymmetricQuantization) && dataPrecision.hasZeroPoint) {
return false;
}
if (updatePrecisions && !NetworkHelper::checkZeroPoint(fqOnWeights, dataPrecision)) {
const std::shared_ptr<ngraph::Node> resultConstant = NetworkHelper::fold_fake_quantize(fqOnWeights);
if (as_type_ptr<opset1::Constant>(resultConstant)) {
replace_node(fqOnWeights, resultConstant);
}
return false;
}
} else {
if (updatePrecisions && !NetworkHelper::checkZeroPoint(dequantization.subtract)) {
const auto resultDequantization = NetworkHelper::foldDequantization(dequantization.multiply, 0, true);
if (resultDequantization.empty() && reshapeFromWeights) {
const auto foldedReshape = fold<opset1::Reshape>(
reshapeFromWeights->get_input_node_shared_ptr(0),
reshapeFromWeights->get_input_node_shared_ptr(1),
reshapeFromWeights->get_special_zero());
if (is_type<opset1::Constant>(foldedReshape)) {
replace_node(reshapeFromWeights, foldedReshape);
}
}
return false;
}
}

return true;
}

} // namespace low_precision
} // namespace pass
} // namespace ngraph
Original file line number Diff line number Diff line change
Expand Up @@ -1544,7 +1544,13 @@ bool NetworkHelper::checkZeroPoint(const std::shared_ptr<Node>& node, const Data
}
auto subtractConst = as_type_ptr<opset1::Constant>(node->get_input_node_shared_ptr(1));
if (!subtractConst) {
if (is_type<opset1::Convert>(subtractConst)) {
return false;
}
subtractConst = as_type_ptr<opset1::Constant>(node->get_input_node_shared_ptr(1)->get_input_node_shared_ptr(0));
if (subtractConst == nullptr) {
return false;
}
}
const auto subtractValues = subtractConst->cast_vector<float>();
if (std::any_of(subtractValues.begin(), subtractValues.end(), [min, max] (const float& val) {
Expand All @@ -1559,9 +1565,14 @@ bool NetworkHelper::checkZeroPoint(const std::shared_ptr<Node>& node, const Data
max = dataPrecision.max + 0.5f;
const auto quantizationDetails = QuantizationDetails::getDetails(as_type_ptr<opset1::FakeQuantize>(node));
for (size_t i = 0; i < quantizationDetails.outputIntervalsCount; ++i) {
const float shift =
(dataPrecision.min * quantizationDetails.outputHighValues[i] - dataPrecision.max * quantizationDetails.outputLowValues[i]) /
(quantizationDetails.outputHighValues[i] - quantizationDetails.outputLowValues[i]);
float shift;
if (quantizationDetails.outputHighValues[i] != quantizationDetails.outputLowValues[i]) {
shift = (dataPrecision.min * quantizationDetails.outputHighValues[i] -
dataPrecision.max * quantizationDetails.outputLowValues[i]) /
(quantizationDetails.outputHighValues[i] - quantizationDetails.outputLowValues[i]);
} else {
shift = 0.f;
}
if (shift < min || shift > max) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,61 @@ namespace low_precision {

WeightableLayerTransformation::WeightableLayerTransformation(const Params& params) : LayerTransformation(params) {}

bool WeightableLayerTransformation::canConvolutionBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
if (!WeightableLayerTransformation::canBeTransformed(context, layer)) {
return false;
}

FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(layer);
if (!canSubtractBeHandled(layer, dequantization)) {
return false;
}

if (updatePrecisions && !NetworkHelper::checkZeroPoint(dequantization.subtract)) {
return false;
}

if (updatePrecisions && !dequantization.empty() && !dequantization.isLowPrecision()) {
return false;
}

std::shared_ptr<opset1::Reshape> reshapeFromWeights = as_type_ptr<opset1::Reshape>(layer->get_input_node_shared_ptr(1));
dequantization = reshapeFromWeights == nullptr ?
NetworkHelper::getDequantization(layer, 1ul) :
NetworkHelper::getDequantization(reshapeFromWeights);

if (dequantization.empty()) {
const auto fqOnWeights = getFakeQuantizeOnWeights(layer);
const auto dataPrecision = getDataPrecisionOnWeights(layer);
if ((!supportAsymmetricQuantization) && dataPrecision.hasZeroPoint) {
return false;
}
if (!NetworkHelper::checkZeroPoint(fqOnWeights, dataPrecision)) {
const std::shared_ptr<ngraph::Node> resultConstant = NetworkHelper::fold_fake_quantize(fqOnWeights);
if (as_type_ptr<opset1::Constant>(resultConstant)) {
replace_node(fqOnWeights, resultConstant);
}
return false;
}
} else {
if (!NetworkHelper::checkZeroPoint(dequantization.subtract)) {
const auto resultDequantization = NetworkHelper::foldDequantization(dequantization.multiply, 0, true);
if (resultDequantization.empty() && reshapeFromWeights) {
const auto foldedReshape = fold<opset1::Reshape>(
reshapeFromWeights->get_input_node_shared_ptr(0),
reshapeFromWeights->get_input_node_shared_ptr(1),
reshapeFromWeights->get_special_zero());
if (is_type<opset1::Constant>(foldedReshape)) {
replace_node(reshapeFromWeights, foldedReshape);
}
}
return false;
}
}

return true;
}

bool WeightableLayerTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
if (!LayerTransformation::canBeTransformed(context, layer)) {
return false;
Expand Down

0 comments on commit 2edd806

Please sign in to comment.