diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py index e5185eaf42b4..a1833c3c08b2 100644 --- a/python/tvm/relay/analysis/__init__.py +++ b/python/tvm/relay/analysis/__init__.py @@ -19,6 +19,9 @@ # Analysis passes from .analysis import * +# Annotations +from .annotated_regions import AnnotatedRegionSet + # Call graph from . import call_graph from .call_graph import CallGraph diff --git a/python/tvm/relay/analysis/annotated_regions.py b/python/tvm/relay/analysis/annotated_regions.py new file mode 100644 index 000000000000..fc8e85ac8743 --- /dev/null +++ b/python/tvm/relay/analysis/annotated_regions.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import +"""Regions used in Relay.""" + +from tvm.runtime import Object +from . import _ffi_api + + +class AnnotatedRegionSet(Object): + """Class to represent a relay expression split into regions.""" + + def __init__(self, expr, region_begin_op, region_end_op): + """Construct regions from an expression. + + Parameters + ---------- + expr : tvm.relay.Expr + The expression from which to construct the regions. + region_begin_op : tvm.relay.Op + The region begin annotation. + region_end_op : tvm.relay.Op + The region end annotation. + + """ + self.__init_handle_by_constructor__(_ffi_api.AnnotatedRegionSet, + expr, + region_begin_op, + region_end_op) + + def __len__(self): + return len(self.regions) + + def get_region(self, expr): + """Get the region an expression belongs to. + + Parameters + ---------- + expr : tvm.relay.Expr + The expression. + + Returns + ------- + region + The region containing the expression. + None if not found. + """ + return _ffi_api.GetRegion(self, expr) diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc new file mode 100644 index 000000000000..f8e951bac780 --- /dev/null +++ b/src/relay/analysis/annotated_region_set.cc @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "annotated_region_set.h" + +#include +#include + +#include +#include +#include + + +namespace tvm { +namespace relay { + +AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const { + for (auto candidate : regions_) { + if (candidate->nodes.find(expr) != candidate->nodes.end()) { + return candidate; + } + } + return AnnotatedRegion(nullptr); +} + +void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, + AnnotatedRegion dest) { + if (dest == src) { + return; + } + + // Merge src to dest and erase src. + dest->nodes.insert(src->nodes.begin(), src->nodes.end()); + for (const auto& input : src->ins) { + dest->ins.push_back(input); + } + for (const auto& output : src->outs) { + dest->outs.push_back(output); + } + // if any of the outputs of src are inputs of dest, they become internal nodes + // so remove them from outs + for (const auto& input : dest->ins) { + auto call = Downcast(input); + auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]); + if (it != src->outs.end()) { + dest->outs.remove(*it); + dest->ins.remove(input); + } + } + regions_.erase(src); +} + +void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) { + auto region2 = GetRegion(expr); + if (region2.defined()) { + MergeRegions(region, region2); + } else { + region->nodes.insert(expr); + } +} + +AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() { + auto ret = regions_.emplace(AnnotatedRegion()); + (*ret.first)->id = region_id_++; + return *ret.first; +} + +class AnnotatedRegionSet::Creator : public ExprVisitor { + public: + Creator(const Op& region_begin_op, const Op& region_end_op) : + begin_op_(region_begin_op), end_op_(region_end_op) {} + + AnnotatedRegionSet Create(const Expr& expr) { + VisitExpr(expr); + return std::move(region_set_); + } + + void VisitExpr_(const CallNode* call) { + auto op_node = call->op.as(); + + if (op_node == nullptr || call->attrs.as() == nullptr) { + // Propagate region to arguments + auto region = region_set_->GetRegion(GetRef(call)); + if (region.defined()) { + for (auto arg : call->args) { + region_set_->AddToRegion(region, arg); + } + } + } else if (call->op == begin_op_) { + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + auto region = region_set_->GetRegion(GetRef(call)); + if (!region.defined()) { + throw Error(ErrorBuilder() + << "Cannot find the corresponding region for start annotation:\n" + << AsText(GetRef(call), false)); + } + region->ins.push_back(GetRef(call)); + } else { + CHECK_EQ(call->op, end_op_); + // The annotation node is inserted on edge so it must have only one argument. + CHECK_EQ(call->args.size(), 1U); + + // Check if the argument already belongs to a region + auto region = region_set_->GetRegion(call->args[0]); + if (!region.defined()) { + region = region_set_->MakeRegion(); + region->nodes.insert(call->args[0]); + } + region->nodes.insert(GetRef(call)); + region->outs.push_back(GetRef(call)); + } + ExprVisitor::VisitExpr_(call); + } + + void VisitExpr_(const TupleNode* op) { + auto region = region_set_->GetRegion(GetRef(op)); + if (region.defined()) { + for (auto field : op->fields) { + region_set_->AddToRegion(region, field); + } + } + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const TupleGetItemNode* g) { + auto region = region_set_->GetRegion(GetRef(g)); + if (region.defined()) { + region_set_->AddToRegion(region, g->tuple); + } + ExprVisitor::VisitExpr_(g); + } + + void VisitExpr_(const FunctionNode* op) { + auto region = region_set_->GetRegion(GetRef(op)); + if (region.defined()) { + for (auto param : op->params) { + region_set_->AddToRegion(region, param); + } + } + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const LetNode* op) { + auto region = region_set_->GetRegion(GetRef(op)); + if (region.defined()) { + region_set_->AddToRegion(region, op->var); + region_set_->AddToRegion(region, op->value); + region_set_->AddToRegion(region, op->body); + } + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const IfNode* op) { + auto region = region_set_->GetRegion(GetRef(op)); + if (region.defined()) { + region_set_->AddToRegion(region, op->cond); + region_set_->AddToRegion(region, op->true_branch); + region_set_->AddToRegion(region, op->false_branch); + } + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const RefCreateNode* op) { + auto region = region_set_->GetRegion(GetRef(op)); + if (region.defined()) { + region_set_->AddToRegion(region, op->value); + } + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const RefReadNode* op) { + auto region = region_set_->GetRegion(GetRef(op)); + if (region.defined()) { + region_set_->AddToRegion(region, op->ref); + } + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const RefWriteNode* op) { + auto region = region_set_->GetRegion(GetRef(op)); + if (region.defined()) { + region_set_->AddToRegion(region, op->ref); + } + ExprVisitor::VisitExpr_(op); + } + + private: + /*! \brief The region set being constructed.*/ + AnnotatedRegionSet region_set_; + /*! \brief Region 'begin' annotation operator. */ + const Op begin_op_; + /*! \brief Region 'end' annotation operator. */ + const Op end_op_; +}; + +AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end) { + return Creator(begin, end).Create(expr); +} + +TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode); +TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode); + +TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet") +.set_body_typed([](Expr expr, Op begin, Op end) { + return AnnotatedRegionSet::Create(expr, begin, end); +}); + +TVM_REGISTER_GLOBAL("relay.analysis.GetRegion") +.set_body_typed([](AnnotatedRegionSet region_set, Expr expr) { + return region_set->GetRegion(expr); +}); + + +} // namespace relay +} // namespace tvm diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h new file mode 100644 index 000000000000..c5db2cc3d202 --- /dev/null +++ b/src/relay/analysis/annotated_region_set.h @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/pass/annotated_region_set.h + * \brief Define data structures to extract and manipulate regions from + * a relay function. Regions are denoted by region_begin and region_end + * annotations that exist on all the input and output edges of the region. + */ + +#ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ +#define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +class AnnotatedRegion; +class AnnotatedRegionSet; + +class AnnotatedRegionNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("id", &id); + Array nodes_array(nodes.begin(), nodes.end()); + v->Visit("nodes", &nodes_array); + Array args_array(ins.begin(), ins.end()); + v->Visit("args", &args_array); + Array rets_array(outs.begin(), outs.end()); + v->Visit("rets", &rets_array); + } + + /*! \brief Get the region ID. */ + int GetID() const { + return id; + } + + /*! \brief Get the region's inputs. */ + std::list GetInputs() const { + return ins; + } + + /*! \brief Get the region's outputs. */ + std::list GetOutputs() const { + return outs; + } + + /*! \brief Get the region's nodes. */ + std::unordered_set GetNodes() const { + return nodes; + } + + static constexpr const char* _type_key = "relay.AnnotatedRegion"; + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object); + + protected: + /*! \brief The region ID. */ + int id{-1}; + /*! \brief The inputs to this region. */ + std::list ins; + /*! \brief The outputs of this region */ + std::list outs; + /*! \brief Nodes in this region. */ + std::unordered_set nodes; + + friend class AnnotatedRegionSet; + friend class AnnotatedRegionSetNode; +}; + +/*! + * \brief An object to hold the properties of a region as used by the + * AnnotatedRegionSet class. This should be considered read-only. +*/ +class AnnotatedRegion : public ObjectRef { + public: + AnnotatedRegion() { + auto n = make_object(); + data_ = std::move(n); + } + + /*! + * \brief Construct from an object pointer. + * \param n The object pointer. + */ + explicit AnnotatedRegion(ObjectPtr n) : ObjectRef(n) {} + + /*! \return Mutable pointers to the node. */ + AnnotatedRegionNode* operator->() const { + auto* ptr = get_mutable(); + CHECK(ptr != nullptr); + return static_cast(ptr); + } +}; + +class AnnotatedRegionSetNode : public Object { + using UnorderedRegionSet = + std::unordered_set; + // Create iterator alias for a RegionSet object. + using iterator = UnorderedRegionSet::iterator; + using const_iterator = UnorderedRegionSet::const_iterator; + + public: + /*! \brief Default constructor. */ + AnnotatedRegionSetNode() = default; + + /*! \return The begin iterator */ + iterator begin() { + return regions_.begin(); + } + /*! \return The end iterator */ + iterator end() { + return regions_.end(); + } + /*! \return The const begin iterator */ + const_iterator begin() const { + return regions_.begin(); + } + /*! \return The const end iterator */ + const_iterator end() const { + return regions_.end(); + } + + /*! + * \brief Get the region that an expression belongs to. + * + * \param expr Which expr to get the region for. + * + * \return A pointer to the region, nullptr if the expression + * doesn't belong to a region. + */ + AnnotatedRegion GetRegion(const Expr& expr) const; + + /*! + * \brief Merge src region into dest region. + * + * \param src The region to merge - will be erased. + * \param dest The region into which src will be merged. + */ + void MergeRegions(AnnotatedRegion src, AnnotatedRegion dest); + + void VisitAttrs(AttrVisitor* v) { + Array regions_array(regions_.begin(), regions_.end()); + v->Visit("regions", ®ions_array); + } + + static constexpr const char* _type_key = "relay.AnnotatedRegionSet"; + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionSetNode, Object); + + private: + /*! + * \brief Add an expression to a region. + * + * \param region The region to add the expression to. + * \param expr The expression. + */ + void AddToRegion(AnnotatedRegion region, const Expr& expr); + + /*! + * \brief Make a new region. + * + * \return The new region. + */ + AnnotatedRegion MakeRegion(); + + std::unordered_set regions_; + /*! \brief The next region ID to assign. */ + int region_id_{0}; + + friend class AnnotatedRegionSet; +}; + +/*! + * \brief A class to hold a set of regions produced from a relay expression + * that contains 'region_begin' and 'region_end' style annotations. The + * regions should be disjoint. The class provides both a method to construct + * the region set of a given relay expression as well as additional methods + * to update and query regions. + */ +class AnnotatedRegionSet : public ObjectRef { + using UnorderedRegionSet = + std::unordered_set; + // Create iterator alias for a RegionSet object. + using iterator = UnorderedRegionSet::iterator; + using const_iterator = UnorderedRegionSet::const_iterator; + + public: + AnnotatedRegionSet() { + auto n = make_object(); + data_ = std::move(n); + } + + /*! + * \brief Construct from an object pointer. + * + * \param n The object pointer. + */ + explicit AnnotatedRegionSet(ObjectPtr n) : ObjectRef(n) {} + + /*! \return The begin iterator. */ + iterator begin() { + auto* n = operator->(); + CHECK(n); + return n->begin(); + } + /*! \return The end iterator. */ + iterator end() { + auto* n = operator->(); + CHECK(n); + return n->end(); + } + /*! \return The begin iterator. */ + const_iterator begin() const { + const auto* n = operator->(); + CHECK(n); + return n->begin(); + } + /*! \return The end iterator. */ + const_iterator end() const { + const auto *n = operator->(); + CHECK(n); + return n->end(); + } + + /*! \return mutable pointers to the node. */ + AnnotatedRegionSetNode* operator->() const { + auto* ptr = get_mutable(); + CHECK(ptr != nullptr); + return static_cast(ptr); + } + + /*! \return The region an expression belongs to. */ + AnnotatedRegion operator[](const Expr& expr) { + const auto *n = operator->(); + CHECK(n); + return n->GetRegion(expr); + } + + /*! \brief Create a RegionSet from a relay expression. + * + * \param expr The relay expr from which to construct the set. + * \param begin Region begin annotation operator. + * \param end Region end annotation operator. + * + * \return The created RegionSet for the expression. + */ + static AnnotatedRegionSet Create(const Expr& expr, + const Op& begin, + const Op& end); + + private: + /*! \brief Helper class to construct a RegionSet from an expr.*/ + class Creator; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ diff --git a/tests/python/relay/test_annotated_regions.py b/tests/python/relay/test_annotated_regions.py new file mode 100644 index 000000000000..a24639867091 --- /dev/null +++ b/tests/python/relay/test_annotated_regions.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +from tvm import relay +from tvm.relay.op.annotation import compiler_begin, compiler_end + + +def check_region(region_set, args, nodes, rets): + region = region_set.get_region(args[0]) + assert region + assert set(args) == set(region.args) + assert set(nodes) == set(region.nodes) + assert set(rets) == set(region.rets) + + +def test_region_set_creator_diamond(): + data = relay.var('data', shape=(10, 10)) + cb_1 = compiler_begin(data, 'test_target') + O_1 = relay.abs(cb_1) + ce_1 = compiler_end(O_1, 'test_target') + ce_2 = compiler_end(O_1, 'test_target') + cb_2 = compiler_begin(ce_1, 'test_target') + O_2 = relay.nn.relu(cb_2) + ce_3 = compiler_end(O_2, 'test_target') + cb_d = compiler_begin(ce_2, "default") + X = relay.tanh(cb_d) + ce_d = compiler_end(X, 'default') + cb_3 = compiler_begin(ce_3, 'test_target') + cb_4 = compiler_begin(ce_d, 'test_target') + O_3 = relay.add(cb_3, cb_4) + ce_4 = compiler_end(O_3, 'test_target') + diamond = relay.Function([data], ce_4) + + region_set = relay.analysis.AnnotatedRegionSet(diamond, + relay.op.get("annotation.compiler_begin"), + relay.op.get("annotation.compiler_end")) + assert len(region_set) == 4 + check_region( + region_set, + [cb_1], + [cb_1, O_1, ce_1, ce_2], + [ce_1, ce_2], + ) + check_region( + region_set, + [cb_2], + [cb_2, O_2, ce_3], + [ce_3], + ) + check_region( + region_set, + [cb_d], + [cb_d, X, ce_d], + [ce_d], + ) + check_region( + region_set, + [cb_3, cb_4], + [cb_3, cb_4, O_3, ce_4], + [ce_4], + ) + + +def test_region_set_creator_merged(): + data = relay.var('data', shape=(10, 10)) + cb_1 = compiler_begin(data, 'test_target') + O_1 = relay.abs(cb_1) + ce_2 = compiler_end(O_1, 'test_target') + O_2 = relay.nn.relu(O_1) + ce_3 = compiler_end(O_2, 'test_target') + cb_d = compiler_begin(ce_2, "default") + X = relay.tanh(cb_d) + ce_d = compiler_end(X, 'default') + cb_3 = compiler_begin(ce_3, 'test_target') + cb_4 = compiler_begin(ce_d, 'test_target') + O_3 = relay.add(cb_3, cb_4) + ce_4 = compiler_end(O_3, 'test_target') + merged = relay.Function([data], ce_4) + + region_set = relay.analysis.AnnotatedRegionSet(merged, + relay.op.get("annotation.compiler_begin"), + relay.op.get("annotation.compiler_end")) + assert len(region_set) == 3 + check_region( + region_set, + [cb_1], + [cb_1, O_1, O_2, ce_2, ce_3], + [ce_2, ce_3], + ) + check_region( + region_set, + [cb_d], + [cb_d, X, ce_d], + [ce_d], + ) + check_region( + region_set, + [cb_3, cb_4], + [cb_3, cb_4, O_3, ce_4], + [ce_4], + ) + + +if __name__ == "__main__": + test_region_set_creator_diamond() + test_region_set_creator_merged() +