Skip to content

Commit

Permalink
[COMPILER] GraphHash based cache system, allow dump and query duplica…
Browse files Browse the repository at this point in the history
…ted functions. (apache#30)
tqchen committed May 29, 2018
1 parent 300ae30 commit 343c19a
Showing 19 changed files with 856 additions and 149 deletions.
4 changes: 2 additions & 2 deletions nnvm/include/nnvm/graph.h
Original file line number Diff line number Diff line change
@@ -63,11 +63,11 @@ class Graph {
* \return The indexed graph.
* \sa IndexedGraph
*/
const IndexedGraph& indexed_graph();
const IndexedGraph& indexed_graph() const;

private:
// internal structure of indexed graph
std::shared_ptr<const IndexedGraph> indexed_graph_;
mutable std::shared_ptr<const IndexedGraph> indexed_graph_;
};

/*!
11 changes: 11 additions & 0 deletions nnvm/include/nnvm/pass_functions.h
Original file line number Diff line number Diff line change
@@ -41,6 +41,17 @@ inline std::string SaveJSON(Graph graph) {
return ret.GetAttr<std::string>("json");
}


/*!
* \brief Print graph ir
* \param graph The graph to be printed
* \return The graph ir string.
*/
inline std::string PrintGraphIR(Graph graph) {
Graph ret = ApplyPass(std::move(graph), "PrintGraphIR");
return ret.GetAttr<std::string>("graphir");
}

/*!
* \brief Add control flow dependencies between nodes.
*
2 changes: 2 additions & 0 deletions nnvm/python/nnvm/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@

from . import build_module
from . build_module import build, optimize, build_config
from . compile_engine import engine, graph_key

from .. import symbol as _symbol
from .. import graph as _graph
@@ -14,5 +15,6 @@

from .. import top as _top


tvm.register_extension(_symbol.Symbol, _symbol.Symbol)
tvm.register_extension(_graph.Graph, _graph.Graph)
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
@@ -184,7 +184,7 @@ def build(graph, target, shape, dtype="float32", params=None):
graph._set_json_attr("target", target, "str")
graph._set_json_attr("opt_level", cfg.opt_level, "int")
graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuse")
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
libmod = graph_attr._move_out_module(graph, "module")
return graph, libmod, params

99 changes: 99 additions & 0 deletions nnvm/python/nnvm/compiler/compile_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# pylint: disable=invalid-name
"""Compiler engine interface to internal engine"""
import tvm

_list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems")
_clear_cache = tvm.get_global_func("nnvm.compiler.ClearCache")
_get_cache_item = tvm.get_global_func("nnvm.compiler.GetCacheItem")
_set_cache_item = tvm.get_global_func("nnvm.compiler.SetCacheItem")
_graph_key_get_graph = tvm.get_global_func("nnvm.compiler.GraphKeyGetGraph")
_make_graph_key = tvm.get_global_func("nnvm.compiler.MakeGraphKey")

@tvm.register_node
class GraphKey(tvm.node.NodeBase):
"""Key of a graph compilation context"""
@property
def graph(self):
return _graph_key_get_graph(self)


@tvm.register_node
class GraphCacheEntry(tvm.node.NodeBase):
"""CacheEntry of compilation into a TVM Function"""
pass


@tvm.register_node
class GraphFunc(tvm.node.NodeBase):
"""Compiled result of a graph into a TVM Function"""
pass


class Engine(object):
"""Global singleton compilation engine."""
def items(self):
"""List the available cache key value pairs.
Returns
-------
item_list : list of (GraphKey, GraphCacheEntry)
The existing cache items
"""
res = _list_cache_items()
assert len(res) % 2 == 0
return [(res[2*i], res[2*i+1]) for i in range(len(res)/2)]

def clear_cache(self):
"""Clear the existing cached functions."""
_clear_cache()

def __setitem__(self, key, value):
"""Clear the existing cached functions."""
if isinstance(value, GraphCacheEntry):
_set_cache_item(key, value.graph_func)
else:
_set_cache_item(key, value)

def __getitem__(self, key):
"""Clear the existing cached functions."""
return _get_cache_item(key)

def dump(self):
"""Return a string representation of engine dump
Returns
-------
dump : str
The dumped string representation
"""
items = self.items()
res = "====================================\n"
res += "CompilerEngine dump, %d items cached\n" % len(items)
for key, value in items:
res += "------------------------------------\n"
res += "target={}\n".format(key.target)
res += "inputs={}\n".format(key.inputs)
res += "use_count={}\n".format(value.use_count)
res += "func_name={}\n".format(value.graph_func.func_name)
res += key.graph.ir() + "\n"
res += "===================================\n"
return res

engine = Engine()


def graph_key(graph, inputs, target):
"""Construct a new graph key.
Parameters
----------
graph : Graph
The computation graph structure
inputs : list of Tensor(placeholder)
The input requirement to the graph.
target : str
The target of compilation.
"""
return _make_graph_key(graph, inputs, target)
2 changes: 2 additions & 0 deletions nnvm/python/nnvm/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""Utilities for testcase"""

from .config import ctx_list
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/testing/config.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import os
import tvm

def test_ctx_list():
def ctx_list():
"""Get context list for testcases"""
device_list = os.environ.get("NNVM_TEST_TARGETS", "")
device_list = (device_list.split(",") if device_list
Loading

0 comments on commit 343c19a

Please sign in to comment.