Skip to content

Commit

Permalink
[LPT] NNCF GroupConvolution 5D on weights support
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Mar 20, 2023
1 parent 497b788 commit 08573bf
Show file tree
Hide file tree
Showing 7 changed files with 1,046 additions and 492 deletions.
25 changes: 23 additions & 2 deletions src/common/low_precision_transformations/src/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ size_t ConvolutionTransformation::getInputChannels(const std::shared_ptr<ngraph:
bool ConvolutionTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) {
auto convolution = m.get_match_root();

//std::cout << "ConvolutionTransformation::transform: " << convolution->get_friendly_name() << std::endl;

if (!canConvolutionBeTransformed(context, convolution, defaultPrecisions)) {
const auto weightInput = convolution->get_input_node_shared_ptr(1);
const auto reshapeFromWeights = ov::as_type_ptr<opset1::Reshape>(weightInput);
Expand Down Expand Up @@ -236,9 +238,24 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
assert(newScalePShape.is_static());
Shape newScaleShape = newScalePShape.to_shape();

//if (convolution->get_friendly_name() == "/features/features.1/conv/conv.0/conv.0.0/Conv/WithoutBiases") {
// std::cout << "DEBUG" << std::endl;
//}

if (!newScaleShape.empty()) {
// that's all we need: [C, 1, 1, 1] => [C, 1, 1]
newScaleShape.pop_back();
const auto input_shape = convolution->get_input_partial_shape(0);
if (input_shape.size() != newScaleShape.size()) {
newScaleShape.pop_back();
}

OPENVINO_ASSERT(
newScaleShape.empty() || (input_shape.size() == newScaleShape.size()),
"unexpected shape size on weights");

if (!newScaleShape.empty()) {
// that's all we need: [C, 1, 1, 1] => [C, 1, 1]
newScaleShape.pop_back();
}
}

if (reshapeFromWeights != nullptr) {
Expand Down Expand Up @@ -283,6 +300,10 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
const size_t weightsRankValue = weightsPShape.rank().get_length();
Shape zeroPointShape(weightsRankValue, 1ul);
zeroPointShape[0] = static_cast<size_t>(weightsPShape[0].get_length());
// TODO: not completed
if ((reshapeFromWeights == nullptr) && (weightsRankValue == 5ull)) {
zeroPointShape[1] = static_cast<size_t>(weightsPShape[1].get_length());
}

auto zeroPointConstant = fold<opset1::Broadcast>(
subtractFromWeights->input_value(1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,13 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext
return true;
}

#define MASTER_
bool WeightableLayerTransformation::isQuantizedStatic(const std::shared_ptr<const Node>& layer,
const bool reshapeIsRequired,
const std::vector<ngraph::element::Type>& defaultPrecisions) {
FakeQuantizeDequantization dequantizationOnWeights;
if (reshapeIsRequired) {
#ifdef MASTER
const auto reshape = layer->get_input_node_shared_ptr(1);
if (!ov::is_type<opset1::Reshape>(reshape)) {
return false;
Expand All @@ -240,6 +242,20 @@ bool WeightableLayerTransformation::isQuantizedStatic(const std::shared_ptr<cons
}

dequantizationOnWeights = NetworkHelper::getDequantization(reshape, defaultPrecisions, 0);
#else
const auto reshape = layer->get_input_node_shared_ptr(1);
std::shared_ptr<Node> parent = ov::is_type<opset1::Reshape>(reshape) ?
reshape->get_input_node_shared_ptr(0) :
reshape;

const auto fq = ov::as_type_ptr<opset1::FakeQuantize>(parent);
if (fq != nullptr) {
return NetworkHelper::isQuantizeSupported(fq);
}

dequantizationOnWeights = NetworkHelper::getDequantization(parent, defaultPrecisions, 0, true);
//std::cout << "WeightableLayerTransformation::isQuantizedStatic: " << layer->get_friendly_name() << std::endl;
#endif
} else if (ov::is_type<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1))) {
const std::shared_ptr<opset1::FakeQuantize> fq = ov::as_type_ptr<opset1::FakeQuantize>(layer->get_input_node_shared_ptr(1));
return NetworkHelper::isQuantizeSupported(fq);
Expand Down
Loading

0 comments on commit 08573bf

Please sign in to comment.