Skip to content

Commit

Permalink
Improve the x86 auto-tune tutorial (apache#3609)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 authored and wweic committed Sep 6, 2019
1 parent 8f8e768 commit b177d54
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tutorials/autotvm/tune_relay_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_network(name, batch_size):
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
mod, params = relay.frontend.from_mxnet(block, shape={input_name: input_shape}, dtype=dtype)
net = mod["main"]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
mod = relay.Module.from_expr(net)
Expand All @@ -86,6 +86,10 @@ def get_network(name, batch_size):
log_file = "%s.log" % model_name
graph_opt_sch_file = "%s_graph_opt.log" % model_name

# Set the input name of the graph
# For ONNX models, it is typically "0".
input_name = "data"

# Set number of threads used for tuning based on the number of
# physical CPU cores on your machine.
num_threads = 1
Expand Down Expand Up @@ -166,7 +170,7 @@ def tune_kernels(tasks,
def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
target_op = [relay.nn.conv2d]
Tuner = DPTuner if use_DP else PBQPTuner
executor = Tuner(graph, {"data": dshape}, records, target_op, target)
executor = Tuner(graph, {input_name: dshape}, records, target_op, target)
executor.benchmark_layout_transform(min_exec_num=2000)
executor.run()
executor.write_opt_sch2record_file(opt_sch_file)
Expand Down Expand Up @@ -198,7 +202,7 @@ def tune_and_evaluate(tuning_opt):
ctx = tvm.cpu()
data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype))
module = runtime.create(graph, lib, ctx)
module.set_input('data', data_tvm)
module.set_input(input_name, data_tvm)
module.set_input(**params)

# evaluate
Expand Down

0 comments on commit b177d54

Please sign in to comment.