Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jun 16, 2022
1 parent b9f47dd commit 4f08249
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 37 deletions.
29 changes: 16 additions & 13 deletions python/tvm/meta_schedule/testing/tune_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,22 @@ def main():
alloc_repeat=1,
max_workers=ARGS.rpc_workers,
)
lib = ms.tune_relay(
mod=mod,
target=ARGS.target,
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
),
runner=runner, # type: ignore
work_dir=ARGS.work_dir,
params=params,
)
with ms.Profiler() as profiler:
lib = ms.tune_relay(
mod=mod,
target=ARGS.target,
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
),
runner=runner, # type: ignore
work_dir=ARGS.work_dir,
params=params,
)
print("Tuning Time:")
print(profiler.table())
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
input_data = {}
for item in ARGS.input_shape:
Expand Down
31 changes: 17 additions & 14 deletions python/tvm/meta_schedule/testing/tune_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,23 @@ def main():
alloc_repeat=1,
max_workers=ARGS.rpc_workers,
)
sch: Optional[tir.Schedule] = ms.tune_tir(
mod=create_te_workload(ARGS.workload, 0),
target=ARGS.target,
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
),
runner=runner, # type: ignore
task_name=ARGS.workload,
work_dir=ARGS.work_dir,
num_threads=cpu_count(),
)
with ms.Profiler() as profiler:
sch: Optional[tir.Schedule] = ms.tune_tir(
mod=create_te_workload(ARGS.workload, 0),
target=ARGS.target,
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
),
runner=runner, # type: ignore
task_name=ARGS.workload,
work_dir=ARGS.work_dir,
num_threads=cpu_count(),
)
print("Tuning Time:")
print(profiler.table())
if sch is None:
print("No valid schedule found!")
else:
Expand Down
20 changes: 10 additions & 10 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,15 +430,13 @@ def tune_tir(
mutator_probs=mutator_probs,
num_threads=num_threads,
)
bests: List[TuningRecord] = database.get_top_k(
database.commit_workload(mod),
top_k=1,
)
if not bests:
return None
assert len(bests) == 1
sch = Schedule(mod)
bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
with Profiler.timeit("ApplyHistoryBest"):
bests: List[TuningRecord] = database.get_top_k(database.commit_workload(mod), top_k=1)
if not bests:
return None
assert len(bests) == 1
sch = Schedule(mod)
bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
return sch


Expand Down Expand Up @@ -488,8 +486,10 @@ def tune_te(
sch : Optional[Schedule]
The tuned schedule.
"""
with Profiler.timeit("CreatePrimFunc"):
func = create_prim_func(tensors)
return tune_tir(
mod=create_prim_func(tensors),
mod=func,
target=target,
config=config,
work_dir=work_dir,
Expand Down

0 comments on commit 4f08249

Please sign in to comment.