-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RELAY] Added a AnnotatedRegion utility class (#5030)
* [RELAY] Added an AnnotatedRegionSet utility class In many of the passes involved in graph partitioning, we need to extract and manipulate annotated regions. This class simplifies the extraction of regions from a relay expression containing region begin and end annotations as well as providing utility functions to query these regions and merge them. Co-authored-by: Ramana Radhakrishnan <[email protected]> Change-Id: Ia912fea0b99f64b6a7197aa6da2347e58f469fbb * Rename fix * Update MakeRegions * Fix __init__ * Indentation * Code style * Remove 'Region' from docs * Overload [] to get region * Use src/dest for MergeRegions * Simplify merge * Tidy const loop vars
- Loading branch information
Showing
5 changed files
with
705 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import | ||
"""Regions used in Relay.""" | ||
|
||
from tvm.runtime import Object | ||
from . import _ffi_api | ||
|
||
|
||
class AnnotatedRegionSet(Object): | ||
"""Class to represent a relay expression split into regions.""" | ||
|
||
def __init__(self, expr, region_begin_op, region_end_op): | ||
"""Construct regions from an expression. | ||
Parameters | ||
---------- | ||
expr : tvm.relay.Expr | ||
The expression from which to construct the regions. | ||
region_begin_op : tvm.relay.Op | ||
The region begin annotation. | ||
region_end_op : tvm.relay.Op | ||
The region end annotation. | ||
""" | ||
self.__init_handle_by_constructor__(_ffi_api.AnnotatedRegionSet, | ||
expr, | ||
region_begin_op, | ||
region_end_op) | ||
|
||
def __len__(self): | ||
return len(self.regions) | ||
|
||
def get_region(self, expr): | ||
"""Get the region an expression belongs to. | ||
Parameters | ||
---------- | ||
expr : tvm.relay.Expr | ||
The expression. | ||
Returns | ||
------- | ||
region | ||
The region containing the expression. | ||
None if not found. | ||
""" | ||
return _ffi_api.GetRegion(self, expr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
#include "annotated_region_set.h" | ||
|
||
#include <tvm/relay/expr.h> | ||
#include <tvm/ir/error.h> | ||
|
||
#include <algorithm> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const { | ||
for (auto candidate : regions_) { | ||
if (candidate->nodes.find(expr) != candidate->nodes.end()) { | ||
return candidate; | ||
} | ||
} | ||
return AnnotatedRegion(nullptr); | ||
} | ||
|
||
void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, | ||
AnnotatedRegion dest) { | ||
if (dest == src) { | ||
return; | ||
} | ||
|
||
// Merge src to dest and erase src. | ||
dest->nodes.insert(src->nodes.begin(), src->nodes.end()); | ||
for (const auto& input : src->ins) { | ||
dest->ins.push_back(input); | ||
} | ||
for (const auto& output : src->outs) { | ||
dest->outs.push_back(output); | ||
} | ||
// if any of the outputs of src are inputs of dest, they become internal nodes | ||
// so remove them from outs | ||
for (const auto& input : dest->ins) { | ||
auto call = Downcast<Call>(input); | ||
auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]); | ||
if (it != src->outs.end()) { | ||
dest->outs.remove(*it); | ||
dest->ins.remove(input); | ||
} | ||
} | ||
regions_.erase(src); | ||
} | ||
|
||
void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) { | ||
auto region2 = GetRegion(expr); | ||
if (region2.defined()) { | ||
MergeRegions(region, region2); | ||
} else { | ||
region->nodes.insert(expr); | ||
} | ||
} | ||
|
||
AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() { | ||
auto ret = regions_.emplace(AnnotatedRegion()); | ||
(*ret.first)->id = region_id_++; | ||
return *ret.first; | ||
} | ||
|
||
class AnnotatedRegionSet::Creator : public ExprVisitor { | ||
public: | ||
Creator(const Op& region_begin_op, const Op& region_end_op) : | ||
begin_op_(region_begin_op), end_op_(region_end_op) {} | ||
|
||
AnnotatedRegionSet Create(const Expr& expr) { | ||
VisitExpr(expr); | ||
return std::move(region_set_); | ||
} | ||
|
||
void VisitExpr_(const CallNode* call) { | ||
auto op_node = call->op.as<OpNode>(); | ||
|
||
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { | ||
// Propagate region to arguments | ||
auto region = region_set_->GetRegion(GetRef<Call>(call)); | ||
if (region.defined()) { | ||
for (auto arg : call->args) { | ||
region_set_->AddToRegion(region, arg); | ||
} | ||
} | ||
} else if (call->op == begin_op_) { | ||
// The annotation node is inserted on edge so it must have only one argument. | ||
CHECK_EQ(call->args.size(), 1U); | ||
|
||
auto region = region_set_->GetRegion(GetRef<Call>(call)); | ||
if (!region.defined()) { | ||
throw Error(ErrorBuilder() | ||
<< "Cannot find the corresponding region for start annotation:\n" | ||
<< AsText(GetRef<Call>(call), false)); | ||
} | ||
region->ins.push_back(GetRef<Call>(call)); | ||
} else { | ||
CHECK_EQ(call->op, end_op_); | ||
// The annotation node is inserted on edge so it must have only one argument. | ||
CHECK_EQ(call->args.size(), 1U); | ||
|
||
// Check if the argument already belongs to a region | ||
auto region = region_set_->GetRegion(call->args[0]); | ||
if (!region.defined()) { | ||
region = region_set_->MakeRegion(); | ||
region->nodes.insert(call->args[0]); | ||
} | ||
region->nodes.insert(GetRef<Call>(call)); | ||
region->outs.push_back(GetRef<Call>(call)); | ||
} | ||
ExprVisitor::VisitExpr_(call); | ||
} | ||
|
||
void VisitExpr_(const TupleNode* op) { | ||
auto region = region_set_->GetRegion(GetRef<Tuple>(op)); | ||
if (region.defined()) { | ||
for (auto field : op->fields) { | ||
region_set_->AddToRegion(region, field); | ||
} | ||
} | ||
ExprVisitor::VisitExpr_(op); | ||
} | ||
|
||
void VisitExpr_(const TupleGetItemNode* g) { | ||
auto region = region_set_->GetRegion(GetRef<TupleGetItem>(g)); | ||
if (region.defined()) { | ||
region_set_->AddToRegion(region, g->tuple); | ||
} | ||
ExprVisitor::VisitExpr_(g); | ||
} | ||
|
||
void VisitExpr_(const FunctionNode* op) { | ||
auto region = region_set_->GetRegion(GetRef<Function>(op)); | ||
if (region.defined()) { | ||
for (auto param : op->params) { | ||
region_set_->AddToRegion(region, param); | ||
} | ||
} | ||
ExprVisitor::VisitExpr_(op); | ||
} | ||
|
||
void VisitExpr_(const LetNode* op) { | ||
auto region = region_set_->GetRegion(GetRef<Let>(op)); | ||
if (region.defined()) { | ||
region_set_->AddToRegion(region, op->var); | ||
region_set_->AddToRegion(region, op->value); | ||
region_set_->AddToRegion(region, op->body); | ||
} | ||
ExprVisitor::VisitExpr_(op); | ||
} | ||
|
||
void VisitExpr_(const IfNode* op) { | ||
auto region = region_set_->GetRegion(GetRef<If>(op)); | ||
if (region.defined()) { | ||
region_set_->AddToRegion(region, op->cond); | ||
region_set_->AddToRegion(region, op->true_branch); | ||
region_set_->AddToRegion(region, op->false_branch); | ||
} | ||
ExprVisitor::VisitExpr_(op); | ||
} | ||
|
||
void VisitExpr_(const RefCreateNode* op) { | ||
auto region = region_set_->GetRegion(GetRef<RefCreate>(op)); | ||
if (region.defined()) { | ||
region_set_->AddToRegion(region, op->value); | ||
} | ||
ExprVisitor::VisitExpr_(op); | ||
} | ||
|
||
void VisitExpr_(const RefReadNode* op) { | ||
auto region = region_set_->GetRegion(GetRef<RefRead>(op)); | ||
if (region.defined()) { | ||
region_set_->AddToRegion(region, op->ref); | ||
} | ||
ExprVisitor::VisitExpr_(op); | ||
} | ||
|
||
void VisitExpr_(const RefWriteNode* op) { | ||
auto region = region_set_->GetRegion(GetRef<RefWrite>(op)); | ||
if (region.defined()) { | ||
region_set_->AddToRegion(region, op->ref); | ||
} | ||
ExprVisitor::VisitExpr_(op); | ||
} | ||
|
||
private: | ||
/*! \brief The region set being constructed.*/ | ||
AnnotatedRegionSet region_set_; | ||
/*! \brief Region 'begin' annotation operator. */ | ||
const Op begin_op_; | ||
/*! \brief Region 'end' annotation operator. */ | ||
const Op end_op_; | ||
}; | ||
|
||
AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end) { | ||
return Creator(begin, end).Create(expr); | ||
} | ||
|
||
TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode); | ||
TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode); | ||
|
||
TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet") | ||
.set_body_typed([](Expr expr, Op begin, Op end) { | ||
return AnnotatedRegionSet::Create(expr, begin, end); | ||
}); | ||
|
||
TVM_REGISTER_GLOBAL("relay.analysis.GetRegion") | ||
.set_body_typed([](AnnotatedRegionSet region_set, Expr expr) { | ||
return region_set->GetRegion(expr); | ||
}); | ||
|
||
|
||
} // namespace relay | ||
} // namespace tvm |
Oops, something went wrong.