Skip to content

Commit

Permalink
Fix memory leak inside Pruning transformation (#5819)
Browse files Browse the repository at this point in the history
Co-authored-by: Gleb Kazantaev <[email protected]>
  • Loading branch information
Gleb Kazantaev and Gleb Kazantaev authored May 26, 2021
1 parent e754a6b commit 5b291b5
Showing 1 changed file with 33 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,19 @@ class ngraph::pass::mask_propagation::Convolution : public MatcherPass {
auto weights_mask = getMask(m_weights);
// If weights are not a Constant and we didn't set Mask value before we will get nullptr
if (!weights_mask) return false;
auto weights_mask_row = weights_mask.get();

if (auto input_mask = getMask(m_input)) {
auto input_mask_row = input_mask.get();
// Weights input channel is connected to the convolution input channel dimension
// so we update weights mask to be aligned with input shape.
weights_mask->add_callback([input_mask](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1/* weights input channel */) = input_mask->at(1 /* input data channel */);
weights_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1/* weights input channel */) = input_mask_row->at(1 /* input data channel */);
return true;
}, input_mask);

input_mask->add_callback([weights_mask](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask->at(1);
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask_row->at(1);
return true;
}, weights_mask);

Expand All @@ -65,14 +67,15 @@ class ngraph::pass::mask_propagation::Convolution : public MatcherPass {

// Create output mask that describes which channel dimensions will be removed
auto conv_mask = std::make_shared<Mask>(m_weights.get_shape().size());
auto conv_mask_row = conv_mask.get();

conv_mask->add_callback([weights_mask](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask->at(0/*weights output channel dim */);
conv_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask_row->at(0/*weights output channel dim */);
return true;
}, weights_mask);

weights_mask->add_callback([conv_mask](Mask::Ptr cur_mask) -> bool {
cur_mask->at(0) = conv_mask->at(1);
weights_mask->add_callback([conv_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(0) = conv_mask_row->at(1);
return true;
}, conv_mask);

Expand Down Expand Up @@ -112,23 +115,25 @@ class ngraph::pass::mask_propagation::GroupConvolution : public MatcherPass {

auto input_mask = getMask(m_input);
if (!input_mask) return false;
auto input_mask_row = input_mask.get();

auto weights_mask = getMask(m_weights);
if (!weights_mask) {
// TODO: only if weights are constant
weights_mask = std::make_shared<Mask>(weights_shape.size());
setMask(m_weights, weights_mask);
}
auto weights_mask_row = weights_mask.get();

// Weights input channel is connected to the convolution input channel dimension
// so we update weights mask to be aligned with input shape.
weights_mask->add_callback([input_mask](Mask::Ptr cur_mask) -> bool {
cur_mask->at(0) = input_mask->at(1);
weights_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(0) = input_mask_row->at(1);
return true;
}, input_mask);

input_mask->add_callback([weights_mask](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask->at(0);
input_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask_row->at(0);
return true;
}, weights_mask);

Expand All @@ -138,14 +143,15 @@ class ngraph::pass::mask_propagation::GroupConvolution : public MatcherPass {

// Update output channels mask dims
auto conv_mask = std::make_shared<Mask>(input_shape.rank().get_length());
auto conv_mask_row = conv_mask.get();

conv_mask->add_callback([weights_mask](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask->at(0);
conv_mask->add_callback([weights_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(1) = weights_mask_row->at(0);
return true;
}, weights_mask);

weights_mask->add_callback([conv_mask](Mask::Ptr cur_mask) -> bool {
cur_mask->at(0) = conv_mask->at(1);
weights_mask->add_callback([conv_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->at(0) = conv_mask_row->at(1);
return true;
}, conv_mask);

Expand Down Expand Up @@ -195,21 +201,24 @@ class ngraph::pass::mask_propagation::Elementwise : public MatcherPass {
NGRAPH_DEBUG << "No mask for: " << m_output.get_node()->get_friendly_name() << std::endl;
return false;
}
auto input_mask_row = input_mask.get();
auto weights_mask_row = weights_mask.get();

// Merge masks from two inputs
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
auto output_mask_row = output_mask.get();

auto out_mask_callback = [input_mask, weights_mask](Mask::Ptr cur_mask) -> bool {
auto out_mask_callback = [input_mask_row, weights_mask_row](Mask::Ptr cur_mask) -> bool {
auto omask_iter = cur_mask->rbegin();
auto imask_iter = input_mask->rbegin();
auto wmask_iter = weights_mask->rbegin();
auto imask_iter = input_mask_row->rbegin();
auto wmask_iter = weights_mask_row->rbegin();

for (auto & item : *cur_mask) {
item.clear();
}

while (imask_iter != input_mask->rend() &&
wmask_iter != weights_mask->rend()) {
while (imask_iter != input_mask_row->rend() &&
wmask_iter != weights_mask_row->rend()) {
// Merge mask dimension values for both masks
// Example: (MaskValue[1,2,3,4], MaskValue[2,3]) -> MaskValue[2,3]
for (const auto & value : *imask_iter) {
Expand All @@ -227,10 +236,10 @@ class ngraph::pass::mask_propagation::Elementwise : public MatcherPass {
output_mask->add_callback(out_mask_callback, input_mask);
output_mask->add_callback(out_mask_callback, weights_mask);

auto callback = [output_mask](Mask::Ptr cur_mask) -> bool {
auto omask_iter = output_mask->rbegin();
auto callback = [output_mask_row](Mask::Ptr cur_mask) -> bool {
auto omask_iter = output_mask_row->rbegin();
auto cmask_iter = cur_mask->rbegin();
while (omask_iter != output_mask->rend() &&
while (omask_iter != output_mask_row->rend() &&
cmask_iter != cur_mask->rend()) {
// TODO: check
*cmask_iter = *omask_iter;
Expand Down

0 comments on commit 5b291b5

Please sign in to comment.