Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vzinovie committed Apr 21, 2021
1 parent 01f80cf commit 2e0c7f0
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class TRANSFORMATIONS_API NetworkHelper {
const bool hasZeroPoint,
const bool updatePrecision,
const element::Type deqPrecision = element::f32,
const int outChannelsShapeIndex = 0);
const size_t outChannelsShapeIndex = 0);

static std::shared_ptr<opset1::FakeQuantize> updateFakeQuantize(
std::shared_ptr<opset1::FakeQuantize> fq,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TRANSFORMATIONS_API WeightableLayerTransformation : public LayerTransforma
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;

protected:
void decomposeFakeQuantizeForWeightsPath(const std::shared_ptr<Node>& weightableLayer, int outChannelsShapeIndex = 0) const;
void decomposeFakeQuantizeForWeightsPath(const std::shared_ptr<Node>& weightableLayer, size_t outChannelsShapeIndex = 0ul) const;
static bool isGroup(const std::shared_ptr<Node>& node);
static bool isDepthwise(const std::shared_ptr<Node>& node);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,13 @@ bool ConvolutionBackpropDataTransformation::transform(TransformationContext &con
FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(convolutionBackpropData);
const bool haveOutputShape = convolutionBackpropData->get_input_size() == 3;
{
std::shared_ptr<opset1::Subtract> subtract;
if (dequantization.subtract != nullptr) {
std::shared_ptr<ngraph::Node> layer = dequantization.subtract;
ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(layer);

auto optimizedSubtract = NetworkHelper::optimizeSubtract(dequantization.subtract);
subtract = optimizedSubtract ? as_type_ptr<opset1::Subtract>(optimizedSubtract) : dequantization.subtract;
NetworkHelper::optimizeSubtract(dequantization.subtract);
}

std::shared_ptr<opset1::Constant> reducedConstant = as_type_ptr<opset1::Constant>(
dequantization.multiply->input_value(1).get_node_shared_ptr());
std::shared_ptr<Node> newMultiplyAfterConst = std::make_shared<opset1::Constant>(
reducedConstant->get_output_element_type(0),
Shape{ 1 },
reducedConstant->cast_vector<float>()[0]);

const auto copyNode = haveOutputShape ?
convolutionBackpropData->copy_with_new_inputs({
dequantization.multiply->input_value(0),
Expand All @@ -90,7 +81,7 @@ bool ConvolutionBackpropDataTransformation::transform(TransformationContext &con
std::vector<element::Type>{ deqPrecision, deqPrecision },
std::vector<element::Type>{ dequantization.multiply->get_output_element_type(0) },
ngraph::op::TemporaryReplaceOutputType(relaxedConvolutionBackpropData, deqPrecision).get(),
ngraph::op::TemporaryReplaceOutputType(newMultiplyAfterConst, deqPrecision).get());
ngraph::op::TemporaryReplaceOutputType(dequantization.multiplyConstant, deqPrecision).get());

replace_node(convolutionBackpropData, newMultiplyAfter);
convolutionBackpropData = newMultiplyAfter->input_value(0).get_node_shared_ptr();
Expand All @@ -110,7 +101,7 @@ bool ConvolutionBackpropDataTransformation::transform(TransformationContext &con
}

{
decomposeFakeQuantizeForWeightsPath(convolutionBackpropData, 1);
decomposeFakeQuantizeForWeightsPath(convolutionBackpropData, 1ul);

dequantization = NetworkHelper::getDequantization(convolutionBackpropData, 1ul);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decompos
const bool hasZeroPoint,
const bool updatePrecision,
const element::Type deqPrecision,
const int outChannelsShapeIndex) {
const size_t outChannelsShapeIndex) {
using std::make_shared;

const auto outputLow = fq->input_value(3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,11 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext
return false;
}

if (!is_type<opset1::ConvolutionBackpropData>(layer)) {
if ( // Check if all dimensions of scale except the first one (which is O-Output channels dimension) are all ones
(shape_size(constOutputShape) != constOutputShape[0]) ||
((constOutputShape[0] != 1ul) && (fqFromWeights->get_output_shape(0)[0] != constOutputShape[0]))) {
return false;
}
} else {
if ( // Check if all dimensions of scale except the second one (which is O-Output channels dimension) are all ones
(shape_size(constOutputShape) != constOutputShape[1]) ||
((constOutputShape[1] != 1ul) && (fqFromWeights->get_output_shape(0)[1] != constOutputShape[1]))) {
const size_t outChannelsShapeIndex = is_type<opset1::ConvolutionBackpropData>(layer) ? 1ul : 0ul;
if ( // Check if all dimensions of scale except the output channels are all ones
(shape_size(constOutputShape) != constOutputShape[outChannelsShapeIndex]) ||
((constOutputShape[outChannelsShapeIndex] != 1ul) &&
(fqFromWeights->get_output_shape(0)[outChannelsShapeIndex] != constOutputShape[outChannelsShapeIndex]))) {
return false;
}
}
Expand Down Expand Up @@ -209,7 +204,7 @@ bool WeightableLayerTransformation::isPrecisionPreserved(std::shared_ptr<Node> l
return false;
}

void WeightableLayerTransformation::decomposeFakeQuantizeForWeightsPath(const std::shared_ptr<Node>& node, const int outChannelsShapeIndex) const {
void WeightableLayerTransformation::decomposeFakeQuantizeForWeightsPath(const std::shared_ptr<Node>& node, const size_t outChannelsShapeIndex) const {
const auto fq = getFakeQuantizeOnWeights(node);
if (fq == nullptr) {
return;
Expand Down

0 comments on commit 2e0c7f0

Please sign in to comment.