Skip to content

Commit

Permalink
[RELAY] Added a AnnotatedRegion utility class (#5030)
Browse files Browse the repository at this point in the history
* [RELAY] Added an AnnotatedRegionSet utility class

In many of the passes involved in graph partitioning,
we need to extract and manipulate annotated regions.
This class simplifies the extraction of regions from a relay
expression containing region begin and end annotations
as well as providing utility functions to query these
regions and merge them.

Co-authored-by: Ramana Radhakrishnan  <[email protected]>

Change-Id: Ia912fea0b99f64b6a7197aa6da2347e58f469fbb

* Rename fix

* Update MakeRegions

* Fix __init__

* Indentation

* Code style

* Remove 'Region' from docs

* Overload [] to get region

* Use src/dest for MergeRegions

* Simplify merge

* Tidy const loop vars
  • Loading branch information
mbaret authored Mar 26, 2020
1 parent 314f31b commit b5ec071
Show file tree
Hide file tree
Showing 5 changed files with 705 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/tvm/relay/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
# Analysis passes
from .analysis import *

# Annotations
from .annotated_regions import AnnotatedRegionSet

# Call graph
from . import call_graph
from .call_graph import CallGraph
Expand Down
62 changes: 62 additions & 0 deletions python/tvm/relay/analysis/annotated_regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Regions used in Relay."""

from tvm.runtime import Object
from . import _ffi_api


class AnnotatedRegionSet(Object):
"""Class to represent a relay expression split into regions."""

def __init__(self, expr, region_begin_op, region_end_op):
"""Construct regions from an expression.
Parameters
----------
expr : tvm.relay.Expr
The expression from which to construct the regions.
region_begin_op : tvm.relay.Op
The region begin annotation.
region_end_op : tvm.relay.Op
The region end annotation.
"""
self.__init_handle_by_constructor__(_ffi_api.AnnotatedRegionSet,
expr,
region_begin_op,
region_end_op)

def __len__(self):
return len(self.regions)

def get_region(self, expr):
"""Get the region an expression belongs to.
Parameters
----------
expr : tvm.relay.Expr
The expression.
Returns
-------
region
The region containing the expression.
None if not found.
"""
return _ffi_api.GetRegion(self, expr)
233 changes: 233 additions & 0 deletions src/relay/analysis/annotated_region_set.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "annotated_region_set.h"

#include <tvm/relay/expr.h>
#include <tvm/ir/error.h>

#include <algorithm>
#include <unordered_map>
#include <vector>


namespace tvm {
namespace relay {

AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
for (auto candidate : regions_) {
if (candidate->nodes.find(expr) != candidate->nodes.end()) {
return candidate;
}
}
return AnnotatedRegion(nullptr);
}

void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
AnnotatedRegion dest) {
if (dest == src) {
return;
}

// Merge src to dest and erase src.
dest->nodes.insert(src->nodes.begin(), src->nodes.end());
for (const auto& input : src->ins) {
dest->ins.push_back(input);
}
for (const auto& output : src->outs) {
dest->outs.push_back(output);
}
// if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs
for (const auto& input : dest->ins) {
auto call = Downcast<Call>(input);
auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
if (it != src->outs.end()) {
dest->outs.remove(*it);
dest->ins.remove(input);
}
}
regions_.erase(src);
}

void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) {
auto region2 = GetRegion(expr);
if (region2.defined()) {
MergeRegions(region, region2);
} else {
region->nodes.insert(expr);
}
}

AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() {
auto ret = regions_.emplace(AnnotatedRegion());
(*ret.first)->id = region_id_++;
return *ret.first;
}

class AnnotatedRegionSet::Creator : public ExprVisitor {
public:
Creator(const Op& region_begin_op, const Op& region_end_op) :
begin_op_(region_begin_op), end_op_(region_end_op) {}

AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}

void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>();

if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
// Propagate region to arguments
auto region = region_set_->GetRegion(GetRef<Call>(call));
if (region.defined()) {
for (auto arg : call->args) {
region_set_->AddToRegion(region, arg);
}
}
} else if (call->op == begin_op_) {
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);

auto region = region_set_->GetRegion(GetRef<Call>(call));
if (!region.defined()) {
throw Error(ErrorBuilder()
<< "Cannot find the corresponding region for start annotation:\n"
<< AsText(GetRef<Call>(call), false));
}
region->ins.push_back(GetRef<Call>(call));
} else {
CHECK_EQ(call->op, end_op_);
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);

// Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) {
region = region_set_->MakeRegion();
region->nodes.insert(call->args[0]);
}
region->nodes.insert(GetRef<Call>(call));
region->outs.push_back(GetRef<Call>(call));
}
ExprVisitor::VisitExpr_(call);
}

void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op));
if (region.defined()) {
for (auto field : op->fields) {
region_set_->AddToRegion(region, field);
}
}
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const TupleGetItemNode* g) {
auto region = region_set_->GetRegion(GetRef<TupleGetItem>(g));
if (region.defined()) {
region_set_->AddToRegion(region, g->tuple);
}
ExprVisitor::VisitExpr_(g);
}

void VisitExpr_(const FunctionNode* op) {
auto region = region_set_->GetRegion(GetRef<Function>(op));
if (region.defined()) {
for (auto param : op->params) {
region_set_->AddToRegion(region, param);
}
}
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const LetNode* op) {
auto region = region_set_->GetRegion(GetRef<Let>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->var);
region_set_->AddToRegion(region, op->value);
region_set_->AddToRegion(region, op->body);
}
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const IfNode* op) {
auto region = region_set_->GetRegion(GetRef<If>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->cond);
region_set_->AddToRegion(region, op->true_branch);
region_set_->AddToRegion(region, op->false_branch);
}
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const RefCreateNode* op) {
auto region = region_set_->GetRegion(GetRef<RefCreate>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->value);
}
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const RefReadNode* op) {
auto region = region_set_->GetRegion(GetRef<RefRead>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->ref);
}
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const RefWriteNode* op) {
auto region = region_set_->GetRegion(GetRef<RefWrite>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->ref);
}
ExprVisitor::VisitExpr_(op);
}

private:
/*! \brief The region set being constructed.*/
AnnotatedRegionSet region_set_;
/*! \brief Region 'begin' annotation operator. */
const Op begin_op_;
/*! \brief Region 'end' annotation operator. */
const Op end_op_;
};

AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end) {
return Creator(begin, end).Create(expr);
}

TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode);
TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode);

TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet")
.set_body_typed([](Expr expr, Op begin, Op end) {
return AnnotatedRegionSet::Create(expr, begin, end);
});

TVM_REGISTER_GLOBAL("relay.analysis.GetRegion")
.set_body_typed([](AnnotatedRegionSet region_set, Expr expr) {
return region_set->GetRegion(expr);
});


} // namespace relay
} // namespace tvm
Loading

0 comments on commit b5ec071

Please sign in to comment.