Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Autotvm] Fix autotvm customized template #5034

Merged
merged 3 commits into from
Mar 12, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/autotvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback
from .task import get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, register_customized_task, \
register_topi_compute, register_topi_schedule, template, \
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
ApplyGraphBest as apply_graph_best
from .env import GLOBAL_SCOPE
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_infer_layout(task_name):
return topi.nn.depthwise_conv2d_infer_layout
raise ValueError("Cannot find infer layout for task %s" % task_name)

@autotvm.register_customized_task("layout_transform")
@autotvm.template("layout_transform")
def layout_transform(*args):
"""Autotvm layout transform template."""
cfg = get_config()
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/autotvm/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
of typical tasks of interest.
"""

from .task import Task, create, get_config, args_to_workload, \
register_customized_task
from .task import Task, create, get_config, args_to_workload, template
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
Expand Down
69 changes: 50 additions & 19 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def get_inputs(self, out):
queue.extend(t.op.input_tensors)
return inputs

def register_task_compute(name, func=None):
def _register_task_compute(name, func=None):
"""Register compute function to autotvm task

Parameters
Expand Down Expand Up @@ -247,7 +247,7 @@ def _do_reg(f):
return _do_reg(func)
return _do_reg

def register_task_schedule(name, func=None):
def _register_task_schedule(name, func=None):
"""Register schedule function to autotvm task

Parameters
Expand Down Expand Up @@ -276,14 +276,9 @@ def _do_reg(f):
return _do_reg(func)
return _do_reg

def register_customized_task(name, func=None):
def _register_customized_task(name, func=None):
"""Register a customized function to AutoTVM task.

In most cases, you can just use register_topi_compute and register_topi_schedule
with the same task name to define an AutoTVM task. However, you can also
create a customized AutoTVM task that defines a tunable template or performs
extra layout transform before invoking compute/schedule function.

Parameters
----------
name: str
Expand All @@ -297,14 +292,45 @@ def register_customized_task(name, func=None):
-------
decorator: callable
A decorator
"""
def _do_reg(f):
if name not in TASK_TABLE:
TASK_TABLE[name] = TopiTemplate()
tmpl = TASK_TABLE[name]
if tmpl.customized_func is not None:
raise ValueError("Customized func is already registered in autoTVM task %s" % name)
tmpl.customized_func = f
return f
if func:
return _do_reg(func)
return _do_reg


def template(task_name, func=None):
"""Decorate a function as a tunable schedule template.

Parameters
----------
task_name: str
The task name

func: None or callable
A callable template function.
If it is None, return a decorator.
If is callable, decorate this function.

Returns
-------
func: callable
The decorated function

Examples
--------
The following code is a tunable template for a blocked matrix multiplication

.. code-block:: python

@autotvm.register_customized_task("matmul")
@autotvm.template("matmul")
def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
Expand All @@ -331,17 +357,22 @@ def matmul(N, L, M, dtype):

return s, [A, B, C]
"""
def _do_reg(f):
if name not in TASK_TABLE:
TASK_TABLE[name] = TopiTemplate()
tmpl = TASK_TABLE[name]
if tmpl.customized_func is not None:
raise ValueError("Customized func is already registered in autoTVM task %s" % name)
tmpl.customized_func = f
return f
def _decorate(f):
def wrapper(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
workload = args_to_workload(args, task_name)
tgt = _target.Target.current()
cfg = DispatchContext.current.query(tgt, workload)
with ApplyConfig(cfg):
return f(*args, **kwargs)

_register_customized_task(task_name, f)
return wrapper

if func:
return _do_reg(func)
return _do_reg
return _decorate(func)
return _decorate


def create(task_name, args, target, target_host=None):
"""Create a tuning task and initialize its search space
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from tvm import target as _target
from tvm.te import tensor

from .task import args_to_workload, DispatchContext, \
register_task_compute, register_task_schedule, serialize_args
from .task import args_to_workload, serialize_args, DispatchContext, \
_register_task_compute, _register_task_schedule


