Skip to content

Commit

Permalink
[LPT] NNCF GroupConvolution 5D on weights support (#16336)
Browse files Browse the repository at this point in the history
* [LPT] NNCF GroupConvolution 5D on weights support

* PullReshapeThroughDequantization rollback
  • Loading branch information
eshoguli authored Mar 23, 2023
1 parent 8a246a8 commit fb24e91
Show file tree
Hide file tree
Showing 9 changed files with 1,170 additions and 594 deletions.
16 changes: 14 additions & 2 deletions src/common/low_precision_transformations/src/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,15 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
Shape newScaleShape = newScalePShape.to_shape();

if (!newScaleShape.empty()) {
// that's all we need: [C, 1, 1, 1] => [C, 1, 1]
newScaleShape.pop_back();
const auto input_shape = convolution->get_input_partial_shape(0);
const auto diff = newScaleShape.size() - input_shape.size();
OPENVINO_ASSERT(
newScaleShape.empty() || ((0 <= diff) && (diff <= 2ull)),
"unexpected shape size on weights");

for (size_t i = 0; i <= diff; ++i) {
newScaleShape.pop_back();
}
}

if (reshapeFromWeights != nullptr) {
Expand Down Expand Up @@ -282,7 +289,12 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph

const size_t weightsRankValue = weightsPShape.rank().get_length();
Shape zeroPointShape(weightsRankValue, 1ul);
// output channel or group
zeroPointShape[0] = static_cast<size_t>(weightsPShape[0].get_length());
if ((reshapeFromWeights == nullptr) && (weightsRankValue == 5ull)) {
// output channel
zeroPointShape[1] = static_cast<size_t>(weightsPShape[1].get_length());
}

auto zeroPointConstant = fold<opset1::Broadcast>(
subtractFromWeights->input_value(1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,16 @@ bool WeightableLayerTransformation::isQuantizedStatic(const std::shared_ptr<cons
FakeQuantizeDequantization dequantizationOnWeights;
if (reshapeIsRequired) {
const auto reshape = layer->get_input_node_shared_ptr(1);
if (!ov::is_type<opset1::Reshape>(reshape)) {
return false;
}
std::shared_ptr<Node> parent = ov::is_type<opset1::Reshape>(reshape) ?
reshape->get_input_node_shared_ptr(0) :
reshape;

if (ov::is_type<opset1::FakeQuantize>(reshape->get_input_node_shared_ptr(0))) {
const std::shared_ptr<opset1::FakeQuantize> fq = ov::as_type_ptr<opset1::FakeQuantize>(reshape->get_input_node_shared_ptr(0));
const auto fq = ov::as_type_ptr<opset1::FakeQuantize>(parent);
if (fq != nullptr) {
return NetworkHelper::isQuantizeSupported(fq);
}

dequantizationOnWeights = NetworkHelper::getDequantization(reshape, defaultPrecisions, 0);
dequantizationOnWeights = NetworkHelper::getDequantization(parent, defaultPrecisions, 0, true);
} else if (ov::is_type<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1))) {
const std::shared_ptr<opset1::FakeQuantize> fq = ov::as_type_ptr<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1));
return NetworkHelper::isQuantizeSupported(fq);
Expand Down
Loading

0 comments on commit fb24e91

Please sign in to comment.