Skip to content

Commit

Permalink
[LPT] IntervalsAlignmentAttribute extending
Browse files Browse the repository at this point in the history
Affected models: cocosnet
  • Loading branch information
eshoguli committed Jun 7, 2021
1 parent 4f5fd79 commit ed331c2
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class LP_TRANSFORMATIONS_API IntervalsAlignmentSharedValue : public SharedValue<
Interval combinedInterval;
Interval minInterval;
size_t minLevels;
// preferable precisions which are preferred by affected quantization operations to avoid zero points
std::set<element::Type> preferablePrecisions;

// TODO: debug only
std::string minLevelsOperation;
Expand All @@ -54,6 +56,8 @@ class LP_TRANSFORMATIONS_API IntervalsAlignmentAttribute : public SharedValueAtt
size_t levels,
const IntervalsAlignmentSharedValue::Interval minInterval,
size_t minLevels);

// specify subgraph original levels
size_t levels;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ FakeQuantizeDecompositionTransformation::FakeQuantizeDecompositionTransformation

namespace fq_decomposition {

// get precision details, depends on:
// 1. FakeQuantize operation parameters (QuantizationDetails::getDetails & LayerTransformation::getPrecisionDetails)
// 2. Precisions on port
// 3.
DataPrecision getDataPrecision(std::shared_ptr<opset1::FakeQuantize> layer) {
const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(layer);
auto attribute = getAttributeFromOutput<std::shared_ptr<PrecisionsAttribute>>(layer->output(0));
if (attribute == nullptr) {
auto precisionsAttribute = getAttributeFromOutput<std::shared_ptr<PrecisionsAttribute>>(layer->output(0));
if (precisionsAttribute == nullptr) {
// TODO: explore this case in more details:
// 1. we should not be here
assert(true);

// 2. not possible to get optimal precision by decomposed FakeQuantize
LayerTransformation::PrecisionDetails precisionDetailsAtOutputIntervals = LayerTransformation::getPrecisionDetails(quantizationDetails);
return DataPrecision(
Expand All @@ -52,7 +58,7 @@ DataPrecision getDataPrecision(std::shared_ptr<opset1::FakeQuantize> layer) {
precisionDetailsAtOutputIntervals.hasZeroPoint);
}

const auto& precisions = attribute->get()->sharedValue->precisions;
const auto& precisions = precisionsAttribute->get()->sharedValue->precisions;

ngraph::element::Type precision;
bool hasZeroPoint;
Expand All @@ -67,8 +73,11 @@ DataPrecision getDataPrecision(std::shared_ptr<opset1::FakeQuantize> layer) {
precision = precisionDetailsAtOutputIntervals.precision;
hasZeroPoint = precisionDetailsAtOutputIntervals.hasZeroPoint;
}
attribute->get()->sharedValue->precisions = { precision };

// update shared attribute to affect all operations in subgraph
precisionsAttribute->get()->sharedValue->precisions = { precision };
} else {
// use only available precision
precision = *precisions.begin();
LayerTransformation::PrecisionDetails precisionDetailsAtOutputIntervals = LayerTransformation::getPrecisionDetails(quantizationDetails);
hasZeroPoint = precisionDetailsAtOutputIntervals.precision != precision;
Expand Down Expand Up @@ -171,39 +180,36 @@ bool FakeQuantizeDecompositionTransformation::transform(TransformationContext& c
// precisionDetailsAtOutputIntervals.precision == preferedPrecision ? precisionDetailsAtOutputIntervals.hasZeroPoint : true);
//}

DataPrecision dataPrecision = fq_decomposition::getDataPrecision(layer);


std::shared_ptr<IntervalsAlignmentAttribute> intervalsAlignment;
if ((layer->get_friendly_name() == "Concat_1515/fq_input_0") || (layer->get_friendly_name() == "Concat_1515/fq_input_1")) {
std::cout << layer->get_friendly_name() << std::endl;
}

std::shared_ptr<ngraph::VariantWrapper<std::shared_ptr<QuantizationAlignmentAttribute>>> alignmentValue;
std::shared_ptr<QuantizationAlignmentAttribute> quantizationAlignment;
for (const auto& input : layer->output(0).get_target_inputs()) {
alignmentValue = low_precision::getAttribute<std::shared_ptr<QuantizationAlignmentAttribute>>(input.get_node()->shared_from_this());
if ((alignmentValue != nullptr) && (alignmentValue->get()->sharedValue->value)) {
break;
const auto alignmentValueWrapper = low_precision::getAttribute<std::shared_ptr<QuantizationAlignmentAttribute>>(input.get_node()->shared_from_this());
if (alignmentValueWrapper != nullptr) {
quantizationAlignment = alignmentValueWrapper->get();
if (quantizationAlignment->sharedValue->value) {
break;
}
}
}

if ((alignmentValue != nullptr) && alignmentValue->get()->sharedValue->value) {
//auto& rt = layer->get_rt_info();
//auto it = rt.find(ngraph::VariantWrapper<IntervalsAlignmentAttributePtr>::type_info.name);
//if (it != rt.end()) {
// auto attributeWrapper = std::dynamic_pointer_cast<ngraph::VariantWrapper<IntervalsAlignmentAttributePtr>>(it->second);
// const std::shared_ptr<IntervalsAlignmentAttribute> attribute = attributeWrapper->get();
// intervalsAlignment = attribute->hasToBeAligned ? attribute : nullptr;
//}

std::shared_ptr<IntervalsAlignmentAttribute> intervalsAlignment;
{
auto intervalsAlignmentWrapper = low_precision::getAttribute<std::shared_ptr<IntervalsAlignmentAttribute>>(layer);
if (intervalsAlignmentWrapper != nullptr) {
intervalsAlignment = intervalsAlignmentWrapper->get();
}
}

if (intervalsAlignment != nullptr) {
if ((quantizationAlignment != nullptr) && (quantizationAlignment->sharedValue->value) && (intervalsAlignment != nullptr)) {
if (intervalsAlignment->sharedValue->minLevels <= 2ul) {
return false;
}

DataPrecision dataPrecision = fq_decomposition::getDataPrecision(layer);

// const auto& combinedInterval = intervalsAlignment->sharedValue->combinedInterval;
// const float maxOutputInterval = combinedInterval.high - combinedInterval.low;
// // FQ -> SUB_quantization -> MUL_quantization -[INT8]-> SUB_dequantization -> MUL_dequantization ->
Expand Down Expand Up @@ -272,22 +278,29 @@ bool FakeQuantizeDecompositionTransformation::transform(TransformationContext& c
//ngraph::copy_runtime_info(sourceNodes, targetNodes);
NetworkHelper::copyInfo(sourceNodes, targetNodes);
} else {
//if (preferedPrecision == element::undefined) {
// if (dataPrecision.precision == element::undefined) {
// dataPrecision = getDataPrecision(layer, quantizationDetails, false);
// if (dataPrecision.precision == element::undefined) {
// return false;
// }
// }
//} else {
// dataPrecision = DataPrecision();;
//}
DataPrecision dataPrecision;

if (intervalsAlignment != nullptr) {
const auto& preferablePrecisions = intervalsAlignment->sharedValue->preferablePrecisions;
DataPrecision dataPrecision;
if (preferablePrecisions.find(ngraph::element::u8) == preferablePrecisions.end()) {
dataPrecision = fq_decomposition::getDataPrecision(layer);
} else {
dataPrecision = DataPrecision(
ngraph::element::u8,
DataPrecision::getMinValue(ngraph::element::u8, quantizationDetails.levels),
DataPrecision::getMaxValue(ngraph::element::u8, quantizationDetails.levels),
LayerTransformation::getPrecisionDetails(quantizationDetails).precision != ngraph::element::u8);
}
}

if (dataPrecision.precision == element::undefined) {
const auto precisionsAttribute = getAttributeFromOutput<PrecisionsAttributePtr>(layer);
const auto precisions = precisionsAttribute == nullptr ?
PrecisionsAttribute::defaultPrecisions :
precisionsAttribute->get()->sharedValue->precisions;

// one place where is operation precision is used independently on attributes (getDataPrecision)
dataPrecision = getDataPrecision(layer, quantizationDetails, precisions);
if (dataPrecision.precision == element::undefined) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_function(std::shared_ptr<
markupAndDecompose.register_pass<low_precision::PropagatePrecisions>();
markupAndDecompose.register_pass<low_precision::AlignQuantizationIntervals>();
markupAndDecompose.register_pass<low_precision::AlignQuantizationParameters>();
markupAndDecompose.run_passes(f);
}

{
ngraph::pass::Manager markupAndDecompose(passConfig);
markupAndDecompose.register_pass<ngraph::pass::low_precision::FakeQuantizeDecompositionTransformation>(params);
markupAndDecompose.run_passes(f);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ std::shared_ptr<ngraph::Variant> VariantWrapper<IntervalsAlignmentAttributePtr>:
assert(!std::isinf(resultSharedValue->combinedInterval.low));
assert(!std::isinf(resultSharedValue->combinedInterval.high));

resultSharedValue->preferablePrecisions.insert(sharedValue->preferablePrecisions.begin(), sharedValue->preferablePrecisions.end());

const auto resultSize = abs(resultSharedValue->minInterval.high - resultSharedValue->minInterval.low);
const auto size = abs(sharedValue->minInterval.high - sharedValue->minInterval.low);
if (resultSize > size) {
Expand Down Expand Up @@ -181,6 +183,10 @@ std::shared_ptr<VariantWrapper<std::shared_ptr<IntervalsAlignmentAttribute>>> Va
fakeQuantize->get_levels()));
rtInfo[ngraph::VariantWrapper<IntervalsAlignmentAttributePtr>::type_info.name] = attribute;

const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(fakeQuantize);
LayerTransformation::PrecisionDetails preferablePrecision = LayerTransformation::getPrecisionDetails(quantizationDetails);
attribute->get()->sharedValue->preferablePrecisions.insert(preferablePrecision.precision);

attribute->get()->sharedValue->minLevelsOperation = node->get_friendly_name();

return attribute;
Expand Down Expand Up @@ -212,6 +218,8 @@ void VariantWrapper<IntervalsAlignmentAttributePtr>::merge(
assert(!std::isinf(resultSharedValue->combinedInterval.low));
assert(!std::isinf(resultSharedValue->combinedInterval.high));

resultSharedValue->preferablePrecisions.insert(sharedValue->preferablePrecisions.begin(), sharedValue->preferablePrecisions.end());

const auto resultSize = abs(resultSharedValue->minInterval.high - resultSharedValue->minInterval.low);
const auto size = abs(sharedValue->minInterval.high - sharedValue->minInterval.low);
if (resultSize > size) {
Expand Down

0 comments on commit ed331c2

Please sign in to comment.