Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unreferenced variables from ProgramDesc in prune() #7890

Merged
merged 3 commits into from
Jan 29, 2018
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions paddle/framework/prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>

#include <glog/logging.h>
Expand Down Expand Up @@ -102,6 +103,32 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
*op_field->Add() = input.blocks(block_id).ops(i);
}
}

// remove the VarDescs in BlockDesc that are not referenced in
// the pruned OpDescs
std::unordered_map<std::string, proto::VarDesc> var_map;
auto* var_field = output->mutable_blocks(block_id)->mutable_vars();
for (const auto& var : *var_field) {
var_map[var.name()] = var;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, there is no need for an extra map. We can directly use var map of input ProgramDesc. We can use input.FindVar(name)->Proto() to get the proto::VarDesc.

Copy link
Contributor Author

@kexinzhao kexinzhao Jan 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two different classes one is framework::proto::xxxDesc, one is framework::xxxDesc. The latter is basically a wrapper about the former.

For this prune function, the input is proto::ProgramDesc, so we only have limited functionality provided by the protobuf library itself. After you compile the framework.proto, you will get framework.pb.h + cc, which lists all the available functions for proto::ProgramDesc and etc.

So I don't think we can use input.FindVar() here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later, when we are going to implement out transpiler function for inference optimization on the c++ side after loading a inference desc from file or buffer, we can use your suggestion, as in that case, we can take a framework::ProgramDesc as input.

@Xreki what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this prune function, the input is proto::ProgramDesc, so we only have limited functionality provided by the protobuf library itself.

You are right. Sorry for didn't notice that.


var_field->Clear();
for (const auto& op : *op_field) {
// add VarDescs of all input arguments for each OpDesc
auto& input_field = op.inputs();
for (auto& input : input_field) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input -> input_var? Because the first argument of this function is named input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (auto& arg : input.arguments()) {
*var_field->Add() = var_map[arg];
}
}
// add VarDescs of all output arguments for each OpDesc
auto& output_field = op.outputs();
for (auto& output : output_field) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output -> output_var? Because the second argument of this function is named output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (auto& arg : output.arguments()) {
*var_field->Add() = var_map[arg];
}
}
}
}

// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
Expand Down