Skip to content

Commit

Permalink
[relay][pass] Annotation for heterogeneous compilation (#2361)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and icemelon committed Jan 11, 2019
1 parent 30a5a60 commit 09236bf
Show file tree
Hide file tree
Showing 24 changed files with 1,494 additions and 68 deletions.
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.broadcast_to_like
tvm.relay.collapse_sum_like
tvm.relay.slice_like
tvm.relay.device_copy
tvm.relay.annotation.on_device


Level 1 Definitions
Expand Down
31 changes: 31 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/annotation.h
* \brief Attribute for annotation operators.
*/
#ifndef TVM_RELAY_ATTRS_ANNOTATION_H_
#define TVM_RELAY_ATTRS_ANNOTATION_H_

#include <tvm/attrs.h>
#include <string>

namespace tvm {
namespace relay {

/*!
* \brief Options for the device annotation operators.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
int device_type;

TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
.describe(
"The virutal device/context type that an expression is annotated with.")
.set_default(0);
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
36 changes: 36 additions & 0 deletions include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/device_copy.h
* \brief Attribute for the device copy operator.
*/
#ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_
#define TVM_RELAY_ATTRS_DEVICE_COPY_H_

#include <tvm/attrs.h>
#include <string>

namespace tvm {
namespace relay {

/*!
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
int dst_dev_type;
int src_dev_type;

TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") {
TVM_ATTR_FIELD(src_dev_type)
.describe(
"The virutal device/context type where the op copies data from.")
.set_default(0);
TVM_ATTR_FIELD(dst_dev_type)
.describe(
"The virutal device/context type where the op copies data to.")
.set_default(0);
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_DEVICE_COPY_H_
1 change: 0 additions & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
}
};


struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout;
std::string dst_layout;
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Rewrite the annotated program.
* \param expr The expression.
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
* \return The updated program.
*/
Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);

/*!
* \brief Collect the device mapping information of each expression.
* \param expr The expression.
* \return The device mapping.
*/
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
Expand Down
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .ndarray import vpi, rocm, opengl, ext_dev

from ._ffi.runtime_ctypes import TypeCode
from ._ffi.ndarray import TVMContext
from ._ffi.function import Function
from ._ffi.base import TVMError, __version__
from .api import *
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .op.tensor import *
from .op.transform import *
from . import nn
from . import annotation
from . import vision
from . import image
from . import frontend
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Annotation related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.annotation import *
11 changes: 6 additions & 5 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,21 @@ def build(funcs, target, target_host=None):
Parameters
----------
funcs : List[tvm.LoweredFunc]
The list of lowered functions.
funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
A list of lowered functions or dictionary mapping from targets to
lowered functions.
target : tvm.Target
The target to run the code on.
The target to run the code on.
target_host : tvm.Target
The host target.
The host target.
Returns
-------
module : tvm.Module
The runtime module.
The runtime module.
"""
if target_host == "":
target_host = None
Expand Down
83 changes: 68 additions & 15 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@

from __future__ import absolute_import
import json
from collections import defaultdict
import attr
from . import _backend
from . import compile_engine
from ..op import Op
from ..expr import Function, GlobalVar
from ..expr_functor import ExprFunctor
from ..ty import TupleType, TensorType
from ... import target as _target


@attr.s
Expand Down Expand Up @@ -105,9 +107,9 @@ def __init__(self, mod, target):
self.nodes = []
self.var_map = {}
self.params = {}
self.storage_map = None
self.storage_device_map = None
self.compile_engine = compile_engine.get()
self.lowered_funcs = set()
self.lowered_funcs = defaultdict(set)
self._name_map = {}

def add_node(self, node, expr):
Expand All @@ -129,10 +131,20 @@ def add_node(self, node, expr):
"""
checked_type = expr.checked_type
# setup storage ids
assert expr in self.storage_map
node.attrs["storage_id"] = [
x.value for x in self.storage_map[expr]
]
assert expr in self.storage_device_map
storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
node.attrs["storage_id"] = [x.value for x in storage_device_info[0]]
device_types = [x.value for x in storage_device_info[1]]
num_unknown_devices = device_types.count(0)
if num_unknown_devices != 0 and num_unknown_devices != len(device_types):
raise RuntimeError("The graph contains not annotated nodes for "
"heterogeneous execution. All nodes must be "
"annotated.")

# Add the `device_index` attribute when the graph is annotated.
if num_unknown_devices == 0:
node.attrs["device_index"] = device_types

node_id = len(self.nodes)
self.nodes.append(node)
Expand Down Expand Up @@ -232,9 +244,25 @@ def visit_call(self, call):
"TVM only support calls to primitive functions " +
"(i.e functions composed of fusable operator invocations)")

cached_func = self.compile_engine.lower(func, self.target)
assert call in self.storage_device_map
device_types = self.storage_device_map[call][1]
call_dev_type = device_types[0].value
if isinstance(self.target, (str, _target.Target)):
# homogeneous execution.
cached_func = self.compile_engine.lower(func, self.target)
self.target = {0: str(self.target)}
elif isinstance(self.target, dict):
# heterogeneous execution.
if call_dev_type not in self.target:
raise Exception("No target is provided for device " +
"{0}".format(call_dev_type))
cached_func = self.compile_engine.lower(func,
self.target[call_dev_type])
else:
raise ValueError("self.target must be the type of str," +
"tvm.target.Target, or dict of int to str")
for loweredf in cached_func.funcs:
self.lowered_funcs.add(loweredf)
self.lowered_funcs[self.target[call_dev_type]].add(loweredf)

inputs = []
# flatten tuple in the call.
Expand Down Expand Up @@ -284,20 +312,25 @@ def _get_json(self):
num_entry = 0
shapes = []
storage_ids = []
device_types = []
dltypes = []
node_row_ptr = [0]
for node in self.nodes:
assert node.num_outputs == len(node.attrs["shape"])
shapes += node.attrs["shape"]
dltypes += node.attrs["dtype"]
storage_ids += node.attrs["storage_id"]
if "device_index" in node.attrs:
device_types += node.attrs["device_index"]
num_entry += node.num_outputs
node_row_ptr.append(num_entry)

# Compute "attrs" field.
attrs = {}
attrs["shape"] = ["list_shape", shapes]
attrs["storage_id"] = ["list_int", storage_ids]
if device_types:
attrs["device_index"] = ["list_int", device_types]
attrs["dltype"] = ["list_str", dltypes]

json_dict = {
Expand All @@ -313,11 +346,24 @@ def _get_json(self):
def debug_dump_memory_plan(self, func):
"""Debug function to dump memory plan."""
def _annotate(expr):
if expr in self.storage_map:
return str(self.storage_map[expr])
if expr in self.storage_device_map:
storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
return str(storage_device_info[0])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)

def debug_dump_device_annotation(self, func):
"""Debug function to dump device annotation result."""
def _annotate(expr):
if expr in self.storage_device_map:
storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
return str(storage_device_info[1])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)


def codegen(self, func):
"""Compile a single function into a graph.
Expand All @@ -331,24 +377,31 @@ def codegen(self, func):
graph_json : str
The graph json that can be consumed by runtime.
lowered_funcs : List[tvm.LoweredFunc]
lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
The lowered functions.
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
self.storage_map = _backend.GraphPlanMemory(func)
self.storage_device_map = _backend.GraphPlanMemory(func)
# First we convert all the parameters into input nodes.
for param in func.params:
node = InputNode(param.name_hint, {})
self.var_map[param] = self.add_node(
node, param)
self.var_map[param] = self.add_node(node, param)

# Then we compile the body into a graph which can depend
# on input variables.
self.heads = self.visit(func.body)
graph_json = self._get_json()
lowered_funcs = list(self.lowered_funcs)

# Return the lowered functions as a list for homogeneous compilation.
# Otherwise, for heterogeneous compilation, a dictionary containing
# the device id to a list of lowered functions is returned. Both forms
# are acceptable to tvm.build.
if not isinstance(self.target, dict):
lowered_funcs = list(list(self.lowered_funcs.values())[0])
else:
lowered_funcs = {k: list(v) for k, v in self.lowered_funcs.items()}
return graph_json, lowered_funcs, self.params

def _get_unique_name(self, name):
Expand Down
Loading

0 comments on commit 09236bf

Please sign in to comment.