Skip to content

Commit

Permalink
[LPT] NNCF GroupConvolution 5D on weights support
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Mar 20, 2023
1 parent 497b788 commit 3f7660d
Show file tree
Hide file tree
Showing 7 changed files with 1,023 additions and 498 deletions.
18 changes: 16 additions & 2 deletions src/common/low_precision_transformations/src/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,19 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
Shape newScaleShape = newScalePShape.to_shape();

if (!newScaleShape.empty()) {
// that's all we need: [C, 1, 1, 1] => [C, 1, 1]
newScaleShape.pop_back();
const auto input_shape = convolution->get_input_partial_shape(0);
if (input_shape.size() != newScaleShape.size()) {
newScaleShape.pop_back();
}

OPENVINO_ASSERT(
newScaleShape.empty() || (input_shape.size() == newScaleShape.size()),
"unexpected shape size on weights");

if (!newScaleShape.empty()) {
// that's all we need: [C, 1, 1, 1] => [C, 1, 1]
newScaleShape.pop_back();
}
}

if (reshapeFromWeights != nullptr) {
Expand Down Expand Up @@ -283,6 +294,9 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
const size_t weightsRankValue = weightsPShape.rank().get_length();
Shape zeroPointShape(weightsRankValue, 1ul);
zeroPointShape[0] = static_cast<size_t>(weightsPShape[0].get_length());
if ((reshapeFromWeights == nullptr) && (weightsRankValue == 5ull)) {
zeroPointShape[1] = static_cast<size_t>(weightsPShape[1].get_length());
}

auto zeroPointConstant = fold<opset1::Broadcast>(
subtractFromWeights->input_value(1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,16 @@ bool WeightableLayerTransformation::isQuantizedStatic(const std::shared_ptr<cons
FakeQuantizeDequantization dequantizationOnWeights;
if (reshapeIsRequired) {
const auto reshape = layer->get_input_node_shared_ptr(1);
if (!ov::is_type<opset1::Reshape>(reshape)) {
return false;
}
std::shared_ptr<Node> parent = ov::is_type<opset1::Reshape>(reshape) ?
reshape->get_input_node_shared_ptr(0) :
reshape;

if (ov::is_type<opset1::FakeQuantize>(reshape->get_input_node_shared_ptr(0))) {
const std::shared_ptr<opset1::FakeQuantize> fq = ov::as_type_ptr<opset1::FakeQuantize>(reshape->get_input_node_shared_ptr(0));
const auto fq = ov::as_type_ptr<opset1::FakeQuantize>(parent);
if (fq != nullptr) {
return NetworkHelper::isQuantizeSupported(fq);
}

dequantizationOnWeights = NetworkHelper::getDequantization(reshape, defaultPrecisions, 0);
dequantizationOnWeights = NetworkHelper::getDequantization(parent, defaultPrecisions, 0, true);
} else if (ov::is_type<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1))) {
const std::shared_ptr<opset1::FakeQuantize> fq = ov::as_type_ptr<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1));
return NetworkHelper::isQuantizeSupported(fq);
Expand Down
Loading

0 comments on commit 3f7660d

Please sign in to comment.