Skip to content

Commit

Permalink
[Graph tuner]Add opt out operator for has_multiple_inputs for graph t…
Browse files Browse the repository at this point in the history
…uner (#5000)

* consider layout_transform in has_multiple_inputs

* refactor code

* remove debug info

* remove subclass assignment

* refactoring a little bit

* remove default value

* remove trailing whitespace

* modify test for has_multiple_inputs

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
zhen-jia and Ubuntu authored Mar 13, 2020
1 parent 64bc997 commit 2e913f0
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 11 deletions.
2 changes: 2 additions & 0 deletions python/tvm/autotvm/graph_tuner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@
INVALID_LAYOUT_TIME = 10e9

MAX_OUTPUT_NODES = 16

OPT_OUT_OP = ["layout_transform"]
6 changes: 4 additions & 2 deletions python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
bind_inputs, expr2graph
from ._base import INVALID_LAYOUT_TIME

from ._base import OPT_OUT_OP

def get_infer_layout(task_name):
if task_name.startswith("conv2d"):
Expand Down Expand Up @@ -153,6 +154,7 @@ def __init__(self, graph, input_shapes, records, target_ops,
self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys())
self._out_nodes_dict = get_out_nodes(self._in_nodes_dict)
self._fetch_cfg()
self._opt_out_op = OPT_OUT_OP

# Setup infer_layout for elemwise-like nodes
# Note: graph tuner currently only supports tuning of single input and single output
Expand All @@ -162,7 +164,7 @@ def __init__(self, graph, input_shapes, records, target_ops,
# elemwise-like node, and use infer_layout function from input op to generate layouts.
input_names = self._input_shapes.keys()
for idx in sorted(self._in_nodes_dict.keys()):
if has_multiple_inputs(self._node_list, idx, input_names):
if has_multiple_inputs(self._node_list, idx, input_names, self._opt_out_op):
node_entry = self._node_list[idx]
node_entry["topi_op"] = []
node_entry["workloads"] = []
Expand Down Expand Up @@ -246,7 +248,7 @@ def _iterate_layout_transform(self, callback):
node_entry = self._node_list[key]
target_input_idx = -1
target_input_pos = -1
if has_multiple_inputs(self._node_list, key, input_names):
if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op):
for i, item in enumerate(val):
node = self._node_list[item]
if not is_boundary_node(node, input_names):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _backward(self):
continue
optimal_sch_idx = optimal_record_dict[node_idx]
full_states = self._stage_dict[node_idx].full_states
if not has_multiple_inputs(self._node_list, node_idx, input_names):
if not has_multiple_inputs(self._node_list, node_idx, input_names, self._opt_out_op):
input_idx = self._in_nodes_dict[node_idx][0]
input_node = self._node_list[input_idx]
if is_boundary_node(input_node, input_names):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/pbqp_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def run(self, **kwargs):
for key, val in self._in_nodes_dict.items():
target_input_idx = -1
target_input_pos = -1
if has_multiple_inputs(self._node_list, key, input_names):
if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op):
for i, item in enumerate(val):
node = self._node_list[item]
if not is_boundary_node(node, input_names):
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm.autotvm.task import TaskExtractEnv

from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node

from .._base import OPT_OUT_OP

def expr2graph(expr, target_ops, node_dict, node_list):
"""Convert relay expr to graph data structure
Expand Down Expand Up @@ -204,7 +204,8 @@ def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_nam
node_direct_ancestor = []
for item_idx in node["inputs"]:
item = node_list[item_idx[0]]
is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], input_names)
is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], \
input_names, OPT_OUT_OP)
if item["op"] in target_ops or is_multiple_inputs:
node_direct_ancestor.append(item_idx[0])
else:
Expand Down Expand Up @@ -245,7 +246,8 @@ def get_in_nodes(node_list, target_ops, input_names):
get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
for key, val in visited_dict.items():
node = node_list[key]
is_multiple_inputs = has_multiple_inputs(node_list, key, input_names)
is_multiple_inputs = has_multiple_inputs(node_list, key, \
input_names, OPT_OUT_OP)
if node["op"] in target_ops or is_multiple_inputs:
in_node_dict[key] = val

Expand Down
12 changes: 9 additions & 3 deletions python/tvm/autotvm/graph_tuner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from tvm import relay
from tvm.relay import transform


def has_multiple_inputs(node_list, node_idx, input_names):
def has_multiple_inputs(node_list, node_idx, input_names, opt_out_op):
"""Check whether a node has multiple input nodes
except variable nodes.
Expand All @@ -47,7 +46,14 @@ def has_multiple_inputs(node_list, node_idx, input_names):
in_idx = in_idx[0]
in_node = node_list[in_idx]
# Exclude parameter nodes
if in_node["op"] is not None or \
if(in_node["op"] is not None and in_node["op"].name in opt_out_op):
increase = False
for t_idx in in_node["inputs"]:
increase = has_multiple_inputs(node_list, t_idx[0], \
input_names, opt_out_op)
if increase:
num_inputs += 1
elif in_node["op"] is not None or \
("name" in in_node and in_node["name"] in input_names):
num_inputs += 1
return num_inputs > 1
Expand Down
3 changes: 2 additions & 1 deletion tests/python/unittest/test_graph_tuner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@
from tvm.relay.testing import resnet
from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
get_out_nodes, expr2graph, bind_inputs
from tvm.autotvm.graph_tuner._base import OPT_OUT_OP
from tvm.relay.expr import Call, TupleGetItem, Tuple, Var


def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
out = has_multiple_inputs(node_list, node_idx, input_names)
out = has_multiple_inputs(node_list, node_idx, input_names, OPT_OUT_OP)
assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \
% (node_list[node_idx]["op"], str(expected_result), str(out))

Expand Down

0 comments on commit 2e913f0

Please sign in to comment.