Skip to content

Commit

Permalink
[RUNTIME][COMPILER] Formal compiler pipeline, runtime wrapper module (a…
Browse files Browse the repository at this point in the history
…pache#21)

* [RUNTIME][COMPILER] Formal compiler pipeline, runtime wrapper module

* more detailed comments
  • Loading branch information
tqchen committed May 29, 2018
1 parent 79ceb9f commit a2ab3d8
Show file tree
Hide file tree
Showing 13 changed files with 340 additions and 130 deletions.
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tvm

from . import build_module
from . build_module import build, precompute_prune, _run_graph
from . build_module import build, optimize, build_config

from .. import symbol as _symbol
from .. import graph as _graph
Expand Down
126 changes: 115 additions & 11 deletions nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,74 @@
from __future__ import absolute_import as _abs

import tvm
from . import graph_attr, graph_pass
from . import graph_attr, graph_util
from .. import graph as _graph
from .. import runtime

OPT_PASS_LEVEL = {
"SimplifyBatchNormInference": 2,
"PrecomputePrune": 2,
"OpFusion": 1
}

# List of optimization pass and level when switch on
class BuildConfig(object):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current = None
defaults = {
"opt_level": 2,
}
def __init__(self, **kwargs):
self._old_scope = None
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError(
"invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys()))
self._attr = kwargs

def __getattr__(self, name):
if name not in self._attr:
return BuildConfig.defaults[name]
return self._attr[name]

def __enter__(self):
# pylint: disable=protected-access
self._old_scope = BuildConfig.current
attr = BuildConfig.current._attr.copy()
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
return self

def __exit__(self, ptype, value, trace):
assert self._old_scope
BuildConfig.current = self._old_scope


BuildConfig.current = BuildConfig()

def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass.
Returns
-------
config: BuildConfig
The build configuration
"""
return BuildConfig(**kwargs)


@tvm.register_func("nnvm.compiler.lower")
def _lower(sch, inputs, func_name):
f = tvm.lower(sch, inputs, name=func_name)
Expand All @@ -19,23 +83,45 @@ def _build(funcs, target):
return tvm.build(funcs, target=target)


def optimize(graph):
"""Perform graph optimization
def _update_shape_dtype(shape, dtype, params):
"""Update shape dtype given params information"""
if not params:
return shape, dtype
shape = shape.copy()
shape.update({k : v.shape for k, v in params.items()})
if isinstance(dtype, str):
for k, v in params.items():
if v.dtype != dtype:
raise ValueError(
"%s: dtype not expected %s vs %s" % (k, dtype, v.dtype))
else:
dtype = dtype.copy()
dtype.update({k : str(v.dtype) for k, v in params.items()})
return shape, dtype


def optimize(graph, shape, dtype="float32"):
"""Perform target and parameter invariant graph optimization.
Parameters
----------
graph : Graph
The graph to be used in lowering.
The graph to be used in optimized.
Returns
-------
graph : Graph
The optimized execution graph.
The optimized graph.
"""
# pylint: disable=unused-argument
cfg = BuildConfig.current
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "SimplifyBatchNormInference"])
return graph


def build(graph, target, shape, dtype="float32"):
def build(graph, target, shape, dtype="float32", params=None):
"""Build graph into runtime library.
This is the final step of graph compilation.
Expand All @@ -54,27 +140,45 @@ def build(graph, target, shape, dtype="float32"):
dtype : str or dict of str to str
The input types to the graph
params : dict of str to NDArray
Input parameetrs to the graph that do not change
during inference time. Used for pre-compute
folding optimization.
Returns
-------
graph : Graph
The final execution graph.
libmod : tvm.Module
The modue that comes with the execution graph
params : dict of str to NDArray
The updated parameters of graph if params is passed.
This can be different from the params passed in.
"""
if not isinstance(target, str):
raise TypeError("require target to be str")
if not isinstance(shape, dict):
raise TypeError("require shape to be dict")

