From 6a7e7f3e8315dff9170e0a689d97c09912167cd3 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 21 Feb 2020 19:14:45 +0000 Subject: [PATCH] call graph for relay --- python/tvm/relay/__init__.py | 4 + python/tvm/relay/call_graph.py | 143 ++++++++ src/relay/pass/call_graph.cc | 339 +++++++++++++++++ src/relay/pass/call_graph.h | 509 ++++++++++++++++++++++++++ tests/python/relay/test_call_graph.py | 150 ++++++++ 5 files changed, 1145 insertions(+) create mode 100644 python/tvm/relay/call_graph.py create mode 100644 src/relay/pass/call_graph.cc create mode 100644 src/relay/pass/call_graph.h create mode 100644 tests/python/relay/test_call_graph.py diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 0df3747a93b1..2ad210e7d109 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -19,6 +19,7 @@ import os from sys import setrecursionlimit from ..api import register_func +from . import call_graph from . import base from . import ty from . import expr @@ -141,3 +142,6 @@ # Feature Feature = feature.Feature + +# CallGraph +CallGraph = call_graph.CallGraph diff --git a/python/tvm/relay/call_graph.py b/python/tvm/relay/call_graph.py new file mode 100644 index 000000000000..104ccda8d585 --- /dev/null +++ b/python/tvm/relay/call_graph.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import +"""Call graph used in Relay.""" + +from tvm.ir import IRModule +from .base import Object +from .expr import GlobalVar +from . import _analysis + + +class CallGraph(Object): + """Class to represent a call graph.""" + + def __init__(self, module): + """Construct a call graph. + + Parameters + ---------- + module : tvm.ir.IRModule + The IR module used to create a call graph + + Returns + ------- + call_graph: CallGraph + A constructed call graph. + """ + self.__init_handle_by_constructor__(_analysis.CallGraph, module) + + @property + def module(self): + """Return the contained Relay IR module. + + Parameters + ---------- + None + + Returns + ------- + ret : tvm.ir.IRModule + The contained IRModule + """ + return _analysis.GetModule(self) + + def ref_count(self, var): + """Return the number of references to the global var + + Parameters + ---------- + var : Union[String, tvm.relay.GlobalVar] + + Returns + ------- + ret : int + The number reference to the global var + """ + var = self._get_global_var(var) + return _analysis.GetRefCountGlobalVar(self, var) + + def global_call_count(self, var): + """Return the number of global function calls from a given global var. + + Parameters + ---------- + var : Union[String, tvm.relay.GlobalVar] + + Returns + ------- + ret : int + The number of global function calls from the given var. + """ + var = self._get_global_var(var) + return _analysis.GetGlobalVarCallCount(self, var) + + def is_recursive(self, var): + """Return the number of global function calls from a given global var. + + Parameters + ---------- + var : Union[String, tvm.relay.GlobalVar] + + Returns + ------- + ret : Boolean + If the function corresponding to var is recurisve. + """ + var = self._get_global_var(var) + return _analysis.IsRecursive(self, var) + + def _get_global_var(self, var): + """Return the global var using a given name or GlobalVar. + + Parameters + ---------- + var : Union[String, tvm.relay.GlobalVar] + + Returns + ------- + ret : tvm.relay.GlobalVar + The global var. + """ + if isinstance(var, str): + mod = self.module + var = mod.get_global_var(var) + + if isinstance(var, GlobalVar): + return var + else: + raise TypeError("var should be either a string or GlobalVar") + + def __str__(self): + """Print the call graph in the topological order.""" + return _analysis.PrintCallGraph(self) + + def __getitem__(self, var): + """Lookup a call graph of a global function by name or by variable. + + Parameters + ---------- + var: Union[String, tvm.relay.GlobalVar] + The name or global variable. + + Returns + ------- + ret : String + The call graph represented in string. + """ + var = self._get_global_var(var) + return _analysis.GetCallGraphGlobalVar(self, var) diff --git a/src/relay/pass/call_graph.cc b/src/relay/pass/call_graph.cc new file mode 100644 index 000000000000..5a4b6a91c04a --- /dev/null +++ b/src/relay/pass/call_graph.cc @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/pass/call_graph.cc + * \brief Implementation of APIs to handle the call graph of a Relay module. + */ + +#include "call_graph.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +CallGraph::CallGraph(IRModule module) { + auto n = make_object(); + n->module = std::move(module); + auto gvar_funcs = n->module->functions; + for (const auto& it : gvar_funcs) { + if (const auto* fn = it.second.as()) { + auto func = GetRef(fn); + // Add the global function to gradually build up the call graph. + n->AddToCallGraph(it.first, func); + } + } + data_ = std::move(n); +} + +void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) { + CHECK(func.defined() && gv.defined()); + // Add the current global function as an entry to the call grpah. + CallGraphEntryNode* cg_node = LookupGlobalVar(gv); + + // Only GlobalVar nodes need to be handled in a function. It indicates that + // the global function of a callee is called by the function that is being + // processed. An edge will be added from the current global function, cg_node, + // to the node that contains the found callee GlobalVarNode. + // + // This is the major overhead for constructing a call graph because the + // post-order visitor will visit each AST node of the current function to + // figure out the dependencies between functions. + PostOrderVisit(func, [&](const Expr& expr) { + if (const GlobalVarNode* gvn = expr.as()) { + auto callee = GetRef(gvn); + cg_node->AddCalledGlobal(LookupGlobalVar(callee)); + } + }); +} + +const CallGraphEntryNode* CallGraphNode::operator[](const GlobalVar& gv) const { + const_iterator cit = call_graph_.find(gv); + CHECK(cit != call_graph_.end()) + << "GlobalVar " << gv->name_hint << " not found in the call graph!"; + return cit->second.get(); +} + +CallGraphEntryNode* CallGraphNode::operator[](const GlobalVar& gv) { + const_iterator cit = call_graph_.find(gv); + CHECK(cit != call_graph_.end()) + << "GlobalVar " << gv->name_hint << " not found in the call graph!"; + return cit->second.get(); +} + +// Query the existence of a GlobalVar in the call graph. It creates an entry if +// there is no such a node available. +CallGraphEntryNode* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) { + CHECK(gv.defined()); + + // This inserts an element to the call graph if it is not there yet. + auto& call_graph_node = call_graph_[gv]; + if (call_graph_node) return call_graph_node.get(); + + CHECK(module->ContainGlobalVar(gv->name_hint)) + << "GlobalVar " << gv->name_hint << " not found in the current ir module"; + + // Create the node for the inserted entry. + call_graph_node = std::unique_ptr(new CallGraphEntryNode(gv)); + return call_graph_node.get(); +} + +void CallGraphNode::Print(std::ostream& os) const { + // Print the call graph in the topological order. + std::vector nodes = TopologicalOrder(); + for (const auto* cgn : nodes) { + cgn->Print(os); + } +} + +GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntryNode* cg_node, + bool update_call_graph) { + CHECK(cg_node->empty() || (cg_node->IsRecursive() && cg_node->size() == 1)) + << "Cannot remove global var " << cg_node->GetNameHint() + << " from call graph, because it still calls " + << cg_node->size() << " other global functions"; + + if (update_call_graph) { + // Update the call graph by removing all edges that point to the node + // `cg_node`. + for (auto& it : *this) { + it.second->RemoveAllCallTo(cg_node); + } + } + GlobalVar gv = cg_node->GetGlobalVar(); + call_graph_.erase(gv); + // Update the IR module. + module->Remove(gv); + return gv; +} + +std::vector CallGraphNode::GetEntryGlobals() const { + std::vector ret; + // An entry function in Relay is a function that never called by other + // functions or only called by itself. + for (const auto& it : *this) { + if (it.second->GetRefCount() == 0 || it.second->IsRecursiveEntry()) { + ret.push_back(it.second.get()); + } + } + return ret; +} + +std::vector CallGraphNode::TopologicalOrder() const { + std::vector ret; + // Collect all entry nodes. + std::vector entries = GetEntryGlobals(); + CallGraphEntryNode::CallGraphEntryNodeSet visited; + + for (const auto& it : entries) { + // Keep tracking the nodes that have been visited. + auto topo = it->TopologicalOrder(&visited); + // Preprend the collected items. The intermeidate nodes that are shared by + // multiple entries are guaranteed to be collected when visiting the + // previous entries. Therefore, topological order remains. + ret.insert(ret.begin(), topo.begin(), topo.end()); + } + + // Find out the missing global functions if there are any to help debugging. + if (ret.size() != module->functions.size()) { + for (auto it : module->functions) { + if (visited.find((*this)[it.first]) == visited.end()) { + LOG(WARNING) << "Missing global:" << it.first->name_hint + << " with # refs = " << (*this)[it.first]->GetRefCount(); + } + } + LOG(FATAL) << "Expected " << module->functions.size() + << " globals, but received " + << ret.size(); + } + + return ret; +} + +// A BSF traverser is used to collect the nodes in a CallGraphEntryNode. The nodes +// that are visited by previous CallGraphEntryNode entries can be memoized. This +// helps us to make sure no entry will be visited multiple times when collecting +// the nodes for an entir call graph. +std::vector CallGraphEntryNode::TopologicalOrder( + CallGraphEntryNodeSet* visited) const { + std::vector ret; + std::vector current_nodes; + if (visited->find(this) == visited->end()) { + visited->emplace(this); + current_nodes.emplace_back(const_cast(this)); + } + + std::vector next_nodes; + while (!current_nodes.empty()) { + for (const auto& node : current_nodes) { + ret.push_back(node); + // Iterate through the called entries. + for (auto git = node->begin(); git != node->end(); ++git) { + if (visited->find(git->second) == visited->end()) { + next_nodes.push_back(git->second); + visited->emplace(git->second); + } + } + } + // Update the current level and clean the next level. + current_nodes = next_nodes; + next_nodes.clear(); + } + return ret; +} + +void CallGraphEntryNode::CleanCallGraphEntries() { + while (!called_globals_.empty()) { + // Decrement the reference counter + called_globals_.back().second->DecRef(); + called_globals_.pop_back(); + } +} + +inline void CallGraphEntryNode::AddCalledGlobal(CallGraphEntryNode* cg_node) { + called_globals_.emplace_back(global_, cg_node); + // Increment the reference to indicate that another call site is found for + // the callee in `cg_node`. + cg_node->IncRef(); + // Mark the global function as recursive if it calls itself. + if (global_ == cg_node->GetGlobalVar()) { + cg_node->is_recursive_ = true; + } +} + +// Remove an edge from the current global function to the callee. +void CallGraphEntryNode::RemoveCallTo(const GlobalVar& callee) { + for (auto it = begin();; ++it) { + CHECK(it != end()) << "Cannot find global function " + << callee->name_hint << " to remove!"; + if (it->second->GetGlobalVar() == callee) { + // Only remove one occurrence of the call site. + it->second->DecRef(); + *it = called_globals_.back(); + called_globals_.pop_back(); + return; + } + } +} + +// Remove all edges from the current global function to the callee. +void CallGraphEntryNode::RemoveAllCallTo(CallGraphEntryNode* callee) { + for (uint32_t i = 0, e = size(); i != e;) { + if (called_globals_[i].second == callee) { + callee->DecRef(); + called_globals_[i] = called_globals_.back(); + called_globals_.pop_back(); + --e; + } else { + ++i; + } + } + // Make sure all references to the callee are removed. + CHECK_EQ(callee->GetRefCount(), 0U) + << "All references to " << callee->GetNameHint() + << " should have been removed"; +} + +void CallGraphEntryNode::Print(std::ostream& os) const { + if (!global_.defined()) { + os << "GlobalVar is not defined\n"; + return; + } + + os << "Call graph node: " << global_->name_hint; + os << " at: " << this << ", #refs = " << GetRefCount() << "\n"; + + for (const auto& it : *this) { + os << " call site: <" << it.first->name_hint << "> calls "; + os << it.second->GetNameHint() << "\n"; + } + os << "\n"; +} + +std::ostream& operator<<(std::ostream& os, const CallGraph& cg) { + cg->Print(os); + return os; +} + +std::ostream& operator<<(std::ostream& os, const CallGraphEntryNode& cgn) { + cgn.Print(os); + return os; +} + +TVM_REGISTER_NODE_TYPE(CallGraphNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + CHECK(node); + p->stream << "CallGraph: \n" << GetRef(node); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.CallGraph") +.set_body_typed([](IRModule module) { + return CallGraph(module); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraph") +.set_body_typed([](CallGraph call_graph) { + std::stringstream ss; + ss << call_graph; + return ss.str(); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.GetModule") +.set_body_typed([](CallGraph call_graph) { + return call_graph->GetModule(); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.GetCallGraphGlobalVar") +.set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + std::stringstream ss; + ss << *entry_node; + return ss.str(); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.GetRefCountGlobalVar") +.set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return static_cast(entry_node->GetRefCount()); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.GetGlobalVarCallCount") +.set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return static_cast(entry_node->size()); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.IsRecursive") +.set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return entry_node->IsRecursive(); +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/call_graph.h b/src/relay/pass/call_graph.h new file mode 100644 index 000000000000..7e1f23db4b80 --- /dev/null +++ b/src/relay/pass/call_graph.h @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/pass/call_graph.h + * \brief Define data structures for the call graph of a IRModule. It borrows + * the idea how LLVM constructs CallGraph. + * + * https://llvm.org/doxygen/CallGraph_8h_source.html + */ + +#ifndef TVM_RELAY_PASS_CALL_GRAPH_H_ +#define TVM_RELAY_PASS_CALL_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +class CallGraphEntryNode; +class CallGraph; + +class CallGraphNode : public Object { + using CallGraphMap = + std::unordered_map, ObjectHash, + ObjectEqual>; + // Create iterator alias for a CallGraphNode object. + using iterator = CallGraphMap::iterator; + using const_iterator = CallGraphMap::const_iterator; + + public: + /*! \brief The IR module for creating a CallGraphNode. */ + IRModule module; + + /*! \brief Default constructor. */ + CallGraphNode() {} + + void VisitAttrs(AttrVisitor* v) { + v->Visit("module", &module); + } + + /*! + * \brief Print the call graph. + * + * \param os The stream for printing. + */ + void Print(std::ostream& os) const; + + /*! \return The begin iterator. */ + iterator begin() { + return call_graph_.begin(); + } + /*! \return The end iterator. */ + iterator end() { + return call_graph_.end(); + } + /*! \return The begin iterator. */ + const_iterator begin() const { + return call_graph_.begin(); + } + /*! \return The end iterator. */ + const_iterator end() const { + return call_graph_.end(); + } + + /*! + * \brief Get an element from the CallGraphNode using a GlobalVar. + * + * \param gv The GlobalVar used for indexing. + * + * \return The fetched element. + */ + const CallGraphEntryNode* operator[](const GlobalVar& gv) const; + /*! + * \brief Get an element from the CallGraphNode using a GlobalVar. + * + * \param gv The GlobalVar used for indexing. + * + * \return The fetched element. + */ + CallGraphEntryNode* operator[](const GlobalVar& gv); + /*! + * \brief Get an element from the CallGraphNode using the global function name. + * + * \param gvar_name The global function name used for indexing. + * + * \return The fetched element. + */ + const CallGraphEntryNode* operator[](const std::string& gvar_name) const { + return (*this)[module->GetGlobalVar(gvar_name)]; + } + /*! + * \brief Get an element from the CallGraphNode using the global function name. + * + * \param gvar_name The global function name used for indexing. + * + * \return The fetched element. + */ + CallGraphEntryNode* operator[](const std::string& gvar_name) { + return (*this)[module->GetGlobalVar(gvar_name)]; + } + + /*! \brief Return the IR module. */ + IRModule GetModule() const { + return module; + } + + /*! + * \brief Get the entries/root nodes of CallGraphNode. + * + * Entry functions are never referenced by other functions. + * Note these functions can be recursive as well. + * + * \return The list of CallGraphEntryNode that represent entry nodes. + */ + std::vector GetEntryGlobals() const; + + /*! + * \brief Remove a GlobalVar in a given CallGraphEntryNode from the current + * IR module. + * + * \param cg_node The CallGraphEntryNode that contains a global function to be + * removed. + * \param update_call_graph Indicate if we will update the CallGraph as well + * since updating is costly. We are only able to remove a leaf function + * when update_call_graph is disabled because the edges pointing to + * functions being removed are not updated. + * + * \return The GlobalVar removed from the current module. + */ + GlobalVar RemoveGlobalVarFromModule(CallGraphEntryNode* cg_node, + bool update_call_graph = false); + + /*! + * \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for + * the GlobalVar if it doesn't exist. + * + * \param gv The GlobalVar for query. + * + * \return The queried entry. + */ + CallGraphEntryNode* LookupGlobalVar(const GlobalVar& gv); + + /*! + * \brief Get the entries from the CallGraphNode in the topological order. + * + * This is useful for various module-level optimizations/analysis. For example, + * inlining requires the correct order of the functions being processed, i.e. + * callee should be always handled before callers. + * + * \return The list of collected entries that are sorted in the topological order. + */ + std::vector TopologicalOrder() const; + + static constexpr const char* _type_key = "relay.CallGraph"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallGraphNode, Object); + + private: + /*! + * \brief Create a CallGraphEntryNode for a global function and add it to the + * CallGraphNode. + * + * \param gv The global var. + * \param func The global function corresponding to `gv`. + */ + void AddToCallGraph(const GlobalVar& gv, const Function& func); + + /*! \brief A record contains GlobalVar to CallGraphEntryNode mapping. */ + CallGraphMap call_graph_; + + friend CallGraph; +}; + +/*! + * \brief The class that represents the call graph of a Relay IR module. It also + * provides a variety of utility functions for users to query, view, and update + * a call graph. + */ +class CallGraph : public ObjectRef { + using CallGraphMap = + std::unordered_map, ObjectHash, + ObjectEqual>; + // Create iterator alias for a CallGraph object. + using iterator = CallGraphMap::iterator; + using const_iterator = CallGraphMap::const_iterator; + + public: + /*! + * \brief Construct a CallGraph from a IR module. + * + * \param module The IR module + */ + explicit CallGraph(IRModule module); + + /*! + * \brief Construct from an object pointer. + * \param n The object pointer. + */ + explicit CallGraph(ObjectPtr n) : ObjectRef(n) {} + + /*! \return The begin iterator. */ + iterator begin() { + auto* n = operator->(); + CHECK(n); + return n->begin(); + } + /*! \return The end iterator. */ + iterator end() { + auto* n = operator->(); + CHECK(n); + return n->end(); + } + /*! \return The begin iterator. */ + const_iterator begin() const { + const auto* n = operator->(); + CHECK(n); + return n->begin(); + } + /*! \return The end iterator. */ + const_iterator end() const { + const auto* n = operator->(); + CHECK(n); + return n->end(); + } + + /*! + * \brief Get an element from the CallGraph using a GlobalVar. + * + * \param gv The GlobalVar used for indexing. + * + * \return The fetched element. + */ + const CallGraphEntryNode* operator[](const GlobalVar& gv) const { + const auto* n = operator->(); + CHECK(n); + return (*n)[gv]; + } + /*! + * \brief Get an element from the CallGraph using a GlobalVar. + * + * \param gv The GlobalVar used for indexing. + * + * \return The fetched element. + */ + CallGraphEntryNode* operator[](const GlobalVar& gv) { + auto* n = operator->(); + CHECK(n); + return (*n)[gv]; + } + /*! + * \brief Get an element from the CallGraph using the global function name. + * + * \param gvar_name The global function name used for indexing. + * + * \return The fetched element. + */ + const CallGraphEntryNode* operator[](const std::string& gvar_name) const { + const auto* n = operator->(); + CHECK(n); + return (*n)[gvar_name]; + } + /*! + * \brief Get an element from the CallGraph using the global function name. + * + * \param gvar_name The global function name used for indexing. + * + * \return The fetched element. + */ + CallGraphEntryNode* operator[](const std::string& gvar_name) { + auto* n = operator->(); + CHECK(n); + return (*n)[gvar_name]; + } + + /*! \return mutable pointers to the node. */ + CallGraphNode* operator->() const { + auto* ptr = get_mutable(); + CHECK(ptr != nullptr); + return static_cast(ptr); + } + + private: + /*! \brief Overload the << operator to print a call graph. */ + friend std::ostream& operator<<(std::ostream& os, const CallGraph&); +}; + +/*! + * \brief A node in the call graph. It maintains the edges from a caller to + * all callees. + */ +class CallGraphEntryNode { + public: + using CallGraphEntry = std::pair; + using CallGraphEntryVector = std::vector; + using CallGraphEntryNodeSet = std::unordered_set; + // Create iterator alias for a CallGraphEntryNode object. + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + + /*! + * \brief Construct from a GlobalVar. + * + * \param gv The GlobalVar to create a CallGraphEntryNode. + */ + explicit CallGraphEntryNode(const GlobalVar& gv) : global_(gv) {} + /*! + * \brief Delete copy constructor. + */ + CallGraphEntryNode(const CallGraphEntryNode&) = delete; + /*! \brief Delete assignment. */ + CallGraphEntryNode& operator=(const CallGraphEntryNode&) = delete; + + /*! \return The begin iterator */ + iterator begin() { + return called_globals_.begin(); + } + /*! \return The end iterator */ + iterator end() { + return called_globals_.end(); + } + /*! \return The const begin iterator */ + const_iterator begin() const { + return called_globals_.begin(); + } + /*! \return The const end iterator */ + const_iterator end() const { + return called_globals_.end(); + } + + /*! + * \brief Return if the list of called nodes is empty. + * + * \return true if the list is empty. Otherwise, false. + */ + bool empty() const { + return called_globals_.empty(); + } + + /*! + * \brief Return the size of the list that represents the nodes are called by + * the current node. + * + * \return The number of called nodes. + */ + uint32_t size() const { + return static_cast(called_globals_.size()); + } + + /*! + * \brief Fetch the i-th CallGraphEntryNode from the list of nodes that are called + * by the current function. + * + * \param i The index. + * + * \return The fetched CallGraphEntryNode. + */ + CallGraphEntryNode* operator[](size_t i) const { + CHECK_LT(i, called_globals_.size()) << "Invalid Index"; + return called_globals_[i].second; + } + + /*! + * \brief Print the call graph that is stemmed from the current CallGraphEntryNode. + * + * \param os The stream for printing. + */ + void Print(std::ostream& os) const; + + /*! + * \brief Return the number of times the global function is referenced. + * + * \return The count. + */ + uint32_t GetRefCount() const { + return ref_cnt_; + } + + /*! + * \brief Return the GlobalVar stored in the current CallGraphEntryNode. + * + * \return The GlobalVar. + */ + GlobalVar GetGlobalVar() const { + return global_; + } + + /*! + * \brief Return the name hint of the GlobalVar stored in the CallGraphEntryNode. + * + * \return The name hint of the global function. + */ + std::string GetNameHint() const { + return global_->name_hint; + } + + /*! + * \brief Return if the global function corresponding to the current + * CallGraphEntryNode is a recursive function. + * + * \return true if it is recursive. Otherwise, false. + */ + bool IsRecursive() const { + return is_recursive_; + } + + /*! + * \brief Return if the global function corresponding to the current + * CallGraphEntryNode is both a recursive function and an entry function. This type + * of function only has one reference which is called by itself. + * + * \return true if it is both a recursive function and an entry. Otherwise, false. + */ + bool IsRecursiveEntry() const { + return GetRefCount() == 1 && IsRecursive(); + } + + /*! + * \brief Return the topological order of the CallGraphEntryNode. + * + * \param visited A set of CallGraphEntryNode objects that have been visited. + * + * \return The list of CallGraphEntryNode that is represented in topological order. + */ + std::vector TopologicalOrder( + CallGraphEntryNodeSet* visited = new CallGraphEntryNodeSet()) const; + + /*! + * \brief Remove all edges from the current CallGraphEntryNode to any global + * function it calls. + */ + void CleanCallGraphEntries(); + + /*! + * \brief Add a node to the list of nodes that are being called by the current + * global function. + * + * \param cg_node The CallGraphEntryNode that will be added to the call list. + */ + void AddCalledGlobal(CallGraphEntryNode* cg_node); + + /*! + * \brief Remove a call edge to the global function from the current + * function. + * + * \param callee The function that is being called. + */ + void RemoveCallTo(const GlobalVar& callee); + + /*! + * \brief Remove all the edges that represent that calls to the global function + * stored in a given CallGraphEntryNode. + * + * \param callee The function that is being called. + */ + void RemoveAllCallTo(CallGraphEntryNode* callee); + + private: + /*! \brief Decrement the reference counter by 1. */ + void DecRef() { + CHECK_GT(ref_cnt_, 0); + --ref_cnt_; + } + /*! \brief Increment the reference counter by 1. */ + void IncRef() { ++ref_cnt_; } + + /*! + * \brief Mark if the global function stored in the CallGraphEntryNode is + * recursive function. + */ + bool is_recursive_{false}; + /*! \brief Count the number of times the global function is referenced. */ + uint32_t ref_cnt_{0}; + /*! \brief The GlobalVar stored in the current CallGraphEntryNode. */ + GlobalVar global_; + /*! \brief The list of entries called by the current CallGraphEntryNode. */ + CallGraphEntryVector called_globals_; + + friend class CallGraph; + /*! \brief Overload the << operator to print a call graph node. */ + friend std::ostream& operator<<(std::ostream& os, const CallGraphEntryNode&); +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_CALL_GRAPH_H_ diff --git a/tests/python/relay/test_call_graph.py b/tests/python/relay/test_call_graph.py new file mode 100644 index 000000000000..4d82c5c2ce22 --- /dev/null +++ b/tests/python/relay/test_call_graph.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +import pytest +import tvm +from tvm import relay + + +def test_callgraph_construct(): + mod = tvm.IRModule({}) + x = relay.var("x", shape=(2, 3)) + y = relay.var("y", shape=(2, 3)) + mod["g1"] = relay.Function([x, y], x + y) + call_graph = relay.CallGraph(mod) + assert "g1" in str(call_graph) + assert relay.alpha_equal(mod, call_graph.module) + + +def test_print_element(): + mod = tvm.IRModule({}) + x0 = relay.var("x0", shape=(2, 3)) + y0 = relay.var("y0", shape=(2, 3)) + mod["g0"] = relay.Function([x0, y0], x0 + y0) + x1 = relay.var("x1", shape=(2, 3)) + y1 = relay.var("y1", shape=(2, 3)) + mod["g1"] = relay.Function([x1, y1], x1 - y1) + call_graph = relay.CallGraph(mod) + + assert "#refs = 0" in str(call_graph["g0"]) + assert "#refs = 0" in str(call_graph["g1"]) + + +def test_global_call_count(): + mod = tvm.IRModule({}) + x0 = relay.var("x0", shape=(2, 3)) + y0 = relay.var("y0", shape=(2, 3)) + g0 = relay.GlobalVar("g0") + mod[g0] = relay.Function([x0, y0], x0 + y0) + x1 = relay.var("x1", shape=(2, 3)) + y1 = relay.var("y1", shape=(2, 3)) + g1 = relay.GlobalVar("g1") + mod[g1] = relay.Function([x1, y1], g0(x1, y1)) + call_graph = relay.CallGraph(mod) + + p0 = relay.var("p0", shape=(2, 3)) + p1 = relay.var("p1", shape=(2, 3)) + func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) + mod["main"] = func + call_graph = relay.CallGraph(mod) + + assert call_graph.global_call_count(g0) == 0 + assert call_graph.global_call_count(g1) == 1 + assert call_graph.global_call_count("main") == 2 + + +def test_ref_count(): + mod = tvm.IRModule({}) + x0 = relay.var("x0", shape=(2, 3)) + y0 = relay.var("y0", shape=(2, 3)) + g0 = relay.GlobalVar("g0") + mod[g0] = relay.Function([x0, y0], x0 + y0) + x1 = relay.var("x1", shape=(2, 3)) + y1 = relay.var("y1", shape=(2, 3)) + g1 = relay.GlobalVar("g1") + mod[g1] = relay.Function([x1, y1], x1 - y1) + call_graph = relay.CallGraph(mod) + + p0 = relay.var("p0", shape=(2, 3)) + p1 = relay.var("p1", shape=(2, 3)) + func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) + mod["main"] = func + call_graph = relay.CallGraph(mod) + + assert call_graph.ref_count(g0) == 1 + assert call_graph.ref_count(g1) == 1 + assert call_graph.ref_count("main") == 0 + + +def test_nested_ref(): + mod = tvm.IRModule({}) + x0 = relay.var("x0", shape=(2, 3)) + y0 = relay.var("y0", shape=(2, 3)) + g0 = relay.GlobalVar("g0") + mod[g0] = relay.Function([x0, y0], x0 + y0) + x1 = relay.var("x1", shape=(2, 3)) + y1 = relay.var("y1", shape=(2, 3)) + g1 = relay.GlobalVar("g1") + mod[g1] = relay.Function([x1, y1], g0(x1, y1)) + call_graph = relay.CallGraph(mod) + + p0 = relay.var("p0", shape=(2, 3)) + p1 = relay.var("p1", shape=(2, 3)) + func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) + mod["main"] = func + call_graph = relay.CallGraph(mod) + + assert call_graph.ref_count(g0) == 2 + assert call_graph.ref_count(g1) == 1 + assert call_graph.ref_count("main") == 0 + + +def test_recursive_func(): + mod = tvm.IRModule({}) + + x = relay.var('x', shape=[], dtype='int32') + fn0 = relay.Function([x], x) + gx = relay.GlobalVar("gx") + mod[gx] = fn0 + + sum_up = relay.GlobalVar('sum_up') + i = relay.var('i', shape=[], dtype='int32') + sb = relay.ScopeBuilder() + with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))): + sb.ret(i) + with sb.else_scope(): + one_less = relay.subtract(i, relay.const(1, dtype='int32')) + global_call = gx(i) + rec_call = relay.Call(sum_up, [one_less]) + global_call + sb.ret(relay.add(rec_call, i)) + func = relay.Function([i], + sb.get(), + ret_type=relay.TensorType([], 'int32')) + func = func.set_attribute("Compiler", tvm.tir.StringImm("a")) + mod[sum_up] = func + iarg = relay.var('i', shape=[], dtype='int32') + mod["main"] = relay.Function([iarg], sum_up(iarg)) + call_graph = relay.CallGraph(mod) + + assert call_graph.is_recursive(sum_up) + assert call_graph.ref_count(sum_up) == 2 + assert call_graph.ref_count(gx) == 1 + assert call_graph.ref_count("main") == 0 + + +if __name__ == "__main__": + pytest.main()