Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve graph tuner dealing with Tuple #3649

Merged
merged 4 commits into from
Aug 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions python/tvm/autotvm/graph_tuner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@
"""Helper functions and global data"""


# Operators dependent on original layouts.
LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]

# We set a large time to represent an invalid layout-transformation.
# This number is set to be 10e9 seconds to align with autotvm.
INVALID_LAYOUT_TIME = 10e9
Expand Down
14 changes: 13 additions & 1 deletion python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def _callback(_, inputs, results):
timeout=timeout)
measure_option = autotvm.measure_option(builder=builder, runner=runner)
for args in args_list:
data, in_layout, out_layout = args
args = serialize_args(args)
ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args)
if ltf_workload in self._layout_transform_perf_records:
Expand All @@ -454,7 +455,18 @@ def _callback(_, inputs, results):
flops = 1
for i in input_shape:
flops *= i
inferred_time = flops * avg_time

# Rule out invalid layout transformations
out = topi.layout_transform(data, in_layout, out_layout)
out_flops = 1
for i in topi.util.get_const_tuple(out.shape):
out_flops *= i

if flops != out_flops:
inferred_time = INVALID_LAYOUT_TIME
else:
inferred_time = flops * avg_time

record_input = MeasureInput(target=self._target, task=None, config=None)
record_output = MeasureResult(costs=(inferred_time,), error_no=0,
all_cost=-1, timestamp=-1)
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv

from .utils import has_multiple_inputs, is_boundary_node
from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node


# Setup relay op base name -> topi compute functions
Expand Down Expand Up @@ -252,7 +252,7 @@ def get_in_nodes(node_list, target_ops, input_names):
visited_dict = {}
in_node_dict = {}
for i, node in enumerate(node_list):
if is_boundary_node(node, input_names):
if is_boundary_node(node, input_names) or is_skipped_node(node):
continue
get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
for key, val in visited_dict.items():
Expand Down Expand Up @@ -282,10 +282,12 @@ def get_in_nodes(node_list, target_ops, input_names):
boundary_nodes.append(key)
if boundary_nodes:
for idx in boundary_nodes:
del in_node_dict[idx]
if idx in in_node_dict:
del in_node_dict[idx]
else:
has_reduced_node = False


# Remove empty nodes to ignore pre-computed sub-graph
has_empty_node = True
while has_empty_node:
Expand Down
28 changes: 25 additions & 3 deletions python/tvm/autotvm/graph_tuner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from tvm import relay
from tvm.relay import transform

from .._base import LAYOUT_FIXED_OP


def has_multiple_inputs(node_list, node_idx, input_names):
"""Check whether a node has multiple input nodes
Expand Down Expand Up @@ -72,11 +70,35 @@ def is_boundary_node(node_entry, input_names):
out : bool
whether node is a boundary node.
"""
out = node_entry["op"] in LAYOUT_FIXED_OP or \
# Operators dependent on original layouts.
_LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
"multibox_prior", "multibox_transform_loc", "where",
"non_max_suppression", "strided_slice"]

out = node_entry["op"] in _LAYOUT_FIXED_OP or \
("name" in node_entry and node_entry["name"] in input_names)
return out


def is_skipped_node(node_entry):
kevinthesun marked this conversation as resolved.
Show resolved Hide resolved
"""Whether a node is not counted.

Parameters
----------
node_entry : dict
Node entry.

Returns
-------
out : bool
whether node is skipped.
"""
# Operators not counted in graph tuner.
_SKIPPED_OP = ["Tuple"]

return node_entry["op"] in _SKIPPED_OP


def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
"""Bind input variables of a relay function expression
to new shapes and/or dtypes.
Expand Down
109 changes: 96 additions & 13 deletions tests/python/unittest/test_graph_tuner_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,33 +354,115 @@ def test_many_sub_graphs():
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
ltf_records.append((ms_input, ms_output))

ltf_keys = []
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"]
executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))

executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))


def test_tuple():
target = "llvm"
dtype = "float32"
dshape = (1, 5, 32, 32)
layout = "NCHW"
target_ops = [relay.nn.conv2d]

data = relay.var("data", shape=dshape, dtype=dtype)
w0 = relay.var("w0_weight")
conv0 = relay.nn.conv2d(data, w0, channels=2, kernel_size=(3, 3), padding=(1, 1))
w1 = relay.var("w1_weight")
conv1 = relay.nn.conv2d(data, w1, channels=3, kernel_size=(3, 3), padding=(1, 1))
out = relay.concatenate([conv0, conv1], axis=1)
net = relay.Function(relay.analysis.free_vars(out), out)
net, params = relay.testing.create_workload(net)

tasks = autotvm.task.extract_from_program(net["main"],
target=target,
params=params,
ops=(relay.op.nn.conv2d,))
wkl_list = [
create_workload((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
create_workload((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.01, 0.012, 0.03, 0.04]
config_list = []
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [1, 2]],
["tile_ow", "sp", [4, 8]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [1, 3]],
["tile_ow", "sp", [2, 16]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [2, 1]],
["tile_ow", "sp", [4, 8]],
["unroll_kw", "ot", True]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
cfg_dict = {"i": -1,
"c": None,
"e": [["tile_ic", "sp", [1, 5]],
["tile_oc", "sp", [3, 1]],
["tile_ow", "sp", [2, 16]],
["unroll_kw", "ot", False]],
"t": ""}
config_list.append(ConfigEntity.from_json_dict(cfg_dict))

records = []

wkl_list = wkl_list + wkl_list
tasks = tasks + tasks
for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
task.workload = wkl
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))

ltf_records = []
ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
ltf_keys.append(ltf_wkl)
ltf_task = copy.deepcopy(tasks[0])
ltf_task.workload = ltf_wkl
ms_input = MeasureInput(target=target, task=ltf_task, config=None)
ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
ltf_records.append((ms_input, ms_output))

executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
expected_out = [records[2][0].config, records[1][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))

executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
expected_out = [records[2][0].config, records[1][0].config]
assert expected_out == out, "Output mismatch: expecting %s but got %s" \
% (str(expected_out), str(out))

Expand All @@ -390,3 +472,4 @@ def test_many_sub_graphs():
test_DPTuner_run()
test_PBQPTuner_run()
test_many_sub_graphs()
test_tuple()