cfg = BuildConfig.current
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# Apply optimization
graph = optimize(graph, shape, dtype)
# Precompute prune
if params and cfg.opt_level >= OPT_PASS_LEVEL["PrecomputePrune"]:
graph, params = precompute_prune(graph, params)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# Operator Fusion and generatiom
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", target, "str")
graph._set_json_attr("opt_level", cfg.opt_level, "int")
graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuse")
libmod = graph_attr._move_out_module(graph, "module")
return graph, libmod
return graph, libmod, params


def _run_graph(graph, params):
Expand All @@ -98,9 +202,9 @@ def _run_graph(graph, params):
dtype = {k : v.dtype for k, v in params.items()}
target = "llvm"
ctx = tvm.cpu(0)
_, oshape = graph_pass.infer_shape(graph, **shape)
_, odtype = graph_pass.infer_dtype(graph, **dtype)
graph, libmod = build(graph, target, shape, dtype)
_, oshape = graph_util.infer_shape(graph, **shape)
_, odtype = graph_util.infer_dtype(graph, **dtype)
graph, libmod, _ = build(graph, target, shape, dtype)
m = runtime.create(graph, libmod, ctx)
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
for k, v in params.items():
Expand Down
78 changes: 0 additions & 78 deletions nnvm/python/nnvm/compiler/graph_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,81 +6,3 @@
- Composable API: break graph transformation pass as segments of small transformations.
"""
from __future__ import absolute_import as _abs

import tvm
from . import graph_attr


def infer_shape(graph, **shape):
"""Infer the shape given the shape of inputs.
Parameters
----------
graph : Graph
The graph to perform shape inference from
Returns
-------
in_shape : list of tuple
Shape of inputs
out_shape: list of tuple
Shape of outputs
"""
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
index = graph.index
input_shape = [shape[index.entry_id(x)] for x in index.input_names]
output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
return input_shape, output_shape


def infer_dtype(graph, **dtype):
"""Infer the type given the typeS of inputs.
Parameters
----------
graph : Graph
The graph to perform type inference from
Returns
-------
in_dtype : list of tuple
Dtype of inputs
out_dtype: list of tuple
Dtype of outputs
"""
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph = graph.apply("InferType")
dtype = graph.json_attr("dtype")
index = graph.index
input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.input_names]
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.output_entries]
return input_dtype, output_dtype


_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")

def check_graph_equal(grapha, graphb):
"""Check if two graphs have equal structure.
Parameters
----------
grapha : Graph
The first graph
graphb : Graph
The second graph
Raises
------
ValueError
ValueError is raised with error message when graph not equal
"""
err = _deep_compare(grapha, graphb)
if err:
raise ValueError("Graph compare error: " + err)
80 changes: 80 additions & 0 deletions nnvm/python/nnvm/compiler/graph_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# pylint: disable=invalid-name
"""Utility function to get information from graph."""
from __future__ import absolute_import as _abs

import tvm
from . import graph_attr

def infer_shape(graph, **shape):
"""Infer the shape given the shape of inputs.
Parameters
----------
graph : Graph
The graph to perform shape inference from
Returns
-------
in_shape : list of tuple
Shape of inputs
out_shape: list of tuple
Shape of outputs
"""
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
index = graph.index
input_shape = [shape[index.entry_id(x)] for x in index.input_names]
output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
return input_shape, output_shape


def infer_dtype(graph, **dtype):
"""Infer the type given the typeS of inputs.
Parameters
----------
graph : Graph
The graph to perform type inference from
Returns
-------
in_dtype : list of tuple
Dtype of inputs
out_dtype: list of tuple
Dtype of outputs
"""
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph = graph.apply("InferType")
dtype = graph.json_attr("dtype")
index = graph.index
input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.input_names]
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.output_entries]
return input_dtype, output_dtype


_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")

def check_graph_equal(grapha, graphb):
"""Check if two graphs have equal structure.
Parameters
----------
grapha : Graph
The first graph
graphb : Graph
The second graph
Raises
------
ValueError
ValueError is raised with error message when graph not equal
"""
err = _deep_compare(grapha, graphb)
if err:
raise ValueError("Graph compare error: " + err)
Loading

0 comments on commit a2ab3d8

Please sign in to comment.