Skip to content

Commit

Permalink
fix graph tuner test
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Mar 19, 2020
1 parent 96e460f commit 8ffa4d0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
4 changes: 2 additions & 2 deletions python/tvm/autotvm/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
of typical tasks of interest.
"""

from .task import Task, create, get_config, args_to_workload, \
register_customized_task
from .task import Task, create, get_config, serialize_args, deserialize_args, \
args_to_workload, register_customized_task
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
Expand Down
45 changes: 39 additions & 6 deletions tests/python/unittest/test_graph_tuner_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner


def _create_args(dshape, kshape, strides, padding, dilation, layout, out_layout,
dtype, out_dtype):
data = tvm.te.placeholder(dshape, dtype=dtype)
kernel = tvm.te.placeholder(kshape, dtype=dtype)
return autotvm.task.serialize_args([data, kernel, strides, padding, dilation,
layout, layout, out_dtype])

def _create_data(target, dshape, dtype, layout):
data = relay.var("data", shape=dshape, dtype=dtype)
w0 = relay.var("w0_weight")
Expand All @@ -49,6 +56,12 @@ def _create_data(target, dshape, dtype, layout):
target=target,
params=params,
ops=(relay.op.get("nn.conv2d"),))
new_args = [
_create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]

costs = [0.04, 0.012, 0.03]
config_list = []
cfg_dict = {"index": -1,
Expand All @@ -74,7 +87,8 @@ def _create_data(target, dshape, dtype, layout):
config_list.append(ConfigEntity.from_json_dict(cfg_dict))

records = []
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
Expand Down Expand Up @@ -261,6 +275,12 @@ def test_many_sub_graphs():
target=target,
params=params,
ops=(conv2d,))
new_args = [
_create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]

costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045]
config_list = []
cfg_dict = {"index": -1,
Expand Down Expand Up @@ -307,9 +327,10 @@ def test_many_sub_graphs():
config_list.append(ConfigEntity.from_json_dict(cfg_dict))

records = []

new_args = new_args + new_args
tasks = tasks + tasks
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
Expand Down Expand Up @@ -359,6 +380,10 @@ def test_tuple():
target=target,
params=params,
ops=(conv2d,))
new_args = [
_create_args((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.01, 0.012, 0.03, 0.04]
config_list = []
cfg_dict = {"index": -1,
Expand Down Expand Up @@ -391,8 +416,10 @@ def test_tuple():
config_list.append(ConfigEntity.from_json_dict(cfg_dict))

records = []
new_args = new_args + new_args
tasks = tasks + tasks
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
Expand Down Expand Up @@ -444,6 +471,11 @@ def test_triangle_block():
target=target,
params=params,
ops=(conv2d,))
new_args = [
_create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 3, 8, 8), (32, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045]
config_list = []
cfg_dict = {"index": -1,
Expand Down Expand Up @@ -490,9 +522,10 @@ def test_triangle_block():
config_list.append(ConfigEntity.from_json_dict(cfg_dict))

records = []

new_args = new_args + new_args
tasks = tasks + tasks
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
Expand Down

0 comments on commit 8ffa4d0

Please sign in to comment.