Skip to content

Commit

Permalink
[Autotvm] Fix autotvm customized template (#5034)
Browse files Browse the repository at this point in the history
* init

* fix template

* tweak naming
  • Loading branch information
icemelon authored Mar 12, 2020
1 parent 681df4f commit 70e11d3
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 54 deletions.
2 changes: 1 addition & 1 deletion python/tvm/autotvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback
from .task import get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, register_customized_task, \
register_topi_compute, register_topi_schedule, template, \
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
ApplyGraphBest as apply_graph_best
from .env import GLOBAL_SCOPE
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_infer_layout(task_name):
return topi.nn.depthwise_conv2d_infer_layout
raise ValueError("Cannot find infer layout for task %s" % task_name)

@autotvm.register_customized_task("layout_transform")
@autotvm.template("layout_transform")
def layout_transform(*args):
"""Autotvm layout transform template."""
cfg = get_config()
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/autotvm/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
of typical tasks of interest.
"""

from .task import Task, create, get_config, args_to_workload, \
register_customized_task
from .task import Task, create, get_config, args_to_workload, template
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
Expand Down
113 changes: 77 additions & 36 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,25 +186,35 @@ def __repr__(self):

TASK_TABLE = {}

class TopiTemplate(object):
"""Topi template that holds the topi compute and schedule function"""
class TaskTemplate(object):
"""
Task template is used to creates a tunable AutoTVM task.
It can be defined by a pair of compute and schedule function using
`_register_task_compute` and `_register_task_schedule`,
or by a customized task creation function that is more flexible using
`_register_customized_task`.
Note that when customized func is registered, compute and schedule function
will be ignored
"""
def __init__(self):
self.compute = None
self.schedule = None
self.customized_func = None
self.fcompute = None
self.fschedule = None
self.fcustomized = None

def __call__(self, *args, **kwargs):
args = deserialize_args(args)
if self.customized_func is None:
if self.fcustomized is None:
return self._default_func(*args, **kwargs)
assert callable(self.customized_func)
return self.customized_func(*args, **kwargs)
assert callable(self.fcustomized)
return self.fcustomized(*args, **kwargs)

def _default_func(self, *args, **kwargs):
assert callable(self.compute) and callable(self.schedule)
out = self.compute(*args, **kwargs)
assert callable(self.fcompute) and callable(self.fschedule)
out = self.fcompute(*args, **kwargs)
arg_bufs = [out] + self.get_inputs(out)
s = self.schedule([out])
s = self.fschedule([out])
return s, arg_bufs

def get_inputs(self, out):
Expand All @@ -218,7 +228,7 @@ def get_inputs(self, out):
queue.extend(t.op.input_tensors)
return inputs

def register_task_compute(name, func=None):
def _register_task_compute(name, func=None):
"""Register compute function to autotvm task
Parameters
Expand All @@ -237,17 +247,17 @@ def register_task_compute(name, func=None):
"""
def _do_reg(f):
if name not in TASK_TABLE:
TASK_TABLE[name] = TopiTemplate()
TASK_TABLE[name] = TaskTemplate()
tmpl = TASK_TABLE[name]
if tmpl.compute is not None:
if tmpl.fcompute is not None:
raise ValueError("Compute is already registered in autoTVM task %s" % name)
tmpl.compute = f
tmpl.fcompute = f
return f
if func:
return _do_reg(func)
return _do_reg

def register_task_schedule(name, func=None):
def _register_task_schedule(name, func=None):
"""Register schedule function to autotvm task
Parameters
Expand All @@ -266,24 +276,19 @@ def register_task_schedule(name, func=None):
"""
def _do_reg(f):
if name not in TASK_TABLE:
TASK_TABLE[name] = TopiTemplate()
TASK_TABLE[name] = TaskTemplate()
tmpl = TASK_TABLE[name]
if tmpl.schedule is not None:
if tmpl.fschedule is not None:
raise ValueError("Schedule is already registered in autoTVM task %s" % name)
tmpl.schedule = f
tmpl.fschedule = f
return f
if func:
return _do_reg(func)
return _do_reg

def register_customized_task(name, func=None):
def _register_customized_task(name, func=None):
"""Register a customized function to AutoTVM task.
In most cases, you can just use register_topi_compute and register_topi_schedule
with the same task name to define an AutoTVM task. However, you can also
create a customized AutoTVM task that defines a tunable template or performs
extra layout transform before invoking compute/schedule function.
Parameters
----------
name: str
Expand All @@ -297,14 +302,45 @@ def register_customized_task(name, func=None):
-------
decorator: callable
A decorator
"""
def _do_reg(f):
if name not in TASK_TABLE:
TASK_TABLE[name] = TaskTemplate()
tmpl = TASK_TABLE[name]
if tmpl.fcustomized is not None:
raise ValueError("Customized func is already registered in autoTVM task %s" % name)
tmpl.fcustomized = f
return f
if func:
return _do_reg(func)
return _do_reg


def template(task_name, func=None):
"""Decorate a function as a tunable schedule template.
Parameters
----------
task_name: str
The task name
func: None or callable
A callable template function.
If it is None, return a decorator.
If is callable, decorate this function.
Returns
-------
func: callable
The decorated function
Examples
--------
The following code is a tunable template for a blocked matrix multiplication
.. code-block:: python
@autotvm.register_customized_task("matmul")
@autotvm.template("matmul")
def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
Expand All @@ -331,17 +367,22 @@ def matmul(N, L, M, dtype):
return s, [A, B, C]
"""
def _do_reg(f):
if name not in TASK_TABLE:
TASK_TABLE[name] = TopiTemplate()
tmpl = TASK_TABLE[name]
if tmpl.customized_func is not None:
raise ValueError("Customized func is already registered in autoTVM task %s" % name)
tmpl.customized_func = f
return f
def _decorate(f):
def wrapper(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
workload = args_to_workload(args, task_name)
tgt = _target.Target.current()
cfg = DispatchContext.current.query(tgt, workload)
with ApplyConfig(cfg):
return f(*args, **kwargs)

_register_customized_task(task_name, f)
return wrapper

if func:
return _do_reg(func)
return _do_reg
return _decorate(func)
return _decorate


def create(task_name, args, target, target_host=None):
"""Create a tuning task and initialize its search space
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from tvm import target as _target
from tvm.te import tensor

from .task import args_to_workload, DispatchContext, \
register_task_compute, register_task_schedule, serialize_args
from .task import args_to_workload, serialize_args, DispatchContext, \
_register_task_compute, _register_task_schedule


# Task extractor for relay program
Expand Down Expand Up @@ -142,7 +142,7 @@ def register_topi_compute(task_name, func=None):
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
def _decorate(topi_compute):
@register_task_compute(task_name)
@_register_task_compute(task_name)
def wrapper(*args, **kwargs):
"""wrapper function for topi compute"""
assert not kwargs, "Do not support kwargs in template function call"
Expand Down Expand Up @@ -212,7 +212,7 @@ def register_topi_schedule(task_name, func=None):
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
def _decorate(topi_schedule):
@register_task_schedule(task_name)
@_register_task_schedule(task_name)
def wrapper(outs, *args, **kwargs):
"""wrapper function for topi schedule"""
workload = get_workload(outs)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/integration/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm import autotvm
from tvm.autotvm.tuner import RandomTuner

@autotvm.register_customized_task("testing/conv2d_no_batching")
@autotvm.template("testing/conv2d_no_batching")
def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
"""An example template for testing"""
assert N == 1, "Only consider batch_size = 1 in this template"
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_autotvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def run(self, measure_inputs, build_results):
def get_build_kwargs(self):
return {}

@autotvm.register_customized_task("testing/matmul")
@autotvm.template("testing/matmul")
def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
Expand All @@ -64,7 +64,7 @@ def matmul(N, L, M, dtype):

return s, [A, B, C]

@autotvm.register_customized_task("testing/bad_matmul")
@autotvm.template("testing/bad_matmul")
def bad_matmul(N, L, M, dtype):
if 'bad_device' in tvm.target.Target.current().keys:
A = te.placeholder((N, L), name='A', dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_autotvm_dispatch_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

def test_fallback():

@autotvm.register_customized_task("testing/dispatch/fallback")
@autotvm.template("testing/dispatch_fallback")
def simple_template(a, b):
cfg = autotvm.get_config()
assert cfg.is_fallback
Expand Down
2 changes: 1 addition & 1 deletion tutorials/autotvm/tune_conv2d_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
# can be very large (at the level of 10^9 for some input shapes)
#

@autotvm.register_customized_task("tutorial/conv2d_no_batching")
@autotvm.template("tutorial/conv2d_no_batching")
def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
assert N == 1, "Only consider batch_size = 1 in this template"

Expand Down
4 changes: 2 additions & 2 deletions tutorials/autotvm/tune_simple_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def matmul_v0(N, L, M, dtype):
# In autotvm, we can define a tunable parameter, or a "knob" for such kind of value.

# Matmul V1: List candidate values
@autotvm.register_customized_task("tutorial/matmul_v1") # 1. use a decorator
@autotvm.template("tutorial/matmul_v1") # 1. use a decorator
def matmul_v1(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
Expand Down Expand Up @@ -183,7 +183,7 @@ def matmul_v1(N, L, M, dtype):
# When the high level API cannot meet your requirement, you can always fall
# back to use low level API.

@autotvm.register_customized_task("tutorial/matmul")
@autotvm.template("tutorial/matmul")
def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions tutorials/optimize/opt_matmul_auto_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
#
# We use AutoTVM to search for best configurations in this schedule.

@autotvm.register_customized_task("tutorial/test_gemm")
@autotvm.template("tutorial/auto_tensorcore/test_gemm")
def test_gemm(N, L, M, dtype, layout):
if (layout == "NN"):
shape_a = (N, L)
Expand Down Expand Up @@ -265,7 +265,7 @@ def test_gemm(N, L, M, dtype, layout):
assert(major == 7 and minor == 5 and layout == 'TN')

def tune_and_evaluate(M, N, L, dtype, layout):
task = autotvm.task.create("tutorial/test_gemm", args=(N, L, M, dtype, layout),
task = autotvm.task.create("tutorial/auto_tensorcore/test_gemm", args=(N, L, M, dtype, layout),
target='cuda')
print(task.config_space)

Expand Down
2 changes: 1 addition & 1 deletion vta/tutorials/autotvm/tune_relay_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def my_clip(x, a_min, a_max):
# init autotvm env to register VTA operator
TaskExtractEnv()

@autotvm.register_customized_task("conv2d_packed.vta")
@autotvm.template("conv2d_packed.vta")
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
A, W = args[:2]
Expand Down

0 comments on commit 70e11d3

Please sign in to comment.