Skip to content

Commit

Permalink
[RELAY][BACKEND] Enable PlanMemory in the graph runtime. (apache#2120)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Nov 19, 2018
1 parent a43dd3b commit 510c9d5
Show file tree
Hide file tree
Showing 8 changed files with 450 additions and 25 deletions.
2 changes: 2 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,14 @@ inline const TTypeNode* ExprNode::type_as() const {
/*!
* \brief Print node as text format.
* \param node The node to be printed.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std::string RelayPrint(
const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def build(funcs, target, target_host=None):
funcs : List[tvm.LoweredFunc]
The list of lowered functions.
target : tvm.Target
The target to run the code on.
Expand Down
33 changes: 25 additions & 8 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import absolute_import
import json
import attr
from . import _backend
from . import compile_engine
from ..op import Op
from ..expr import Function, GlobalVar, ExprFunctor
Expand Down Expand Up @@ -103,11 +104,12 @@ def __init__(self, mod, target):
self.nodes = []
self.var_map = {}
self.params = {}
self.storage_map = None
self.compile_engine = compile_engine.get()
self.lowered_funcs = set()
self._name_map = {}

def add_node(self, node, checked_type):
def add_node(self, node, expr):
"""
Add a node to the graph.
Expand All @@ -116,14 +118,21 @@ def add_node(self, node, checked_type):
node: Node
The node to add to the graph.
checked_type: Type
The type of the node.
expr: tvm.relay.Expr
The corresponding expression.
Returns
-------
node_ref: Union[NodeRef, List[NodeRef]]
A reference to the node.
"""
checked_type = expr.checked_type
# setup storage ids
assert expr in self.storage_map
node.attrs["storage_id"] = [
x.value for x in self.storage_map[expr]
]

node_id = len(self.nodes)
self.nodes.append(node)
# Tuple return value, flatten as tuple
Expand Down Expand Up @@ -168,7 +177,7 @@ def visit_constant(self, op):
name = "p%d" % index
self.params[name] = op.data
node = InputNode(name, {})
return self.add_node(node, op.checked_type)
return self.add_node(node, op)

def visit_function(self, _):
raise RuntimeError("function not supported")
Expand Down Expand Up @@ -244,7 +253,7 @@ def visit_call(self, call):
op_name = cached_func.func_name
op_node = OpNode(self._get_unique_name(op_name), {},
op_name, inputs, {})
return self.add_node(op_node, call.checked_type)
return self.add_node(op_node, call)

def _get_json(self):
"""
Expand Down Expand Up @@ -281,8 +290,7 @@ def _get_json(self):
assert node.num_outputs == len(node.attrs["shape"])
shapes += node.attrs["shape"]
dltypes += node.attrs["dtype"]
for i in range(node.num_outputs):
storage_ids.append(i + num_entry)
storage_ids += node.attrs["storage_id"]
num_entry += node.num_outputs
node_row_ptr.append(num_entry)

Expand All @@ -302,6 +310,14 @@ def _get_json(self):

return json.dumps(json_dict, indent=2)

def debug_dump_memory_plan(self, func):
"""Debug function to dump memory plan."""
def _annotate(expr):
if expr in self.storage_map:
return str(self.storage_map[expr])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)

def codegen(self, func):
"""Compile a single function into a graph.
Expand All @@ -321,11 +337,12 @@ def codegen(self, func):
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
self.storage_map = _backend.GraphPlanMemory(func)
# First we convert all the parameters into input nodes.
for param in func.params:
node = InputNode(param.name_hint, {})
self.var_map[param] = self.add_node(
node, param.type_annotation)
node, param)

# Then we compile the body into a graph which can depend
# on input variables.
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/relay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,29 @@ def register_relay_node(type_key=None):

class RelayNode(NodeBase):
"""Base class of all relay node."""
def astext(self, annotate=None):
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
Returns
-------
text : str
The text format of the expression.
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
Note
----
meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big(constat weights),
so it can be helpful to skip printing the meta data section.
"""
return _expr.RelayPrint(self, annotate)
return _expr.RelayPrint(self, show_meta_data, annotate)


@register_relay_node
Expand Down
Loading

0 comments on commit 510c9d5

Please sign in to comment.