# Task extractor for relay program
Expand Down Expand Up @@ -142,7 +142,7 @@ def register_topi_compute(task_name, func=None):
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
def _decorate(topi_compute):
@register_task_compute(task_name)
@_register_task_compute(task_name)
def wrapper(*args, **kwargs):
"""wrapper function for topi compute"""
assert not kwargs, "Do not support kwargs in template function call"
Expand Down Expand Up @@ -212,7 +212,7 @@ def register_topi_schedule(task_name, func=None):
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
def _decorate(topi_schedule):
@register_task_schedule(task_name)
@_register_task_schedule(task_name)
def wrapper(outs, *args, **kwargs):
"""wrapper function for topi schedule"""
workload = get_workload(outs)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/integration/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm import autotvm
from tvm.autotvm.tuner import RandomTuner

@autotvm.register_customized_task("testing/conv2d_no_batching")
@autotvm.template("testing/conv2d_no_batching")
def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
"""An example template for testing"""
assert N == 1, "Only consider batch_size = 1 in this template"
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_autotvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def run(self, measure_inputs, build_results):
def get_build_kwargs(self):
return {}

@autotvm.register_customized_task("testing/matmul")
@autotvm.template("testing/matmul")
def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
Expand All @@ -64,7 +64,7 @@ def matmul(N, L, M, dtype):

return s, [A, B, C]

@autotvm.register_customized_task("testing/bad_matmul")
@autotvm.template("testing/bad_matmul")
def bad_matmul(N, L, M, dtype):
if 'bad_device' in tvm.target.Target.current().keys:
A = te.placeholder((N, L), name='A', dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_autotvm_dispatch_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

def test_fallback():

@autotvm.register_customized_task("testing/dispatch/fallback")
@autotvm.template("testing/dispatch_fallback")
def simple_template(a, b):
cfg = autotvm.get_config()
assert cfg.is_fallback
Expand Down
2 changes: 1 addition & 1 deletion tutorials/autotvm/tune_conv2d_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
# can be very large (at the level of 10^9 for some input shapes)
#

@autotvm.register_customized_task("tutorial/conv2d_no_batching")
@autotvm.template("tutorial/conv2d_no_batching")
def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
assert N == 1, "Only consider batch_size = 1 in this template"

Expand Down
4 changes: 2 additions & 2 deletions tutorials/autotvm/tune_simple_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def matmul_v0(N, L, M, dtype):
# In autotvm, we can define a tunable parameter, or a "knob" for such kind of value.

# Matmul V1: List candidate values
@autotvm.register_customized_task("tutorial/matmul_v1") # 1. use a decorator
@autotvm.template("tutorial/matmul_v1") # 1. use a decorator
def matmul_v1(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
Expand Down Expand Up @@ -183,7 +183,7 @@ def matmul_v1(N, L, M, dtype):
# When the high level API cannot meet your requirement, you can always fall
# back to use low level API.

@autotvm.register_customized_task("tutorial/matmul")
@autotvm.template("tutorial/matmul")
def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name='A', dtype=dtype)
B = te.placeholder((L, M), name='B', dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions tutorials/optimize/opt_matmul_auto_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
#
# We use AutoTVM to search for best configurations in this schedule.

@autotvm.register_customized_task("tutorial/test_gemm")
@autotvm.template("tutorial/auto_tensorcore/test_gemm")
def test_gemm(N, L, M, dtype, layout):
if (layout == "NN"):
shape_a = (N, L)
Expand Down Expand Up @@ -265,7 +265,7 @@ def test_gemm(N, L, M, dtype, layout):
assert(major == 7 and minor == 5 and layout == 'TN')

def tune_and_evaluate(M, N, L, dtype, layout):
task = autotvm.task.create("tutorial/test_gemm", args=(N, L, M, dtype, layout),
task = autotvm.task.create("tutorial/auto_tensorcore/test_gemm", args=(N, L, M, dtype, layout),
target='cuda')
print(task.config_space)

Expand Down
2 changes: 1 addition & 1 deletion vta/tutorials/autotvm/tune_relay_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def my_clip(x, a_min, a_max):
# init autotvm env to register VTA operator
TaskExtractEnv()

@autotvm.register_customized_task("conv2d_packed.vta")
@autotvm.template("conv2d_packed.vta")
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
A, W = args[:2]
Expand Down