Skip to content

Commit

Permalink
[MetaSchedule][Refactor] Introduce TuneConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Apr 18, 2022
1 parent 9c2df39 commit 43167a9
Show file tree
Hide file tree
Showing 12 changed files with 268 additions and 479 deletions.
9 changes: 1 addition & 8 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,5 @@
from .extracted_task import ExtractedTask
from .relay_integration import extract_task_from_relay
from .search_strategy import MeasureCandidate
from .tune import (
EvolutionarySearchConfig,
ReplayFuncConfig,
ReplayTraceConfig,
tune_relay,
tune_te,
tune_tir,
)
from .tune import TuneConfig, tune_relay, tune_te, tune_tir
from .tune_context import TuneContext
14 changes: 7 additions & 7 deletions python/tvm/meta_schedule/search_strategy/evolutionary_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def __init__(
*,
num_trials_per_iter: int,
max_trials_per_task: int,
population_size: int,
init_measured_ratio: float,
init_min_unmeasured: int,
genetic_num_iters: int,
genetic_mutate_prob: float,
genetic_max_fail_count: int,
eps_greedy: float,
population_size: int = 2048,
init_measured_ratio: float = 0.2,
init_min_unmeasured: int = 50,
genetic_num_iters: int = 4,
genetic_mutate_prob: float = 0.85,
genetic_max_fail_count: int = 10,
eps_greedy: float = 0.05,
) -> None:
"""Constructor"""
self.__init_handle_by_constructor__(
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ class RoundRobin(TaskScheduler):
def __init__(
self,
tasks: List["TuneContext"],
task_weights: List[float],
builder: Builder,
runner: Runner,
database: Database,
max_trials: int,
*,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
) -> None:
Expand All @@ -66,6 +68,8 @@ def __init__(
----------
tasks : List[TuneContext]
List of tasks to schedule.
task_weights : List[float]
List of weights for each task. Not used in round robin.
builder : Builder
The builder.
runner : Runner
Expand All @@ -79,6 +83,7 @@ def __init__(
measure_callbacks: Optional[List[MeasureCallback]]
The list of measure callbacks of the scheduler.
"""
del task_weights
self.__init_handle_by_constructor__(
_ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member
tasks,
Expand Down
56 changes: 2 additions & 54 deletions python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@
import argparse
import json
import logging
import os

import numpy as np # type: ignore
import tvm
from tvm import meta_schedule as ms
from tvm.ir.transform import PassContext
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.relay import build as relay_build


def _parse_args():
Expand Down Expand Up @@ -98,54 +95,6 @@ def _parse_args():
ARGS = _parse_args()


def tune_each_task(
mod,
target,
config,
runner,
work_dir,
params,
):
extracted_tasks = ms.extract_task_from_relay(mod, target, params)
database = ms.database.JSONDatabase(
path_workload=os.path.join(work_dir, "default_database_workload.json"),
path_tuning_record=os.path.join(work_dir, "default_database_tuning_record.json"),
)
for task in extracted_tasks:
# pylint: disable=protected-access
tune_context = ms.tune.Parse._tune_context(
tune_context=None,
mod=ms.tune.Parse._mod(task.dispatched[0]),
target=target,
config=config,
task_name=task.task_name,
space_generator=None,
sch_rules=None,
postprocs=None,
mutator_probs=None,
num_threads=os.cpu_count(),
)
task_scheduler = ms.tune.Parse._task_scheduler(
None,
[tune_context],
task_weights=[1.0],
builder=ms.tune.Parse._builder(None),
runner=ms.tune.Parse._runner(runner),
database=database,
max_trials=config.max_trials_per_task,
cost_model=ms.tune.Parse._cost_model(None),
measure_callbacks=ms.tune.Parse._callbacks(None),
)
# pylint: enable=protected-access
task_scheduler.tune()
with target, ms.ApplyHistoryBest(database):
with PassContext(
opt_level=3,
config={"relay.backend.use_meta_schedule": True},
):
return relay_build(mod, target=target, params=params)


def main():
mod, params, (input_name, input_shape, input_dtype) = get_network(
ARGS.workload,
Expand All @@ -168,15 +117,14 @@ def main():
alloc_repeat=alloc_repeat,
max_workers=ARGS.rpc_workers,
)
# lib = tune_each_task(
lib = ms.tune_relay(
mod=mod,
target=ARGS.target,
config=ms.EvolutionarySearchConfig(
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
init_min_unmeasured=50,
),
runner=runner, # type: ignore
work_dir=ARGS.work_dir,
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/testing/tune_te_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ def main():
sch: Optional[tir.Schedule] = ms.tune_tir(
mod=create_te_workload(ARGS.workload, 0),
target=ARGS.target,
config=ms.EvolutionarySearchConfig(
config=ms.TuneConfig(
strategy="evolutionary",
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
init_min_unmeasured=50,
),
runner=runner, # type: ignore
task_name=ARGS.workload,
Expand Down
Loading

0 comments on commit 43167a9

Please sign in to comment.