diff --git a/python/tvm/autotvm/graph_tuner/_base.py b/python/tvm/autotvm/graph_tuner/_base.py index 4002f67f43ba..e8d35ac35780 100644 --- a/python/tvm/autotvm/graph_tuner/_base.py +++ b/python/tvm/autotvm/graph_tuner/_base.py @@ -18,11 +18,6 @@ """Helper functions and global data""" -# Operators dependent on original layouts. -LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape", - "multibox_prior", "multibox_transform_loc", "where", - "non_max_suppression", "strided_slice"] - # We set a large time to represent an invalid layout-transformation. # This number is set to be 10e9 seconds to align with autotvm. INVALID_LAYOUT_TIME = 10e9 diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index 68bc6145d9f2..dca4148b058d 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -444,6 +444,7 @@ def _callback(_, inputs, results): timeout=timeout) measure_option = autotvm.measure_option(builder=builder, runner=runner) for args in args_list: + data, in_layout, out_layout = args args = serialize_args(args) ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args) if ltf_workload in self._layout_transform_perf_records: @@ -454,7 +455,18 @@ def _callback(_, inputs, results): flops = 1 for i in input_shape: flops *= i - inferred_time = flops * avg_time + + # Rule out invalid layout transformations + out = topi.layout_transform(data, in_layout, out_layout) + out_flops = 1 + for i in topi.util.get_const_tuple(out.shape): + out_flops *= i + + if flops != out_flops: + inferred_time = INVALID_LAYOUT_TIME + else: + inferred_time = flops * avg_time + record_input = MeasureInput(target=self._target, task=None, config=None) record_output = MeasureResult(costs=(inferred_time,), error_no=0, all_cost=-1, timestamp=-1) diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index a6eea6d243e4..19c319361ce5 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -26,7 +26,7 @@ from tvm.relay.ty import TupleType, TensorType from tvm.autotvm.task import TaskExtractEnv -from .utils import has_multiple_inputs, is_boundary_node +from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node # Setup relay op base name -> topi compute functions @@ -252,7 +252,7 @@ def get_in_nodes(node_list, target_ops, input_names): visited_dict = {} in_node_dict = {} for i, node in enumerate(node_list): - if is_boundary_node(node, input_names): + if is_boundary_node(node, input_names) or is_skipped_node(node): continue get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names) for key, val in visited_dict.items(): @@ -282,10 +282,12 @@ def get_in_nodes(node_list, target_ops, input_names): boundary_nodes.append(key) if boundary_nodes: for idx in boundary_nodes: - del in_node_dict[idx] + if idx in in_node_dict: + del in_node_dict[idx] else: has_reduced_node = False + # Remove empty nodes to ignore pre-computed sub-graph has_empty_node = True while has_empty_node: diff --git a/python/tvm/autotvm/graph_tuner/utils/utils.py b/python/tvm/autotvm/graph_tuner/utils/utils.py index 2570d81bae4e..d73f2c35f50e 100644 --- a/python/tvm/autotvm/graph_tuner/utils/utils.py +++ b/python/tvm/autotvm/graph_tuner/utils/utils.py @@ -19,8 +19,6 @@ from tvm import relay from tvm.relay import transform -from .._base import LAYOUT_FIXED_OP - def has_multiple_inputs(node_list, node_idx, input_names): """Check whether a node has multiple input nodes @@ -72,11 +70,35 @@ def is_boundary_node(node_entry, input_names): out : bool whether node is a boundary node. """ - out = node_entry["op"] in LAYOUT_FIXED_OP or \ + # Operators dependent on original layouts. + _LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape", + "multibox_prior", "multibox_transform_loc", "where", + "non_max_suppression", "strided_slice"] + + out = node_entry["op"] in _LAYOUT_FIXED_OP or \ ("name" in node_entry and node_entry["name"] in input_names) return out +def is_skipped_node(node_entry): + """Whether a node is not counted. + + Parameters + ---------- + node_entry : dict + Node entry. + + Returns + ------- + out : bool + whether node is skipped. + """ + # Operators not counted in graph tuner. + _SKIPPED_OP = ["Tuple"] + + return node_entry["op"] in _SKIPPED_OP + + def bind_inputs(expr, input_shapes=None, input_dtypes="float32"): """Bind input variables of a relay function expression to new shapes and/or dtypes. diff --git a/tests/python/unittest/test_graph_tuner_core.py b/tests/python/unittest/test_graph_tuner_core.py index 30b037e1598e..c26b4b83cb9b 100644 --- a/tests/python/unittest/test_graph_tuner_core.py +++ b/tests/python/unittest/test_graph_tuner_core.py @@ -354,25 +354,107 @@ def test_many_sub_graphs(): ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1) ltf_records.append((ms_input, ms_output)) - ltf_keys = [] - ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"] - ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg) - ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg) - ltf_keys.append(ltf_wkl) - ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"] - ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg) - ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg) - ltf_keys.append(ltf_wkl) - ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"] + executor = DPTuner(net, {"data": dshape}, records, target_ops, target) + executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True) + executor.run() + out = [record[0].config for record in executor.get_optimal_records()] + expected_out = [records[3][0].config, records[1][0].config, records[2][0].config] + assert expected_out == out, "Output mismatch: expecting %s but got %s" \ + % (str(expected_out), str(out)) + + executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target) + executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True) + executor.run() + out = [record[0].config for record in executor.get_optimal_records()] + expected_out = [records[3][0].config, records[1][0].config, records[2][0].config] + assert expected_out == out, "Output mismatch: expecting %s but got %s" \ + % (str(expected_out), str(out)) + + +def test_tuple(): + target = "llvm" + dtype = "float32" + dshape = (1, 5, 32, 32) + layout = "NCHW" + target_ops = [relay.nn.conv2d] + + data = relay.var("data", shape=dshape, dtype=dtype) + w0 = relay.var("w0_weight") + conv0 = relay.nn.conv2d(data, w0, channels=2, kernel_size=(3, 3), padding=(1, 1)) + w1 = relay.var("w1_weight") + conv1 = relay.nn.conv2d(data, w1, channels=3, kernel_size=(3, 3), padding=(1, 1)) + out = relay.concatenate([conv0, conv1], axis=1) + net = relay.Function(relay.analysis.free_vars(out), out) + net, params = relay.testing.create_workload(net) + + tasks = autotvm.task.extract_from_program(net["main"], + target=target, + params=params, + ops=(relay.op.nn.conv2d,)) + wkl_list = [ + create_workload((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype), + create_workload((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype), + ] + costs = [0.01, 0.012, 0.03, 0.04] + config_list = [] + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [1, 5]], + ["tile_oc", "sp", [1, 2]], + ["tile_ow", "sp", [4, 8]], + ["unroll_kw", "ot", True]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [1, 5]], + ["tile_oc", "sp", [1, 3]], + ["tile_ow", "sp", [2, 16]], + ["unroll_kw", "ot", False]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [1, 5]], + ["tile_oc", "sp", [2, 1]], + ["tile_ow", "sp", [4, 8]], + ["unroll_kw", "ot", True]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + cfg_dict = {"i": -1, + "c": None, + "e": [["tile_ic", "sp", [1, 5]], + ["tile_oc", "sp", [3, 1]], + ["tile_ow", "sp", [2, 16]], + ["unroll_kw", "ot", False]], + "t": ""} + config_list.append(ConfigEntity.from_json_dict(cfg_dict)) + + records = [] + + wkl_list = wkl_list + wkl_list + tasks = tasks + tasks + for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks): + task.workload = wkl + ms_input = MeasureInput(target=target, task=task, config=config) + ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1) + records.append((ms_input, ms_output)) + + ltf_records = [] + ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"] ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg) ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg) - ltf_keys.append(ltf_wkl) + ltf_task = copy.deepcopy(tasks[0]) + ltf_task.workload = ltf_wkl + ms_input = MeasureInput(target=target, task=ltf_task, config=None) + ms_output = MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1) + ltf_records.append((ms_input, ms_output)) executor = DPTuner(net, {"data": dshape}, records, target_ops, target) executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True) executor.run() out = [record[0].config for record in executor.get_optimal_records()] - expected_out = [records[3][0].config, records[1][0].config, records[2][0].config] + expected_out = [records[2][0].config, records[1][0].config] assert expected_out == out, "Output mismatch: expecting %s but got %s" \ % (str(expected_out), str(out)) @@ -380,7 +462,7 @@ def test_many_sub_graphs(): executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True) executor.run() out = [record[0].config for record in executor.get_optimal_records()] - expected_out = [records[3][0].config, records[1][0].config, records[2][0].config] + expected_out = [records[2][0].config, records[1][0].config] assert expected_out == out, "Output mismatch: expecting %s but got %s" \ % (str(expected_out), str(out)) @@ -390,3 +472,4 @@ def test_many_sub_graphs(): test_DPTuner_run() test_PBQPTuner_run() test_many_sub_graphs() + test_tuple()