Skip to content

Commit

Permalink
[AutoTVM] Temporary fix to the stack overflow issue in autotvm task e…
Browse files Browse the repository at this point in the history
…xtraction (#5019)

* Temporary fix to the stack overflow issue in autotvm task extraction

* fix lint

* fix graph tuner test
  • Loading branch information
icemelon authored Mar 20, 2020
1 parent b91dbca commit dbd805c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 11 deletions.
3 changes: 2 additions & 1 deletion python/tvm/autotvm/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
of typical tasks of interest.
"""

from .task import Task, create, get_config, args_to_workload, template
from .task import Task, create, get_config, args_to_workload, template, \
serialize_args, deserialize_args
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
Expand Down
17 changes: 13 additions & 4 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,20 @@ def _lower(mod,
mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(mod["main"])
return
# default case
compiler = relay.vm.VMCompiler()
if params:
compiler.set_params(params)
compiler.lower(mod, target=target)
# Try graph codegen first to extract autotvm tasks.
# If failed to compile, then fallback to use VM compiler.
# TODO: Currently VM compiler is likely to stack overflow for large models.
try:
opt_mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(opt_mod["main"])
except tvm.TVMError:
compiler = relay.vm.VMCompiler()
if params:
compiler.set_params(params)
compiler.lower(mod, target=target)


def extract_from_program(mod, params, target, target_host=None, ops=None):
Expand Down
45 changes: 39 additions & 6 deletions tests/python/unittest/test_autotvm_graph_tuner_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner


def _create_args(dshape, kshape, strides, padding, dilation, layout, out_layout,
dtype, out_dtype):
data = tvm.te.placeholder(dshape, dtype=dtype)
kernel = tvm.te.placeholder(kshape, dtype=dtype)
return autotvm.task.serialize_args([data, kernel, strides, padding, dilation,
layout, layout, out_dtype])

def _create_data(target, dshape, dtype, layout):
data = relay.var("data", shape=dshape, dtype=dtype)
w0 = relay.var("w0_weight")
Expand All @@ -49,6 +56,12 @@ def _create_data(target, dshape, dtype, layout):
target=target,
params=params,
ops=(relay.op.get("nn.conv2d"),))
new_args = [
_create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]

costs = [0.04, 0.012, 0.03]
config_list = []
cfg_dict = {"index": -1,
Expand All @@ -74,7 +87,8 @@ def _create_data(target, dshape, dtype, layout):
config_list.append(ConfigEntity.from_json_dict(cfg_dict))

records = []
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
Expand Down Expand Up @@ -261,6 +275,12 @@ def test_many_sub_graphs():
target=target,
params=params,
ops=(conv2d,))
new_args = [
_create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]

costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045]
config_list = []
cfg_dict = {"index": -1,
Expand Down Expand Up @@ -307,9 +327,10 @@ def test_many_sub_graphs():
config_list.append(ConfigEntity.from_json_dict(cfg_dict))

records = []

new_args = new_args + new_args
tasks = tasks + tasks
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
Expand Down Expand Up @@ -359,6 +380,10 @@ def test_tuple():
target=target,
params=params,
ops=(conv2d,))
new_args = [
_create_args((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.01, 0.012, 0.03, 0.04]
config_list = []
cfg_dict = {"index": -1,
Expand Down Expand Up @@ -391,8 +416,10 @@ def test_tuple():
config_list.append(ConfigEntity.from_json_dict(cfg_dict))

records = []
new_args = new_args + new_args
tasks = tasks + tasks
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
Expand Down Expand Up @@ -444,6 +471,11 @@ def test_triangle_block():
target=target,
params=params,
ops=(conv2d,))
new_args = [
_create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 3, 8, 8), (32, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045]
config_list = []
cfg_dict = {"index": -1,
Expand Down Expand Up @@ -490,9 +522,10 @@ def test_triangle_block():
config_list.append(ConfigEntity.from_json_dict(cfg_dict))

records = []

new_args = new_args + new_args
tasks = tasks + tasks
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
Expand Down

0 comments on commit dbd805c

Please sign in to comment.