Skip to content

Commit

Permalink
GraphTuner supports relay.module as input (#3434)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun authored and tqchen committed Jun 27, 2019
1 parent a074daf commit 6c43019
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
3 changes: 3 additions & 0 deletions python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def __init__(self, graph, input_shapes, records, target_ops,
self._logger.propagate = False

# Generate workload and schedule dictionaries.
if isinstance(graph, relay.Module):
graph = graph[graph.entry_func]

if isinstance(graph, relay.expr.Function):
node_dict = {}
graph = bind_inputs(graph, input_shapes, dtype)
Expand Down
4 changes: 3 additions & 1 deletion tests/python/unittest/test_graph_tuner_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def test_DPTuner_run():
target_ops = [relay.nn.conv2d]

g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
mod = relay.module.Module()
mod[mod.entry_func] = g
costs = [0.02, 0.02, 0.045]
config_list = []
cfg_dict = {"i": -1,
Expand Down Expand Up @@ -190,7 +192,7 @@ def test_DPTuner_run():
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))

executor = DPTuner(g, {"data": dshape}, records, target_ops, target, log_file=log_file)
executor = DPTuner(mod, {"data": dshape}, records, target_ops, target, log_file=log_file)
executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
executor.run()
out = [record[0].config for record in executor.get_optimal_records()]
Expand Down

0 comments on commit 6c43019

Please sign in to comment.