Skip to content

Commit

Permalink
[DOCS] Fix vta tutorial (apache#4809)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and alexwong committed Feb 28, 2020
1 parent b3d3c0b commit d1d6550
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
1 change: 1 addition & 0 deletions tests/scripts/task_python_docs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ rm -rf docs/_build/html/javadoc

# remove stale tutorials and always build from scratch.
rm -rf docs/tutorials
rm -rf docs/vta/tutorials

# C++ doc
make doc
Expand Down
9 changes: 5 additions & 4 deletions vta/tutorials/autotvm/tune_relay_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def compile_network(env, target, model, start_pack, stop_pack):
# ----------------------------------
# key total free pending
# ----------------------------------
# pynq 6 6 0
# pynq 6 6 0
# rpi3b 11 11 0
# ----------------------------------
#
Expand Down Expand Up @@ -223,7 +223,7 @@ def compile_network(env, target, model, start_pack, stop_pack):
# .. note:: How to set tuning options
#
# In general, the default values provided here work well.
# If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping`
# If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping`
# to larger values, makes the tuning run for longer.
# If your device is under-powered or your conv2d operators are large, consider
# setting a longer timeout.
Expand Down Expand Up @@ -348,12 +348,13 @@ def tune_and_evaluate(tuning_opt):
# Perform task extraction on Relay program
print("Extract tasks...")
relay_prog, params = compile_network(env, target, network, start_pack, stop_pack)
tasks = autotvm.task.extract_from_program(func=relay_prog,
mod = relay.Module.from_expr(relay_prog)
tasks = autotvm.task.extract_from_program(mod,
params=params,
ops=(tvm.relay.op.nn.conv2d,),
target=target,
target_host=env.target_host)

# We should have extracted 10 convolution tasks
assert len(tasks) == 10
print("Extracted {} conv2d tasks:".format(len(tasks)))
Expand Down

0 comments on commit d1d6550

Please sign in to comment.