Skip to content

Commit

Permalink
Improve graph tuner dealing with Tuple (apache#3649)
Browse files Browse the repository at this point in the history
* Improve graph tuner dealing with Tuple

* Add test case

* Move some data out of _base.py

* Fix lint
  • Loading branch information
kevinthesun authored and yzhliu committed Aug 11, 2019
1 parent 3d4ba8d commit 4f12046
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 25 deletions.
5 changes: 0 additions & 5 deletions python/tvm/autotvm/graph_tuner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@
"""Helper functions and global data"""


# Operators dependent on original layouts.
LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]

# We set a large time to represent an invalid layout-transformation.
# This number is set to be 10e9 seconds to align with autotvm.
INVALID_LAYOUT_TIME = 10e9
Expand Down
14 changes: 13 additions & 1 deletion python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def _callback(_, inputs, results):
timeout=timeout)
measure_option = autotvm.measure_option(builder=builder, runner=runner)
for args in args_list:
data, in_layout, out_layout = args
args = serialize_args(args)
ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args)
if ltf_workload in self._layout_transform_perf_records:
Expand All @@ -454,7 +455,18 @@ def _callback(_, inputs, results):
flops = 1
for i in input_shape:
flops *= i
inferred_time = flops * avg_time

# Rule out invalid layout transformations
out = topi.layout_transform(data, in_layout, out_layout)
out_flops = 1
for i in topi.util.get_const_tuple(out.shape):
out_flops *= i

if flops != out_flops:
inferred_time = INVALID_LAYOUT_TIME
else:
inferred_time = flops * avg_time

record_input = MeasureInput(target=self._target, task=None, config=None)
record_output = MeasureResult(costs=(inferred_time,), error_no=0,
all_cost=-1, timestamp=-1)
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv

from .utils import has_multiple_inputs, is_boundary_node
from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node


# Setup relay op base name -> topi compute functions
Expand Down Expand Up @@ -252,7 +252,7 @@ def get_in_nodes(node_list, target_ops, input_names):
visited_dict = {}
in_node_dict = {}
for i, node in enumerate(node_list):
if is_boundary_node(node, input_names):
if is_boundary_node(node, input_names) or is_skipped_node(node):
continue
get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
for key, val in visited_dict.items():
Expand Down Expand Up @@ -282,10 +282,12 @@ def get_in_nodes(node_list, target_ops, input_names):
boundary_nodes.append(key)
if boundary_nodes:
for idx in boundary_nodes:
del in_node_dict[idx]
if idx in in_node_dict:
del in_node_dict[idx]
else:
has_reduced_node = False


# Remove empty nodes to ignore pre-computed sub-graph
has_empty_node = True
while has_empty_node:
Expand Down
28 changes: 25 additions & 3 deletions python/tvm/autotvm/graph_tuner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from tvm import relay
from tvm.relay import transform

from .._base import LAYOUT_FIXED_OP


def has_multiple_inputs(node_list, node_idx, input_names):
"""Check whether a node has multiple input nodes
Expand Down Expand Up @@ -72,11 +70,35 @@ def is_boundary_node(node_entry, input_names):
out : bool
whether node is a boundary node.
"""
out = node_entry["op"] in LAYOUT_FIXED_OP or \
# Operators dependent on original layouts.
_LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]

out = node_entry["op"] in _LAYOUT_FIXED_OP or \
("name" in node_entry and node_entry["name"] in input_names)
return out


def is_skipped_node(node_entry):
"""Whether a node is not counted.
Parameters
----------
node_entry : dict
Node entry.
Returns
-------
out : bool
whether node is skipped.
"""
# Operators not counted in graph tuner.
_SKIPPED_OP = ["Tuple"]

return node_entry["op"] in _SKIPPED_OP


def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
"""Bind input variables of a relay function expression
to new shapes and/or dtypes.
Expand Down
109 changes: 96 additions & 13 deletions tests/python/unittest/test_graph_tuner_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,33 +354,115 @@ def test_many_sub_graphs():
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
ltf_records.append((ms_input, ms_output))

ltf_keys = []
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"]
executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))

executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))


def test_tuple():
target = "llvm"
dtype = "float32"
dshape = (1, 5, 32, 32)
layout = "NCHW"
target_ops = [relay.nn.conv2d]

data = relay.var("data", shape=dshape, dtype=dtype)
w0 = relay.var("w0_weight")
conv0 = relay.nn.conv2d(data, w0, channels=2, kernel_size=(3, 3), padding=(1, 1))
w1 = relay.var("w1_weight")
conv1 = relay.nn.conv2d(data, w1, channels=3, kernel_size=(3, 3), padding=(1, 1))
out = relay.concatenate([conv0, conv1], axis=1)
net = relay.Function(relay.analysis.free_vars(out), out)
net, params = relay.testing.create_workload(net)

tasks = autotvm.task.extract_from_program(net["main"],
target=target,
params=params,
ops=(relay.op.nn.conv2d,))
wkl_list = [
create_workload((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
create_workload((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.01, 0.012, 0.03, 0.04]
config_list = []
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [1, 2]],
["tile_ow", "sp", [4, 8]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [1, 3]],
["tile_ow", "sp", [2, 16]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [2, 1]],
["tile_ow", "sp", [4, 8]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [3, 1]],
["tile_ow", "sp", [2, 16]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))

records = []

wkl_list = wkl_list + wkl_list
tasks = tasks + tasks
for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
task.workload = wkl
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))

ltf_records = []
ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_task = copy.deepcopy(tasks[0])
ltf_task.workload = ltf_wkl
ms_input = MeasureInput(target=target, task=ltf_task, config=None)
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
ltf_records.append((ms_input, ms_output))

executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
expected_out = [records[2][0].config, records[1][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))

executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
expected_out = [records[2][0].config, records[1][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))

Expand All @@ -390,3 +472,4 @@ def test_many_sub_graphs():
test_DPTuner_run()
test_PBQPTuner_run()
test_many_sub_graphs()
test_tuple()

0 comments on commit 4f12046

Please sign in to comment.