Skip to content

Commit

Permalink
[LPT] Check if layer after concat isQuantized and require per-tensor …
Browse files Browse the repository at this point in the history
…quantize
  • Loading branch information
vzinovie committed May 6, 2021
1 parent d1321f9 commit 1ea24a4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ bool ConcatMultiChannelsTransformation::isMultiChannel(const std::vector<std::sh
for (const std::shared_ptr<ngraph::opset1::Concat>& concat : concatLayers) {
const std::vector<std::shared_ptr<ngraph::Node>> children = getChildrenRecursivelyExceptPrecisionPreserved(concat);
for (const std::shared_ptr<ngraph::Node>& child : children) {
if (is_type<ngraph::opset1::Convolution>(child.get()) ||
is_type<ngraph::opset1::ConvolutionBackpropData>(child.get())) {
if ((is_type<ngraph::opset1::Convolution>(child.get()) ||
is_type<ngraph::opset1::ConvolutionBackpropData>(child.get())) &&
this->layerTransformationsManager->isQuantized(child)) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ void ConvolutionBackpropDataTransformation::registerMatcherIn(GraphRewrite &pass
}

bool ConvolutionBackpropDataTransformation::isQuantized(std::shared_ptr<Node> layer) const noexcept {
if (deconvolutionSpecificChannelsRatio) {
size_t inputChannels = layer->get_input_shape(0)[1];
size_t outputChannels = layer->get_output_shape(0)[1];
if (inputChannels % 4 != 0 || outputChannels % 16 != 0) {
return false;
}
}
return WeightableLayerTransformation::isQuantized(layer, false);
}

Expand Down

0 comments on commit 1ea24a4

Please sign in to comment.