Skip to content

Commit

Permalink
[LPT] ConvolutionTransformation with asymmetric quantization after Sp…
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev authored and yekruglov committed Jun 7, 2021
1 parent 856b215 commit 2b6ea2f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
broadcastShape[1] = subtract->get_output_shape(0)[1];

std::shared_ptr<Node> newShift = fold<opset1::Broadcast>(
subtract->input_value(1).get_node_shared_ptr(),
subtract->input_value(1),
std::make_shared<opset1::Constant>(
element::i64,
Shape{ length },
broadcastShape));

const auto newSubtract = as_type_ptr<opset1::Subtract>(subtract->clone_with_new_inputs({
subtract->input_value(0).get_node_shared_ptr(),
subtract->input_value(0),
newShift }));
NetworkHelper::copyInfo(subtract, newSubtract);
replace_node(subtract, newSubtract);
Expand Down Expand Up @@ -176,7 +176,7 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
if (is_type<opset1::Convert>(convolution->get_input_node_ptr(0))) {
auto newConvolution = convolution->clone_with_new_inputs({
convolution->get_input_node_ptr(0)->get_input_source_output(0),
convolution->get_input_node_shared_ptr(1) });
convolution->input_value(1)});
replace_node(convolution, newConvolution);
convolution = newConvolution;
}
Expand Down Expand Up @@ -249,7 +249,7 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
zeroPointShape[0] = weightsShape[0];

auto zeroPointConstant = fold<opset1::Broadcast>(
subtractFromWeights->get_input_node_shared_ptr(1),
subtractFromWeights->input_value(1),
std::make_shared<opset1::Constant>(element::i32, Shape{ zeroPointShape.size() }, zeroPointShape));
NetworkHelper::copyInfo(subtractFromWeights->get_input_node_shared_ptr(1), zeroPointConstant);
replace_node(subtractFromWeights->get_input_node_shared_ptr(1), zeroPointConstant);
Expand All @@ -266,7 +266,7 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
auto newConvolution = convolution->clone_with_new_inputs({
convolution->get_input_source_output(0),
childNode.get() == convolution.get() ?
convolution->get_input_node_ptr(1)->get_input_node_shared_ptr(0) :
convolution->get_input_node_ptr(1)->input_value(0) :
childNode->copy_with_new_inputs({convertFromWeights->input_value(0), childNode->input_value(1)})});
replace_node(convolution, newConvolution);
convolution = newConvolution;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ SimpleLowPrecisionTransformer getTransformerWithTransformationByName(
transformer.add<ClampTransformation, opset1::Clamp>(params);
return transformer;
}
if (name == "ConvolutionTransformation") {
if (name == "ConvolutionTransformation" || name == "AsymmetricConvolutionTransformation") {
transformer.add<ConvolutionTransformation, opset1::Convolution>(params);
return transformer;
}
Expand Down Expand Up @@ -190,6 +190,7 @@ const std::vector<std::string> transformationNames = {
"AvgPoolTransformation",
"ClampTransformation",
"ConvolutionTransformation",
"AsymmetricConvolutionTransformation",
"DepthToSpaceTransformation",
"FakeQuantizeTransformation",
"InterpolateTransformation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@ std::shared_ptr<Node> TransformationsAfterSplitFunction::getLayerByTransformatio
CoordinateDiff{ 0, 0 },
Strides{ 1, 1 });
}
if (transformationName == "AsymmetricConvolutionTransformation") {
const auto dequantizationOnData = makeDequantization(parent, { {element::f32}, { 128.f }, { 0.1f } });
const auto weights = opset1::Constant::create(element::i8, Shape{ 3, 3, 1, 1 }, { 2 });
const auto dequantizationOnWeights = makeDequantization(weights, { {element::f32}, {}, {0.3f} });
return std::make_shared<opset1::Convolution>(
dequantizationOnData,
dequantizationOnWeights,
Strides{ 1, 1 },
CoordinateDiff{ 0, 0 },
CoordinateDiff{ 0, 0 },
Strides{ 1, 1 });
}
if (transformationName == "DepthToSpaceTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
return std::make_shared<opset1::DepthToSpace>(dequantization, opset1::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 3);
Expand Down

0 comments on commit 2b6ea2f

Please sign in to comment.