Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed May 1, 2020
1 parent e8ffa89 commit 3a6f409
Showing 1 changed file with 47 additions and 52 deletions.
99 changes: 47 additions & 52 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,17 @@ struct RegionFuncMetadata {
*/
std::vector<std::pair<Var, Expr>> args;

/*! \brief Map from each region output expr node to its output index and TupleGetItem node. */
std::unordered_map<Expr, std::pair<int, TupleGetItem>, ObjectHash, ObjectEqual> out_expr_indices;
/*! \brief Map from each region output expr (compiler end) node to
* the corresponding function output expr.
*/
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> region_func_out;

/*! \brief Map from each input expression to the corresponding input variable of this region.
* This cache is used to make sure a region function will not have duplicated inputs even
* it refers the same expr multuple times.
/*! \brief Map from each region input expression (compiler begin) to
* the corresponding function input variable. This cache is used to make sure
* a region function will not have duplicated inputs even if it refers to
* the same expr multiple times.
*/
std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> in_expr_vars;
std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> region_func_in;
};

/*! \brief This class partitions the expr labeled with begin and end annotations
Expand Down Expand Up @@ -123,11 +126,6 @@ class Partitioner : public MixedModeMutator {
auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
partitioning::compiler_end_op);
regions_sets_[region_set] = f_func;

// Initial region function metadata.
for (auto region : region_set) {
region_func_meta_[region];
}
}
}

Expand Down Expand Up @@ -156,8 +154,8 @@ class Partitioner : public MixedModeMutator {
int index = GetArgIdx(sg, GetRef<Call>(call));
CHECK_NE(index, -1);

if (region_func_meta_[sg].in_expr_vars.count(parent)) {
return region_func_meta_[sg].in_expr_vars[parent];
if (region_func_meta_[sg].region_func_in.count(parent)) {
return region_func_meta_[sg].region_func_in[parent];
} else {
// The type of the created variable is the same as the compiler_begin
// node.
Expand All @@ -172,7 +170,7 @@ class Partitioner : public MixedModeMutator {
region_func_meta_[sg].args.end()) {
region_func_meta_[sg].args.push_back(cand);
}
region_func_meta_[sg].in_expr_vars[parent] = var;
region_func_meta_[sg].region_func_in[parent] = var;
return std::move(var);
}
} else {
Expand All @@ -195,8 +193,7 @@ class Partitioner : public MixedModeMutator {
// If multiple outputs are there, a tuple node is inserted at the end.

if (!region_func_meta_[region].func_call.defined()) {
// First time this region is encountered in the traversal.
// Creating the function.
// First time this region is encountered in the traversal. Creating the function.
CreateFunction(region, call);
}
// Retrieve this particular output of function.
Expand Down Expand Up @@ -266,21 +263,22 @@ class Partitioner : public MixedModeMutator {
}

/*!
* \brief This function is called first time that we encounter a compiler_end
* node to create the function for the subgraph.
* \brief Create a function and its function call for the given region. If the function has
* multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes
* will be created to serve output consumers.
*/
void CreateFunction(AnnotatedRegion region, const CallNode* call) {
void CreateFunction(AnnotatedRegion region, const CallNode* end_node) {
// Create fields which is a unique list of outputs.
Array<Expr> fields;
int i = 0;
for (auto ret : region->GetOutputs()) {
auto ret_node = Downcast<Call>(ret)->args[0];
std::unordered_map<Expr, int, ObjectHash, ObjectEqual> out_expr_to_idx;
int out_idx = 0;
for (auto region_end_node : region->GetOutputs()) {
auto ret_node = Downcast<Call>(region_end_node)->args[0];
// Don't duplicate outputs.
if (!region_func_meta_[region].out_expr_indices.count(ret_node)) {
if (!out_expr_to_idx.count(ret_node)) {
auto ret_expr = MixedModeMutator::VisitExpr(ret_node);
fields.push_back(ret_expr);
region_func_meta_[region].out_expr_indices[ret_node] = {i, TupleGetItem()};
i++;
out_expr_to_idx[ret_node] = out_idx++;
}
}

Expand Down Expand Up @@ -309,13 +307,13 @@ class Partitioner : public MixedModeMutator {
if (fields.size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
Function(params, fields[0], end_node->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}

std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string target = end_node->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());

global_region_func =
Expand All @@ -340,9 +338,24 @@ class Partitioner : public MixedModeMutator {
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);

// The return type of callnode is the same as the type of the compiler_end node.
auto ret = Call(glob_func, param_expr);
region_func_meta_[region].func_call = ret;
// Create a call node for the function.
auto call = Call(glob_func, param_expr);
region_func_meta_[region].func_call = call;

// Create output expr(s) for the function call.
if (out_expr_to_idx.size() == 1) {
// Single output direcly uses the call node as the output expr.
region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = call;
} else {
// Multiple outptus need to create TupleGetItem nodes as output exprs.
for (auto pair : out_expr_to_idx) {
Expr region_out_expr = pair.first; // The arg of a compiler end node of this region.
int idx = pair.second; // Corresponding function output tuple index.
auto tuple_get_item = TupleGetItem(call, idx);
tuple_get_item->checked_type_ = region_out_expr->checked_type_;
region_func_meta_[region].region_func_out[region_out_expr] = tuple_get_item;
}
}
}

/*!
Expand All @@ -352,26 +365,8 @@ class Partitioner : public MixedModeMutator {
*/
Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
Expr arg = Downcast<Call>(end_arg)->args[0];
// Function has one output.
if (region_func_meta_[region].out_expr_indices.size() == 1) {
return region_func_meta_[region].func_call;
}

// Function has multiple outputs.
// Use already made TupleGetItem.
if (region_func_meta_[region].out_expr_indices.count(arg) &&
region_func_meta_[region].out_expr_indices[arg].second.defined()) {
return region_func_meta_[region].out_expr_indices[arg].second;
}
// Create new TupleGetItem.
CHECK(region_func_meta_[region].out_expr_indices.count(arg));
int index = region_func_meta_[region].out_expr_indices[arg].first;

auto func_call = region_func_meta_[region].func_call;
auto tuple_get_item_ = TupleGetItem(func_call, index);
tuple_get_item_->checked_type_ = arg->checked_type_;
region_func_meta_[region].out_expr_indices[arg].second = tuple_get_item_;
return std::move(tuple_get_item_);
CHECK(region_func_meta_[region].region_func_out.count(arg));
return region_func_meta_[region].region_func_out[arg];
}

/*! \brief Map from each region to its metadata of the generated function. */
Expand All @@ -388,7 +383,7 @@ class Partitioner : public MixedModeMutator {
IRModule module_;
};

IRModule Remove(IRModule module) {
IRModule RemoveDefaultAnnotations(IRModule module) {
class DefaultRemover : public ExprRewriter {
public:
DefaultRemover() = default;
Expand Down Expand Up @@ -427,7 +422,7 @@ Pass PartitionGraph() {
// TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
// by treating them as un-annotated, but we don't have it yet. This workaround pass removes
// all "default" annotations and should be deleted in the future.
auto new_m = partitioning::Remove(m);
auto new_m = partitioning::RemoveDefaultAnnotations(m);
return partitioning::Partitioner(new_m).Partition();
};
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
Expand Down

0 comments on commit 3a6f409

Please sign in to comment.