From e443b87aa9eefb693899fa0707e10576746249f7 Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Wed, 26 May 2021 12:34:50 +0300 Subject: [PATCH] Fix memory leak inside Pruning transformation (#5819) Co-authored-by: Gleb Kazantaev --- .../src/pruning/propagate_masks.cpp | 57 +++++++++++-------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp b/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp index 6016a162162462..ac7a8e8b6859c7 100644 --- a/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp +++ b/inference-engine/src/offline_transformations/src/pruning/propagate_masks.cpp @@ -44,17 +44,19 @@ class ngraph::pass::mask_propagation::Convolution : public MatcherPass { auto weights_mask = getMask(m_weights); // If weights are not a Constant and we didn't set Mask value before we will get nullptr if (!weights_mask) return false; + auto weights_mask_row = weights_mask.get(); if (auto input_mask = getMask(m_input)) { + auto input_mask_row = input_mask.get(); // Weights input channel is connected to the convolution input channel dimension // so we update weights mask to be aligned with input shape. - weights_mask->add_callback([input_mask](Mask::Ptr cur_mask) -> bool { - cur_mask->at(1/* weights input channel */) = input_mask->at(1 /* input data channel */); + weights_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool { + cur_mask->at(1/* weights input channel */) = input_mask_row->at(1 /* input data channel */); return true; }, input_mask); - input_mask->add_callback([weights_mask](Mask::Ptr cur_mask) -> bool { - cur_mask->at(1) = weights_mask->at(1); + input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool { + cur_mask->at(1) = weights_mask_row->at(1); return true; }, weights_mask); @@ -65,14 +67,15 @@ class ngraph::pass::mask_propagation::Convolution : public MatcherPass { // Create output mask that describes which channel dimensions will be removed auto conv_mask = std::make_shared(m_weights.get_shape().size()); + auto conv_mask_row = conv_mask.get(); - conv_mask->add_callback([weights_mask](Mask::Ptr cur_mask) -> bool { - cur_mask->at(1) = weights_mask->at(0/*weights output channel dim */); + conv_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool { + cur_mask->at(1) = weights_mask_row->at(0/*weights output channel dim */); return true; }, weights_mask); - weights_mask->add_callback([conv_mask](Mask::Ptr cur_mask) -> bool { - cur_mask->at(0) = conv_mask->at(1); + weights_mask->add_callback([conv_mask_row](Mask::Ptr cur_mask) -> bool { + cur_mask->at(0) = conv_mask_row->at(1); return true; }, conv_mask); @@ -112,6 +115,7 @@ class ngraph::pass::mask_propagation::GroupConvolution : public MatcherPass { auto input_mask = getMask(m_input); if (!input_mask) return false; + auto input_mask_row = input_mask.get(); auto weights_mask = getMask(m_weights); if (!weights_mask) { @@ -119,16 +123,17 @@ class ngraph::pass::mask_propagation::GroupConvolution : public MatcherPass { weights_mask = std::make_shared(weights_shape.size()); setMask(m_weights, weights_mask); } + auto weights_mask_row = weights_mask.get(); // Weights input channel is connected to the convolution input channel dimension // so we update weights mask to be aligned with input shape. - weights_mask->add_callback([input_mask](Mask::Ptr cur_mask) -> bool { - cur_mask->at(0) = input_mask->at(1); + weights_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool { + cur_mask->at(0) = input_mask_row->at(1); return true; }, input_mask); - input_mask->add_callback([weights_mask](Mask::Ptr cur_mask) -> bool { - cur_mask->at(1) = weights_mask->at(0); + input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool { + cur_mask->at(1) = weights_mask_row->at(0); return true; }, weights_mask); @@ -138,14 +143,15 @@ class ngraph::pass::mask_propagation::GroupConvolution : public MatcherPass { // Update output channels mask dims auto conv_mask = std::make_shared(input_shape.rank().get_length()); + auto conv_mask_row = conv_mask.get(); - conv_mask->add_callback([weights_mask](Mask::Ptr cur_mask) -> bool { - cur_mask->at(1) = weights_mask->at(0); + conv_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool { + cur_mask->at(1) = weights_mask_row->at(0); return true; }, weights_mask); - weights_mask->add_callback([conv_mask](Mask::Ptr cur_mask) -> bool { - cur_mask->at(0) = conv_mask->at(1); + weights_mask->add_callback([conv_mask_row](Mask::Ptr cur_mask) -> bool { + cur_mask->at(0) = conv_mask_row->at(1); return true; }, conv_mask); @@ -195,21 +201,24 @@ class ngraph::pass::mask_propagation::Elementwise : public MatcherPass { NGRAPH_DEBUG << "No mask for: " << m_output.get_node()->get_friendly_name() << std::endl; return false; } + auto input_mask_row = input_mask.get(); + auto weights_mask_row = weights_mask.get(); // Merge masks from two inputs auto output_mask = std::make_shared(m_output.get_partial_shape().rank().get_length()); + auto output_mask_row = output_mask.get(); - auto out_mask_callback = [input_mask, weights_mask](Mask::Ptr cur_mask) -> bool { + auto out_mask_callback = [input_mask_row, weights_mask_row](Mask::Ptr cur_mask) -> bool { auto omask_iter = cur_mask->rbegin(); - auto imask_iter = input_mask->rbegin(); - auto wmask_iter = weights_mask->rbegin(); + auto imask_iter = input_mask_row->rbegin(); + auto wmask_iter = weights_mask_row->rbegin(); for (auto & item : *cur_mask) { item.clear(); } - while (imask_iter != input_mask->rend() && - wmask_iter != weights_mask->rend()) { + while (imask_iter != input_mask_row->rend() && + wmask_iter != weights_mask_row->rend()) { // Merge mask dimension values for both masks // Example: (MaskValue[1,2,3,4], MaskValue[2,3]) -> MaskValue[2,3] for (const auto & value : *imask_iter) { @@ -227,10 +236,10 @@ class ngraph::pass::mask_propagation::Elementwise : public MatcherPass { output_mask->add_callback(out_mask_callback, input_mask); output_mask->add_callback(out_mask_callback, weights_mask); - auto callback = [output_mask](Mask::Ptr cur_mask) -> bool { - auto omask_iter = output_mask->rbegin(); + auto callback = [output_mask_row](Mask::Ptr cur_mask) -> bool { + auto omask_iter = output_mask_row->rbegin(); auto cmask_iter = cur_mask->rbegin(); - while (omask_iter != output_mask->rend() && + while (omask_iter != output_mask_row->rend() && cmask_iter != cur_mask->rend()) { // TODO: check *cmask_iter = *omask_iter;