Skip to content

Commit

Permalink
[RELAY] Re-wrote the Graph Partitioner to support multiple outputs
Browse files Browse the repository at this point in the history
    *removed the expected use-case as we are taking broken-down PR approach
    *code style fixes
    *some trivial one liners
  • Loading branch information
manupak committed Mar 26, 2020
1 parent 32676f9 commit a4653fd
Showing 1 changed file with 18 additions and 33 deletions.
51 changes: 18 additions & 33 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ namespace partitioning {

// Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead.
static const Op &compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op &compiler_end_op = Op::Get("annotation.compiler_end");
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");

/*!
* \brief The checker that verifies if a Relay program is annotated correctly
Expand Down Expand Up @@ -120,26 +120,11 @@ class AnnotationChecker : public ExprVisitor {
* as TupleGetItemNode index.
* 6) Therefore, functions will be created for all annotated regions. The name for each
* global function is created using "Region" id and the compiler name.
*
* Expected Usecase :
* This pass is intended to run as the last pass in a series of passes as follows :
* 1) Annotate Supported Single Ops - annotated each single op with supported backends.
* We use supported_begin and supported_end annotations.
* 2) Annotate Supported Composite Ops - annotate each composite op (that consist of
* multiple single ops).
* We use supported_begin and supported_end
* annotations.
* 3) Deconflict Pass - Make sure each op is annotated by only a single backend.
* In other words, each Annotated Region will be disjoint.
* We promote supported_* annotations to compiler_* annotations.
* 4) Merge Supported Pass - Merge the disjoint compiler_* Annotated regions belonging
* to same backend.
* 5) *Partition Graph* - Convert Disjoint Annotated Regions into Functions.
*/

class Partitioner : public ExprMutator {
public:
explicit Partitioner(const IRModule &module) : module_(module) {
explicit Partitioner(const IRModule& module) : module_(module) {
for (auto f : module->functions) {
GlobalVar f_var = f.first;
BaseFunc f_func = f.second;
Expand Down Expand Up @@ -282,7 +267,7 @@ class Partitioner : public ExprMutator {

if (region->GetOutputs().size() == 1) {
// If there is only a single output; no need to add a tuplegetitem node
return Call(glob_func, param_expr);
return ret;
} else {
// Add a tuplegetitem node to select this output out of many
auto tuple_get_item_ = TupleGetItem(ret, index);
Expand Down Expand Up @@ -388,7 +373,7 @@ class Partitioner : public ExprMutator {

IRModule Partition() {
auto glob_funcs = module_->functions;
for (const auto &pair : glob_funcs) {
for (const auto& pair : glob_funcs) {
if (auto *fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = Function(func->params,
Expand All @@ -405,9 +390,9 @@ class Partitioner : public ExprMutator {
private:
/*!
* \brief Get the region an expression belongs to
* if its in a region.
* if its in a region.
*/
AnnotatedRegion GetRegion(const Expr &e) {
AnnotatedRegion GetRegion(const Expr& e) {
for (auto sg_set_it : regions_sets_) {
auto sg_set = sg_set_it.first;
AnnotatedRegion sg = sg_set->GetRegion(e);
Expand All @@ -420,9 +405,9 @@ class Partitioner : public ExprMutator {

/*!
* \brief Get the function an expression belongs to
* if its in a region.
* if its in a region.
*/
BaseFunc GetFunc(const Expr &e) {
BaseFunc GetFunc(const Expr& e) {
for (auto sg_set_it : regions_sets_) {
auto sg_set = sg_set_it.first;
auto func = sg_set_it.second;
Expand All @@ -436,10 +421,10 @@ class Partitioner : public ExprMutator {
}

/*!
* \brief Get the index of the argument;
* this is to be used as tuplegetitem idx
*/
int GetArgIdx(AnnotatedRegion sg, const Expr &arg) {
* \brief Get the index of the argument;
* this is to be used as tuplegetitem idx
*/
int GetArgIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetInputs()) {
if (arg == arg_) {
Expand All @@ -452,9 +437,9 @@ class Partitioner : public ExprMutator {

/*!
* \brief Get the index of the return(output);
* this is to be used as tuplegetitem idx
* this is to be used as tuplegetitem idx
*/
int GetRetIdx(AnnotatedRegion sg, const Expr &arg) {
int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetOutputs()) {
if (arg == arg_) {
Expand All @@ -467,20 +452,20 @@ class Partitioner : public ExprMutator {

/*!
* \brief This map maintains the already created function calls.
* This is required in the multi-output scenario, to link rest of the outputs to call
* This is required in the multi-output scenario, to link rest of the outputs to call
*/
std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;

/*!
* \brief This map maintains arguments (of region) visits through visitor patterns.
* Those arguement var and expression will be used to when creating the function.
* Those arguement var and expression will be used to when creating the function.
*/
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>,
ObjectHash, ObjectEqual> region_args;

/*!
* \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it belongs to
* This map maintains the mapping between regionsets and the function it belongs to
*/
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
IRModule module_;
Expand Down

0 comments on commit a4653fd

Please sign in to comment.