Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][AUTOTVM] Extract tuning tasks from Relay programs #2181

Merged
merged 14 commits into from
Dec 24, 2018
1 change: 1 addition & 0 deletions python/tvm/autotvm/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .topi_integration import register_topi_compute, register_topi_schedule
from .nnvm_integration import extract_from_graph, extract_from_multiple_graph
from .relay_integration import extract_from_program, extract_from_multiple_program
231 changes: 29 additions & 202 deletions python/tvm/autotvm/task/nnvm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,208 +7,13 @@
import logging


from ... import tensor, placeholder, create_schedule, target as _target
from ... import target as _target

from ..util import get_const_tuple
from .task import create, register
from .task import create
from .topi_integration import TaskExtractEnv

logger = logging.getLogger('autotvm')

def serialize_args(args):
"""serialize arguments of a topi function to a hashable tuple.

Parameters
----------
args: list of hashable or Tensor
"""
ret = []
for t in args:
if isinstance(t, tensor.Tensor):
ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype))
else:
ret.append(t)
return tuple(ret)


def deserialize_args(args):
"""The inverse function of :code:`serialize_args`.

Parameters
----------
args: list of hashable or Tensor
"""
ret = []
for t in args:
if isinstance(t, tuple) and t[0] == 'TENSOR':
ret.append(placeholder(shape=t[1], dtype=t[2]))
else:
ret.append(t)
return ret


# Task extractor for nnvm graph
class TaskExtractEnv:
"""Global environment for extracting tuning tasks from nnvm graph"""
current = None

def __init__(self):
import topi
import nnvm

# NOTE: To add more symbols, you only need to change the following lists
# nnvm symbol -> topi compute
self.symbol2topi = {
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
nnvm.sym.dense: [topi.nn.dense],
}

# topi compute -> autotvm task name
self.topi_to_task = {
topi.nn.conv2d: "topi_nn_conv2d",
topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.dense: "topi_nn_dense",
}

self.topi_to_schedule = {
topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
topi.generic.schedule_conv2d_nhwc],
topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
topi.generic.schedule_depthwise_conv2d_nhwc],
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.dense: [topi.generic.schedule_dense],
}

self._register_tracing()
self._register_topi_task()
self.task_collection = []
self.wanted_topi_funcs = list(self.topi_to_task.keys())

def _register_tracing(self):
"""Register tracing function to track the topi function call"""
# register topi compute for "tracing" target
for topi_compute in self.topi_to_task:
def _local_scope(compute_func):
"""start a scope to hold the local function in for loop"""

@compute_func.register("tracing", )
def _tracing_topi_compute(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \
"Please modify it to use only positional args."

if compute_func in self.wanted_topi_funcs: # record this call
key = (self.topi_to_task[compute_func], serialize_args(args))
if key not in self.task_collection:
self.task_collection.append(key)

return compute_func.fdefault(*args)
_local_scope(topi_compute)

# register topi schedule for "tracing" target
for topi_compute in self.topi_to_task:
for topi_schedule in self.topi_to_schedule[topi_compute]:
def _local_scope_(schedule_func):
"""start a scope to hold the local function in for loop"""

@schedule_func.register("tracing", )
def _tracing_topi_compute(outs):
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
return create_schedule([x.op for x in outs])
_local_scope_(topi_schedule)

def _register_topi_task(self):
"""register tuning wrapper for topi function"""
import topi

# Tuning wrapper for topi functions
@register("topi_nn_conv2d")
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
assert layout == 'NCHW', "only support NCHW currently"
C = topi.nn.conv2d(*args, **kwargs)
s = topi.generic.schedule_conv2d_nchw([C])
return s, [A, W, C]

@register("topi_nn_depthwise_conv2d_nchw")
def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs)
s = topi.generic.schedule_depthwise_conv2d_nchw([C])
return s, [A, W, C]

@register("topi_nn_group_conv2d_nchw")
def _topi_nn_group_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.group_conv2d_nchw(*args, **kwargs)
s = topi.generic.schedule_group_conv2d_nchw([C])
return s, [A, W, C]

@register("topi_nn_conv2d_transpose_nchw")
def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.conv2d_transpose_nchw(*args, **kwargs)
s = topi.generic.schedule_conv2d_transpose_nchw([C])
return s, [A, W, C]

@register("topi_nn_dense")
def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
data, weight, bias = args
C = topi.nn.dense(*args, **kwargs)
s = topi.generic.schedule_dense([C])
if bias is not None:
return s, [data, weight, bias, C]
return s, [data, weight, C]

def reset(self, wanted_topi_funcs):
"""Reset task collections

Parameters
----------
wanted_topi_funcs: List of function
The topi function to be extracted
"""
self.task_collection = []
self.wanted_topi_funcs = wanted_topi_funcs

def get_tasks(self):
"""Get collected tasks

Returns
-------
tasks: List of tuple(name, args)
A list of tasks extracted from the nnvm graph
"""
return self.task_collection

@staticmethod
def get():
"""Get the single instance of TaskExtractEnv

Returns
-------
env: TaskExtractEnv
The single instance of TaskExtractEnv
"""
if not TaskExtractEnv.current:
TaskExtractEnv.current = TaskExtractEnv()
return TaskExtractEnv.current


def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
""" Extract tuning tasks from a nnvm graph.
Expand Down Expand Up @@ -237,13 +42,24 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
collected tasks
"""
import nnvm.compiler
import nnvm
import topi

env = TaskExtractEnv.get()

#NOTE: To add more symbols, you only need to change the following lists
#nnvm symbol -> topi compute
SYMBOL2TOPI = {
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
nnvm.sym.dense: [topi.nn.dense],
}

topi_funcs = []
for sym_name in symbols:
if sym_name in env.symbol2topi:
topi_funcs.extend(env.symbol2topi[sym_name])
if sym_name in SYMBOL2TOPI:
topi_funcs.extend(SYMBOL2TOPI[sym_name])
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)

Expand Down Expand Up @@ -297,13 +113,24 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_
collected tasks
"""
import nnvm.compiler
import nnvm
import topi

env = TaskExtractEnv.get()

#NOTE: To add more symbols, you only need to change the following lists
#nnvm symbol -> topi compute
SYMBOL2TOPI = {
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw],
nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
nnvm.sym.dense: [topi.nn.dense],
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we combine these two SYMBOL2TOPI ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps; @merrymercy should we do this? There have also been discuss board questions from people who didn't know to add conv2d_transpose to the tutorial example when tuning other networks.

Copy link
Member

@merrymercy merrymercy Dec 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzhliu is referring to combine L52 and L123.
This is not related to the problem in the forum.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we don't combine them currently seems to be because the imports here are tricky to do outside of a function. Should we just delete the extract_from_multiple_graphs function? I do not see it used often.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using it. Maybe we can create another function for these lists


topi_funcs = []
for sym_name in symbols:
if sym_name in env.symbol2topi:
topi_funcs.extend(env.symbol2topi[sym_name])
if sym_name in SYMBOL2TOPI:
topi_funcs.extend(SYMBOL2TOPI[sym_name])
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)

Expand Down
Loading