From dbd805c171ad8f66b78bbaeb3eaa4bd8e562de6a Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Fri, 20 Mar 2020 09:26:40 -0700 Subject: [PATCH] [AutoTVM] Temporary fix to the stack overflow issue in autotvm task extraction (#5019) * Temporary fix to the stack overflow issue in autotvm task extraction * fix lint * fix graph tuner test --- python/tvm/autotvm/task/__init__.py | 3 +- python/tvm/autotvm/task/relay_integration.py | 17 +++++-- .../unittest/test_autotvm_graph_tuner_core.py | 45 ++++++++++++++++--- 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/python/tvm/autotvm/task/__init__.py b/python/tvm/autotvm/task/__init__.py index 7e18fca66bff..be50af77f37e 100644 --- a/python/tvm/autotvm/task/__init__.py +++ b/python/tvm/autotvm/task/__init__.py @@ -22,7 +22,8 @@ of typical tasks of interest. """ -from .task import Task, create, get_config, args_to_workload, template +from .task import Task, create, get_config, args_to_workload, template, \ + serialize_args, deserialize_args from .space import ConfigSpace, ConfigEntity from .code_hash import attach_code_hash, attach_code_hash_to_arg from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \ diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index a7cbef74e5c2..de183db41e2c 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -47,11 +47,20 @@ def _lower(mod, mod, _ = relay.optimize(mod, target, params) grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) grc.codegen(mod["main"]) + return # default case - compiler = relay.vm.VMCompiler() - if params: - compiler.set_params(params) - compiler.lower(mod, target=target) + # Try graph codegen first to extract autotvm tasks. + # If failed to compile, then fallback to use VM compiler. + # TODO: Currently VM compiler is likely to stack overflow for large models. + try: + opt_mod, _ = relay.optimize(mod, target, params) + grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) + grc.codegen(opt_mod["main"]) + except tvm.TVMError: + compiler = relay.vm.VMCompiler() + if params: + compiler.set_params(params) + compiler.lower(mod, target=target) def extract_from_program(mod, params, target, target_host=None, ops=None): diff --git a/tests/python/unittest/test_autotvm_graph_tuner_core.py b/tests/python/unittest/test_autotvm_graph_tuner_core.py index a7be18a5a2d3..f577b6387507 100644 --- a/tests/python/unittest/test_autotvm_graph_tuner_core.py +++ b/tests/python/unittest/test_autotvm_graph_tuner_core.py @@ -34,6 +34,13 @@ from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner +def _create_args(dshape, kshape, strides, padding, dilation, layout, out_layout, + dtype, out_dtype): + data = tvm.te.placeholder(dshape, dtype=dtype) + kernel = tvm.te.placeholder(kshape, dtype=dtype) + return autotvm.task.serialize_args([data, kernel, strides, padding, dilation, + layout, layout, out_dtype]) + def _create_data(target, dshape, dtype, layout): data = relay.var("data", shape=dshape, dtype=dtype) w0 = relay.var("w0_weight") @@ -49,6 +56,12 @@ def _create_data(target, dshape, dtype, layout): target=target, params=params, ops=(relay.op.get("nn.conv2d"),)) + new_args = [ + _create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype), + _create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype), + _create_args((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype), + ] + costs = [0.04, 0.012, 0.03] config_list = [] cfg_dict = {"index": -1, @@ -74,7 +87,8 @@ def _create_data(target, dshape, dtype, layout): config_list.append(ConfigEntity.from_json_dict(cfg_dict)) records = [] - for cost, config, task in zip(costs, config_list, tasks): + for args, cost, config, task in zip(new_args, costs, config_list, tasks): + task.args = args ms_input = MeasureInput(target=target, task=task, config=config) ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1) records.append((ms_input, ms_output)) @@ -261,6 +275,12 @@ def test_many_sub_graphs(): target=target, params=params, ops=(conv2d,)) + new_args = [ + _create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype), + _create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype), + _create_args((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype), + ] + costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045] config_list = [] cfg_dict = {"index": -1, @@ -307,9 +327,10 @@ def test_many_sub_graphs(): config_list.append(ConfigEntity.from_json_dict(cfg_dict)) records = [] - + new_args = new_args + new_args tasks = tasks + tasks - for cost, config, task in zip(costs, config_list, tasks): + for args, cost, config, task in zip(new_args, costs, config_list, tasks): + task.args = args ms_input = MeasureInput(target=target, task=task, config=config) ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1) records.append((ms_input, ms_output)) @@ -359,6 +380,10 @@ def test_tuple(): target=target, params=params, ops=(conv2d,)) + new_args = [ + _create_args((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype), + _create_args((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype), + ] costs = [0.01, 0.012, 0.03, 0.04] config_list = [] cfg_dict = {"index": -1, @@ -391,8 +416,10 @@ def test_tuple(): config_list.append(ConfigEntity.from_json_dict(cfg_dict)) records = [] + new_args = new_args + new_args tasks = tasks + tasks - for cost, config, task in zip(costs, config_list, tasks): + for args, cost, config, task in zip(new_args, costs, config_list, tasks): + task.args = args ms_input = MeasureInput(target=target, task=task, config=config) ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1) records.append((ms_input, ms_output)) @@ -444,6 +471,11 @@ def test_triangle_block(): target=target, params=params, ops=(conv2d,)) + new_args = [ + _create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype), + _create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype), + _create_args((1, 3, 8, 8), (32, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype), + ] costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045] config_list = [] cfg_dict = {"index": -1, @@ -490,9 +522,10 @@ def test_triangle_block(): config_list.append(ConfigEntity.from_json_dict(cfg_dict)) records = [] - + new_args = new_args + new_args tasks = tasks + tasks - for cost, config, task in zip(costs, config_list, tasks): + for args, cost, config, task in zip(new_args, costs, config_list, tasks): + task.args = args ms_input = MeasureInput(target=target, task=task, config=config) ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1) records.append((ms_input, ms_output))