Skip to content

Commit

Permalink
[LPT] NNCF GroupConvolution 5D on weights support
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Mar 20, 2023
1 parent 4ffecce commit 964118e
Show file tree
Hide file tree
Showing 10 changed files with 1,173 additions and 602 deletions.
16 changes: 14 additions & 2 deletions src/common/low_precision_transformations/src/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,15 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
Shape newScaleShape = newScalePShape.to_shape();

if (!newScaleShape.empty()) {
// that's all we need: [C, 1, 1, 1] => [C, 1, 1]
newScaleShape.pop_back();
const auto input_shape = convolution->get_input_partial_shape(0);
const auto diff = newScaleShape.size() - input_shape.size();
OPENVINO_ASSERT(
newScaleShape.empty() || ((0 <= diff) && (diff <= 2ull)),
"unexpected shape size on weights");

for (size_t i = 0; i <= diff; ++i) {
newScaleShape.pop_back();
}
}

if (reshapeFromWeights != nullptr) {
Expand Down Expand Up @@ -282,7 +289,12 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph

const size_t weightsRankValue = weightsPShape.rank().get_length();
Shape zeroPointShape(weightsRankValue, 1ul);
// output channel or group
zeroPointShape[0] = static_cast<size_t>(weightsPShape[0].get_length());
if ((reshapeFromWeights == nullptr) && (weightsRankValue == 5ull)) {
// output channel
zeroPointShape[1] = static_cast<size_t>(weightsPShape[1].get_length());
}

auto zeroPointConstant = fold<opset1::Broadcast>(
subtractFromWeights->input_value(1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,6 @@ ngraph::pass::low_precision::PullReshapeThroughDequantization::PullReshapeThroug
const auto& opsMap = m.get_pattern_value_map();
auto reshape = opsMap.at(reshapeWrapper).get_node_shared_ptr();

auto child = reshape->get_output_target_inputs(0).begin()->get_node();
if (ov::is_type<opset1::GroupConvolution>(child)) {
return false;
}

while (reshape != nullptr) {
const auto parent = reshape->get_input_node_shared_ptr(0);
if (ov::is_type<opset1::Multiply>(parent) || ov::is_type<opset1::Subtract>(parent)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,16 @@ bool WeightableLayerTransformation::isQuantizedStatic(const std::shared_ptr<cons
FakeQuantizeDequantization dequantizationOnWeights;
if (reshapeIsRequired) {
const auto reshape = layer->get_input_node_shared_ptr(1);
if (!ov::is_type<opset1::Reshape>(reshape)) {
return false;
}
std::shared_ptr<Node> parent = ov::is_type<opset1::Reshape>(reshape) ?
reshape->get_input_node_shared_ptr(0) :
reshape;

if (ov::is_type<opset1::FakeQuantize>(reshape->get_input_node_shared_ptr(0))) {
const std::shared_ptr<opset1::FakeQuantize> fq = ov::as_type_ptr<opset1::FakeQuantize>(reshape->get_input_node_shared_ptr(0));
const auto fq = ov::as_type_ptr<opset1::FakeQuantize>(parent);
if (fq != nullptr) {
return NetworkHelper::isQuantizeSupported(fq);
}

dequantizationOnWeights = NetworkHelper::getDequantization(reshape, defaultPrecisions, 0);
dequantizationOnWeights = NetworkHelper::getDequantization(parent, defaultPrecisions, 0, true);
} else if (ov::is_type<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1))) {
const std::shared_ptr<opset1::FakeQuantize> fq = ov::as_type_ptr<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1));
return NetworkHelper::isQuantizeSupported(fq);
Expand Down
Loading

0 comments on commit 964118e

Please sign in to comment.