From 43167a95f5d0bc86fe00b3e3db5e7c2724f1afca Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 16 Apr 2022 00:02:27 -0700 Subject: [PATCH 1/2] [MetaSchedule][Refactor] Introduce TuneConfig --- python/tvm/meta_schedule/__init__.py | 9 +- .../search_strategy/evolutionary_search.py | 14 +- .../task_scheduler/round_robin.py | 5 + .../testing/tune_relay_meta_schedule.py | 56 +- .../testing/tune_te_meta_schedule.py | 4 +- python/tvm/meta_schedule/tune.py | 560 ++++++------------ .../test_meta_schedule_measure_callback.py | 4 +- .../test_meta_schedule_search_strategy.py | 2 + .../test_meta_schedule_task_scheduler.py | 2 + .../unittest/test_meta_schedule_tune_relay.py | 72 ++- .../unittest/test_meta_schedule_tune_te.py | 5 +- .../unittest/test_meta_schedule_tune_tir.py | 14 +- 12 files changed, 268 insertions(+), 479 deletions(-) diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 466c5e3e6699..76eebbdf23f1 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -32,12 +32,5 @@ from .extracted_task import ExtractedTask from .relay_integration import extract_task_from_relay from .search_strategy import MeasureCandidate -from .tune import ( - EvolutionarySearchConfig, - ReplayFuncConfig, - ReplayTraceConfig, - tune_relay, - tune_te, - tune_tir, -) +from .tune import TuneConfig, tune_relay, tune_te, tune_tir from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py index 20d0b33378e3..f54fc53935f0 100644 --- a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -64,13 +64,13 @@ def __init__( *, num_trials_per_iter: int, max_trials_per_task: int, - population_size: int, - init_measured_ratio: float, - init_min_unmeasured: int, - genetic_num_iters: int, - genetic_mutate_prob: float, - genetic_max_fail_count: int, - eps_greedy: float, + population_size: int = 2048, + init_measured_ratio: float = 0.2, + init_min_unmeasured: int = 50, + genetic_num_iters: int = 4, + genetic_mutate_prob: float = 0.85, + genetic_max_fail_count: int = 10, + eps_greedy: float = 0.05, ) -> None: """Constructor""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 16d06ab1fd72..6634d6193e26 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -53,10 +53,12 @@ class RoundRobin(TaskScheduler): def __init__( self, tasks: List["TuneContext"], + task_weights: List[float], builder: Builder, runner: Runner, database: Database, max_trials: int, + *, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, ) -> None: @@ -66,6 +68,8 @@ def __init__( ---------- tasks : List[TuneContext] List of tasks to schedule. + task_weights : List[float] + List of weights for each task. Not used in round robin. builder : Builder The builder. runner : Runner @@ -79,6 +83,7 @@ def __init__( measure_callbacks: Optional[List[MeasureCallback]] The list of measure callbacks of the scheduler. """ + del task_weights self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member tasks, diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py index 0973c9b91bff..d8e6d38695ac 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py @@ -18,15 +18,12 @@ import argparse import json import logging -import os import numpy as np # type: ignore import tvm from tvm import meta_schedule as ms -from tvm.ir.transform import PassContext from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.relay_workload import get_network -from tvm.relay import build as relay_build def _parse_args(): @@ -98,54 +95,6 @@ def _parse_args(): ARGS = _parse_args() -def tune_each_task( - mod, - target, - config, - runner, - work_dir, - params, -): - extracted_tasks = ms.extract_task_from_relay(mod, target, params) - database = ms.database.JSONDatabase( - path_workload=os.path.join(work_dir, "default_database_workload.json"), - path_tuning_record=os.path.join(work_dir, "default_database_tuning_record.json"), - ) - for task in extracted_tasks: - # pylint: disable=protected-access - tune_context = ms.tune.Parse._tune_context( - tune_context=None, - mod=ms.tune.Parse._mod(task.dispatched[0]), - target=target, - config=config, - task_name=task.task_name, - space_generator=None, - sch_rules=None, - postprocs=None, - mutator_probs=None, - num_threads=os.cpu_count(), - ) - task_scheduler = ms.tune.Parse._task_scheduler( - None, - [tune_context], - task_weights=[1.0], - builder=ms.tune.Parse._builder(None), - runner=ms.tune.Parse._runner(runner), - database=database, - max_trials=config.max_trials_per_task, - cost_model=ms.tune.Parse._cost_model(None), - measure_callbacks=ms.tune.Parse._callbacks(None), - ) - # pylint: enable=protected-access - task_scheduler.tune() - with target, ms.ApplyHistoryBest(database): - with PassContext( - opt_level=3, - config={"relay.backend.use_meta_schedule": True}, - ): - return relay_build(mod, target=target, params=params) - - def main(): mod, params, (input_name, input_shape, input_dtype) = get_network( ARGS.workload, @@ -168,15 +117,14 @@ def main(): alloc_repeat=alloc_repeat, max_workers=ARGS.rpc_workers, ) - # lib = tune_each_task( lib = ms.tune_relay( mod=mod, target=ARGS.target, - config=ms.EvolutionarySearchConfig( + config=ms.TuneConfig( + strategy="evolutionary", num_trials_per_iter=64, max_trials_per_task=ARGS.num_trials, max_trials_global=ARGS.num_trials, - init_min_unmeasured=50, ), runner=runner, # type: ignore work_dir=ARGS.work_dir, diff --git a/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py index abba94ad7a5e..2e8b538b9cc9 100644 --- a/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_te_meta_schedule.py @@ -100,11 +100,11 @@ def main(): sch: Optional[tir.Schedule] = ms.tune_tir( mod=create_te_workload(ARGS.workload, 0), target=ARGS.target, - config=ms.EvolutionarySearchConfig( + config=ms.TuneConfig( + strategy="evolutionary", num_trials_per_iter=64, max_trials_per_task=ARGS.num_trials, max_trials_global=ARGS.num_trials, - init_min_unmeasured=50, ), runner=runner, # type: ignore task_name=ARGS.workload, diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 201434665af5..05df6f3c33bd 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -18,10 +18,10 @@ # pylint: disable=import-outside-toplevel import logging import os.path -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union from tvm._ffi.registry import register_func -from tvm.ir import IRModule, structural_hash +from tvm.ir import IRModule from tvm.ir.transform import PassContext from tvm.runtime import Module, NDArray from tvm.target import Target @@ -41,7 +41,7 @@ from .schedule_rule import ScheduleRule from .search_strategy import EvolutionarySearch, ReplayFunc, ReplayTrace from .space_generator import PostOrderApply, SpaceGenerator -from .task_scheduler import GradientBased, TaskScheduler +from .task_scheduler import GradientBased, RoundRobin from .tune_context import TuneContext from .utils import autotvm_silencer @@ -51,119 +51,6 @@ FnScheduleRule = Callable[[], List[ScheduleRule]] FnPostproc = Callable[[], List[Postproc]] FnMutatorProb = Callable[[], Dict[Mutator, float]] -FnTaskScheduler = Callable[ - [ - List[TuneContext], - List[float], - Builder, - Runner, - Database, - CostModel, - List[MeasureCallback], - ], - TaskScheduler, -] - - -class ReplayFuncConfig(NamedTuple): - """Configuration for ReplayFunc - - Parameters - ---------- - num_trials_per_iter : int - Number of trials per iteration. - max_trials_per_task : int - Total number of trials for one task - max_trials_global : int - Total number of trials for all tasks in the task scheduler - """ - - num_trials_per_iter: int - max_trials_per_task: int - max_trials_global: int - - def create_strategy(self) -> ReplayFunc: - return ReplayFunc(self.num_trials_per_iter, self.max_trials_per_task) - - -class ReplayTraceConfig(NamedTuple): - """Configuration for ReplayTrace - - Parameters - ---------- - num_trials_per_iter : int - Number of trials per iteration. - max_trials_per_task : int - Total number of trials for one task - max_trials_global : int - Total number of trials for all tasks in the task scheduler - """ - - num_trials_per_iter: int - max_trials_per_task: int - max_trials_global: int - - def create_strategy(self) -> ReplayTrace: - return ReplayTrace(self.num_trials_per_iter, self.max_trials_per_task) - - -class EvolutionarySearchConfig(NamedTuple): - """Configuration for EvolutionarySearch - - Parameters - ---------- - num_trials_per_iter : int - Number of trials per iteration. - max_trials_per_task : int - Total number of trials. - max_trials_global : int - Total number of trials for all tasks in the task scheduler - population_size : int - The initial population of traces from measured samples and randomly generated samples. - init_measured_ratio : int - The ratio of measured samples in the initial population. - init_min_unmeasured : int - The minimal size of unmeasured population in the initial sampling. - genetic_num_iters : int - The number of iterations for genetic algorithm. - genetic_mutate_prob : float - The probability of mutation. - genetic_max_fail_count : int - The maximum number to retry mutation. - eps_greedy : float - The ratio of greedy selected samples in the final picks. - """ - - num_trials_per_iter: int - max_trials_per_task: int - max_trials_global: int - population_size: int = 2048 - init_measured_ratio: float = 0.2 - init_min_unmeasured: int = 50 - genetic_num_iters: int = 4 - genetic_mutate_prob: float = 0.85 - genetic_max_fail_count: int = 10 - eps_greedy: float = 0.05 - - def create_strategy(self) -> EvolutionarySearch: - return EvolutionarySearch( - num_trials_per_iter=self.num_trials_per_iter, - max_trials_per_task=self.max_trials_per_task, - population_size=self.population_size, - init_measured_ratio=self.init_measured_ratio, - init_min_unmeasured=self.init_min_unmeasured, - genetic_num_iters=self.genetic_num_iters, - genetic_mutate_prob=self.genetic_mutate_prob, - genetic_max_fail_count=self.genetic_max_fail_count, - eps_greedy=self.eps_greedy, - ) - - -SearchStrategyConfig = Union[ - ReplayFuncConfig, - ReplayTraceConfig, - EvolutionarySearchConfig, -] class DefaultLLVM: @@ -337,10 +224,10 @@ def _runner(runner: Optional[Runner]) -> Runner: return runner @staticmethod - def _database(database: Union[None, Database], task_name: str, path: str) -> Database: + def _database(database: Union[None, Database], path: str) -> Database: if database is None: - path_workload = os.path.join(path, f"{task_name}_database_workload.json") - path_tuning_record = os.path.join(path, f"{task_name}_database_tuning_record.json") + path_workload = os.path.join(path, "database_workload.json") + path_tuning_record = os.path.join(path, "database_tuning_record.json") logger.info( "Creating JSONDatabase. Workload at: %s. Tuning records at: %s", path_workload, @@ -449,95 +336,189 @@ def _mutator_probs( # pylint: enable=protected-access raise ValueError(f"Unsupported target: {target}") - @staticmethod - def _tune_context( - tune_context: Optional[TuneContext], - mod: IRModule, - target: Target, - config: SearchStrategyConfig, - task_name: str, - space_generator: Optional[FnSpaceGenerator], - sch_rules: Optional[FnScheduleRule], - postprocs: Optional[FnPostproc], - mutator_probs: Optional[FnMutatorProb], - num_threads: Optional[int], - ) -> TuneContext: - if tune_context is None: - return TuneContext( - mod=mod, - target=target, - # pylint: disable=protected-access - space_generator=Parse._space_generator(space_generator), - search_strategy=config.create_strategy(), - sch_rules=Parse._sch_rules(sch_rules, target), - postprocs=Parse._postproc(postprocs, target), - mutator_probs=Parse._mutator_probs(mutator_probs, target), - # pylint: enable=protected-access - task_name=task_name, - rand_state=-1, - num_threads=num_threads, - ) - if not isinstance(tune_context, TuneContext): - raise TypeError(f"Expected `tune_context` to be TuneContext, but gets: {tune_context}") - return tune_context - @staticmethod - def _task_scheduler( - task_scheduler: Union[None, TaskScheduler, FnTaskScheduler], - tasks: List[TuneContext], - task_weights: List[float], - builder: Builder, - runner: Runner, - database: Database, - max_trials: int, - cost_model: CostModel, - measure_callbacks: List[MeasureCallback], - ): - if task_scheduler is None: - return GradientBased( - tasks=tasks, - task_weights=task_weights, - builder=builder, - runner=runner, - database=database, - max_trials=max_trials, - cost_model=cost_model, - measure_callbacks=measure_callbacks, +class TuneConfig(NamedTuple): + """Configuration for tuning + + Parameters + ---------- + max_trials_global: int + Maximum number of trials to run. + num_trials_per_iter: int + Number of trials to run per iteration. + max_trials_per_task: int + Maximum number of trials to run per task. + task_scheduler: str + Task scheduler to use. + Valid options are: round_robin, gradient. + search_strategy: str + Search strategy to use. + Valid options are: evolutionary, replay_func, replay_trace. + task_scheduler_config: Dict[str, Any] + Configuration for task scheduler. + search_strategy_config: Dict[str, Any] + Configuration for search strategy. + """ + + max_trials_global: int + num_trials_per_iter: int + max_trials_per_task: Optional[int] = None + task_scheduler: str = "gradient" + strategy: str = "evolutionary" + task_scheduler_config: Dict[str, Any] = {} + search_strategy_config: Dict[str, Any] = {} + + def create_strategy(self, **kwargs): + """Create search strategy from configuration""" + cls_tbl = { + "evolutionary": EvolutionarySearch, + "replay_func": ReplayFunc, + "replay_trace": ReplayTrace, + } + if self.strategy not in cls_tbl: + raise ValueError( + f"Invalid search strategy: {self.strategy}. " + "Valid options are: {}".format(", ".join(cls_tbl.keys())) ) - if callable(task_scheduler): - return task_scheduler( - tasks, - task_weights, - builder, - runner, - database, - cost_model, - measure_callbacks, + max_trials_per_task = self.max_trials_per_task + if max_trials_per_task is None: + max_trials_per_task = self.max_trials_global + return cls_tbl[self.strategy]( + num_trials_per_iter=self.num_trials_per_iter, + max_trials_per_task=max_trials_per_task, + **kwargs, + **self.search_strategy_config, + ) + + def create_task_scheduler(self, **kwargs): + """Create task scheduler from configuration""" + cls_tbl = { + "round_robin": RoundRobin, + "gradient": GradientBased, + } + if self.task_scheduler not in cls_tbl: + raise ValueError( + f"Invalid task scheduler: {self.task_scheduler}. " + "Valid options are: {}".format(", ".join(cls_tbl.keys())) ) - if not isinstance(task_scheduler, TaskScheduler): - raise TypeError( - f"Expected `task_scheduler` to be TaskScheduler, but gets: {task_scheduler}" + return cls_tbl[self.task_scheduler]( + max_trials=self.max_trials_global, + **kwargs, + **self.task_scheduler_config, + ) + + +def tune_extracted_tasks( + extracted_tasks: List[ExtractedTask], + config: TuneConfig, + work_dir: str, + *, + builder: Optional[Builder] = None, + runner: Optional[Runner] = None, + database: Optional[Database] = None, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, + space: Optional[FnSpaceGenerator] = None, + sch_rules: Optional[FnScheduleRule] = None, + postprocs: Optional[FnPostproc] = None, + mutator_probs: Optional[FnMutatorProb] = None, + num_threads: Optional[int] = None, +) -> Database: + """Tune extracted tasks with a given target. + + Parameters + ---------- + extracted_tasks : List[ExtractedTask] + The list of extraced tasks. + config : TuneConfig + The search strategy config. + work_dir : Optional[str] + The working directory to save intermediate results. + builder : Optional[Builder] + The builder to use. + runner : Optional[Runner] + The runner to use. + database : Optional[Database] + The database to use. + cost_model : Optional[CostModel] + The cost model to use. + measure_callbacks : Optional[List[MeasureCallback]] + The callbacks used during tuning. + task_scheduler : Optional[TaskScheduler] + The task scheduler to use. + space : Optional[FnSpaceGenerator] + The space generator to use. + sch_rules : Optional[FnScheduleRule] + The search rules to use. + postprocs : Optional[FnPostproc] + The postprocessors to use. + mutator_probs : Optional[FnMutatorProb] + The probability distribution to use different mutators. + num_threads : Optional[int] + The number of threads to use. + + Returns + ------- + database : Database + The database containing all the tuning results. + + """ + logger.info("Working directory: %s", work_dir) + # pylint: disable=protected-access + database = Parse._database(database, work_dir) + builder = Parse._builder(builder) + runner = Parse._runner(runner) + cost_model = Parse._cost_model(cost_model) + measure_callbacks = Parse._callbacks(measure_callbacks) + # parse the tuning contexts + tune_contexts = [] + for task in extracted_tasks: + assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" + tune_contexts.append( + TuneContext( + mod=Parse._mod(task.dispatched[0]), + target=task.target, + space_generator=Parse._space_generator(space), + search_strategy=config.create_strategy(), + sch_rules=Parse._sch_rules(sch_rules, task.target), + postprocs=Parse._postproc(postprocs, task.target), + mutator_probs=Parse._mutator_probs(mutator_probs, task.target), + task_name=task.task_name, + num_threads=num_threads, ) - return task_scheduler + ) + # parse the task scheduler + # pylint: enable=protected-access + task_scheduler = config.create_task_scheduler( + tasks=tune_contexts, + task_weights=[float(t.weight) for t in extracted_tasks], + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + ) + task_scheduler.tune() + cost_model.save(os.path.join(work_dir, "cost_model.xgb")) + return database def tune_tir( mod: Union[IRModule, PrimFunc], target: Union[str, Target], - config: SearchStrategyConfig, + config: TuneConfig, work_dir: str, *, - task_name: str = "main", builder: Optional[Builder] = None, runner: Optional[Runner] = None, database: Optional[Database] = None, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, - task_scheduler: Optional[TaskScheduler] = None, space: Optional[FnSpaceGenerator] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, mutator_probs: Optional[FnMutatorProb] = None, + task_name: str = "main", num_threads: Optional[int] = None, ) -> Optional[Schedule]: """Tune a TIR IRModule with a given target. @@ -548,7 +529,7 @@ def tune_tir( The module to tune. target : Union[str, Target] The target to tune for. - config : SearchStrategyConfig + config : TuneConfig The search strategy config. work_dir : Optional[str] The working directory to save intermediate results. @@ -562,46 +543,39 @@ def tune_tir( The cost model to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. - f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] - The function to create TuneContext. - f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] - The function to create TaskScheduler. Returns ------- sch : Optional[Schedule] The tuned schedule. """ - - logger.info("Working directory: %s", work_dir) # pylint: disable=protected-access mod = Parse._mod(mod) - database = Parse._database(database, task_name, work_dir) - tune_context = Parse._tune_context( - tune_context=None, - mod=mod, - target=Parse._target(target), + target = Parse._target(target) + # pylint: enable=protected-access + database = tune_extracted_tasks( + extracted_tasks=[ + ExtractedTask( + task_name=task_name, + mod=mod, + dispatched=[mod], + target=target, + weight=1, + ), + ], config=config, - task_name=task_name, - space_generator=space, + work_dir=work_dir, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + space=space, sch_rules=sch_rules, postprocs=postprocs, mutator_probs=mutator_probs, num_threads=num_threads, ) - task_scheduler = Parse._task_scheduler( - task_scheduler, - [tune_context], - task_weights=[1.0], - builder=Parse._builder(builder), - runner=Parse._runner(runner), - database=database, - max_trials=config.max_trials_global, - cost_model=Parse._cost_model(cost_model), - measure_callbacks=Parse._callbacks(measure_callbacks), - ) - # pylint: enable=protected-access - task_scheduler.tune() bests: List[TuningRecord] = database.get_top_k( database.commit_workload(mod), top_k=1, @@ -611,14 +585,13 @@ def tune_tir( assert len(bests) == 1 sch = Schedule(mod) bests[0].trace.apply_to_schedule(sch, remove_postproc=False) - task_scheduler.cost_model.save(os.path.join(work_dir, f"{task_name}.xgb")) return sch def tune_te( tensors: List[Tensor], target: Union[str, Target], - config: SearchStrategyConfig, + config: TuneConfig, work_dir: str, *, task_name: str = "main", @@ -627,7 +600,6 @@ def tune_te( database: Optional[Database] = None, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, - task_scheduler: Optional[TaskScheduler] = None, space: Optional[FnSpaceGenerator] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, @@ -642,7 +614,7 @@ def tune_te( The list of input/output tensors of the TE compute DAG. target : Union[str, Target] The target to tune for. - config : SearchStrategyConfig + config : TuneConfig The search strategy config. task_name : str The name of the task. @@ -656,10 +628,6 @@ def tune_te( The database to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. - f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] - The function to create TuneContext. - f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] - The function to create TaskScheduler. Returns ------- @@ -677,7 +645,6 @@ def tune_te( database=database, cost_model=cost_model, measure_callbacks=measure_callbacks, - task_scheduler=task_scheduler, space=space, sch_rules=sch_rules, postprocs=postprocs, @@ -686,144 +653,10 @@ def tune_te( ) -def deduplicate_extracted_tasks( - extracted_tasks: List[ExtractedTask], -) -> Tuple[List[ExtractedTask], List[int]]: - """Remove duplicate extraced tasks. - - Parameters - ---------- - extracted_tasks : List[ExtractedTask] - The list of extraced tasks. - - Returns - ------- - tasks : Tuple[List[ExtractedTask], List[int]] - A tuple containing the deduplicated extraced tasks and the count for each task. - """ - hash2idx: Dict[int, int] = {} - dedup: List[ExtractedTask] = [] - count: List[int] = [] - - for task in extracted_tasks: - assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" - mod = Parse._mod(task.dispatched[0]) # pylint: disable=protected-access - shash = structural_hash(mod) - if shash in hash2idx: - count[hash2idx[shash]] += 1 - else: - hash2idx[shash] = len(dedup) - dedup.append(task) - count.append(1) - return dedup, count - - -def tune_extracted_tasks( - extracted_tasks: List[ExtractedTask], - target: Target, - config: SearchStrategyConfig, - work_dir: str, - *, - builder: Optional[Builder] = None, - runner: Optional[Runner] = None, - database: Optional[Database] = None, - cost_model: Optional[CostModel] = None, - measure_callbacks: Optional[List[MeasureCallback]] = None, - task_scheduler: Optional[TaskScheduler] = None, - space: Optional[FnSpaceGenerator] = None, - sch_rules: Optional[FnScheduleRule] = None, - postprocs: Optional[FnPostproc] = None, - mutator_probs: Optional[FnMutatorProb] = None, - num_threads: Optional[int] = None, -) -> Database: - """Tune extracted tasks with a given target. - - Parameters - ---------- - extracted_tasks : List[ExtractedTask] - The list of extraced tasks. - target : Union[str, Target] - The target to tune for. - config : SearchStrategyConfig - The search strategy config. - work_dir : Optional[str] - The working directory to save intermediate results. - builder : Optional[Builder] - The builder to use. - runner : Optional[Runner] - The runner to use. - database : Optional[Database] - The database to use. - cost_model : Optional[CostModel] - The cost model to use. - measure_callbacks : Optional[List[MeasureCallback]] - The callbacks used during tuning. - task_scheduler : Optional[TaskScheduler] - The task scheduler to use. - space : Optional[FnSpaceGenerator] - The space generator to use. - sch_rules : Optional[FnScheduleRule] - The search rules to use. - postprocs : Optional[FnPostproc] - The postprocessors to use. - mutator_probs : Optional[FnMutatorProb] - The probability distribution to use different mutators. - num_threads : Optional[int] - The number of threads to use. - - Returns - ------- - database : Database - The database containing all the tuning results. - - """ - # deduplication - logger.info("Before task deduplication: %d tasks", len(extracted_tasks)) - extracted_tasks, _ = deduplicate_extracted_tasks(extracted_tasks) - logger.info("After task deduplication: %d tasks", len(extracted_tasks)) - # pylint: disable=protected-access - target = Parse._target(target) - # parse the tuning contexts - tune_contexts = [] - for task in extracted_tasks: - assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" - tune_contexts.append( - Parse._tune_context( - tune_context=None, - mod=Parse._mod(task.dispatched[0]), - target=target, - config=config, - task_name=task.task_name, - space_generator=space, - sch_rules=sch_rules, - postprocs=postprocs, - mutator_probs=mutator_probs, - num_threads=num_threads, - ) - ) - # parse the task scheduler - database = Parse._database(database, "default", work_dir) - task_scheduler = Parse._task_scheduler( - task_scheduler, - tune_contexts, - task_weights=[float(t.weight) for t in extracted_tasks], - builder=Parse._builder(builder), - runner=Parse._runner(runner), - database=database, - max_trials=config.max_trials_global, - cost_model=Parse._cost_model(cost_model), - measure_callbacks=Parse._callbacks(measure_callbacks), - ) - # pylint: enable=protected-access - task_scheduler.tune() - task_scheduler.cost_model.save(os.path.join(work_dir, "cost_model.xgb")) - return database - - def tune_relay( mod: IRModule, target: Union[str, Target], - config: SearchStrategyConfig, + config: TuneConfig, work_dir: str, *, params: Optional[Dict[str, NDArray]] = None, @@ -832,7 +665,6 @@ def tune_relay( database: Optional[Database] = None, cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, - task_scheduler: Optional[TaskScheduler] = None, space: Optional[FnSpaceGenerator] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, @@ -847,7 +679,7 @@ def tune_relay( The module to tune. target : Union[str, Target] The target to tune for. - config : SearchStrategyConfig + config : TuneConfig The search strategy config. params : Optional[Dict[str, tvm.runtime.NDArray]] The associated parameters of the program @@ -863,10 +695,6 @@ def tune_relay( The database to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. - f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] - The function to create TuneContext. - f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] - The function to create TaskScheduler. Returns ------- @@ -887,7 +715,6 @@ def tune_relay( extracted_tasks = extract_task_from_relay(mod, target, params) database = tune_extracted_tasks( extracted_tasks, - target, config, work_dir, builder=builder, @@ -895,7 +722,6 @@ def tune_relay( database=database, cost_model=cost_model, measure_callbacks=measure_callbacks, - task_scheduler=task_scheduler, space=space, sch_rules=sch_rules, postprocs=postprocs, diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py index df8d0fe38315..a1b188930f86 100644 --- a/tests/python/unittest/test_meta_schedule_measure_callback.py +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -78,7 +78,7 @@ def apply( measure_callback = FancyMeasureCallback() measure_callback.apply( - RoundRobin([], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), + RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], @@ -102,7 +102,7 @@ def apply( measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( - RoundRobin([], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), + RoundRobin([], [], DummyBuilder(), DummyRunner(), DummyDatabase(), max_trials=1), 0, [MeasureCandidate(Schedule(Matmul), None)], [BuilderResult("test_build", None)], diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index ca9c50b521be..b148f58ff804 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -145,6 +145,7 @@ def _schedule_matmul_small(sch: Schedule): ) _scheduler = RoundRobin( tasks=[context], + task_weights=[1.0], builder=ms.builder.LocalBuilder(), runner=ms.runner.LocalRunner(), database=DummyDatabase(), @@ -207,6 +208,7 @@ def _schedule_matmul_empty(sch: Schedule): ) _scheduler = RoundRobin( tasks=[context], + task_weights=[1.0], builder=ms.builder.LocalBuilder(), runner=ms.runner.LocalRunner(), database=DummyDatabase(), diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 26a2733980c0..fdf4d26379ae 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -168,6 +168,7 @@ def test_meta_schedule_task_scheduler_single(): database = DummyDatabase() round_robin = RoundRobin( [task], + [1.0], DummyBuilder(), DummyRunner(), database, @@ -210,6 +211,7 @@ def test_meta_schedule_task_scheduler_multiple(): database = DummyDatabase() round_robin = RoundRobin( tasks, + [1.0], DummyBuilder(), DummyRunner(), database, diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index 64b8795c5eaf..6b45ad6f07a5 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -20,14 +20,14 @@ from os import path as osp from typing import List -import numpy as np +import numpy as np # type: ignore import pytest import tvm -from tvm import relay, tir +from tvm import relay from tvm._ffi import register_func from tvm.contrib import graph_executor from tvm.ir import IRModule -from tvm.meta_schedule import ApplyHistoryBest, ReplayTraceConfig +from tvm.meta_schedule import ApplyHistoryBest, TuneConfig from tvm.meta_schedule.database import JSONDatabase, PyDatabase, TuningRecord, Workload from tvm.meta_schedule.relay_integration import extract_task_from_relay from tvm.meta_schedule.testing import apply_fixed_schedules @@ -40,19 +40,19 @@ from tvm.tir.schedule.trace import Trace from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN - logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off + @tvm.script.ir_module class tvmgen_default_fused_layout_transform: @T.prim_func - def main( - placeholder: T.Buffer[(1, 3, 16, 16), "float32"], - T_layout_trans: T.Buffer[(1, 1, 16, 16, 3), "float32"], - ) -> None: + def main( # type: ignore + placeholder: T.Buffer[(1, 3, 16, 16), "float32"], # type: ignore + T_layout_trans: T.Buffer[(1, 1, 16, 16, 3), "float32"], # type: ignore + ) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body @@ -63,7 +63,7 @@ def main( T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else( - ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, + ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore placeholder[ax0, ax1 * 3 + ax4, ax2, ax3], T.float32(0), dtype="float32", @@ -73,7 +73,7 @@ def main( @tvm.script.ir_module class tvmgen_default_fused_nn_contrib_conv2d_NCHWc: @T.prim_func - def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"]) -> None: + def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"]) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body @@ -84,21 +84,21 @@ def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.B i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1]) T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) - data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") + data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5): with T.block("conv2d_NCHWc"): n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) - T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) + T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block]) T.block_attr({"workload":["conv2d_NCHWc.x86", ["TENSOR", [1, 1, 16, 16, 3], "float32"], ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1], [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"]}) with T.init(): conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0) - conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] + conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore @tvm.script.ir_module class tvmgen_default_fused_layout_transform_1: @T.prim_func - def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.Buffer[(1, 8, 16, 16), "float32"]) -> None: + def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.Buffer[(1, 8, 16, 16), "float32"]) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body @@ -106,9 +106,9 @@ def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T. for i0, i1, i2, i3 in T.grid(1, 8, 16, 16): with T.block("T_layout_trans"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) + T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore T.writes(T_layout_trans[ax0, ax1, ax2, ax3]) - T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") + T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument @@ -144,14 +144,19 @@ def test_meta_schedule_tune_relay( mod=mod, params=params, target=target, - config=ReplayTraceConfig( + config=TuneConfig( + strategy="evolutionary", num_trials_per_iter=32, - max_trials_per_task=32, + max_trials_per_task=20000, max_trials_global=20000, + search_strategy_config={ + "genetic_num_iters": 10, + }, ), work_dir=work_dir, database=JSONDatabase( - osp.join(work_dir, "workload.json"), osp.join(work_dir, "records.json") + osp.join(work_dir, "workload.json"), + osp.join(work_dir, "records.json"), ), ) # Compile without meta-scheduler for correctness check @@ -330,7 +335,7 @@ def get_output(data, lib): assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) -def schedule_dense(dense_block, M, do_tune, sch): +def schedule_dense(dense_block, M, do_tune, sch): # pylint: disable=invalid-name """ Manually schedule a dense block, created from TE compute op via CreatePrimFunc, using VNNI instruction. @@ -392,7 +397,7 @@ def schedule_dense(dense_block, M, do_tune, sch): def manual_tir_common(do_tune=False): - M, N, K = 1024, 1024, 1024 + M, N, K = 1024, 1024, 1024 # pylint: disable=invalid-name data_shape = (M, K) weight_shape = (N, K) @@ -437,9 +442,10 @@ def manual_tir_common(do_tune=False): extracted_tasks, ) ) - config = ReplayTraceConfig( + config = TuneConfig( + strategy="replay_trace", num_trials_per_iter=64, - max_trials_per_task=64, + max_trials_per_task=20000, max_trials_global=20000, ) @@ -447,7 +453,10 @@ def manual_tir_common(do_tune=False): # postprocs=lambda: [] is important to prevent default post processors from # tampering with the manual schedule. database = tune_extracted_tasks( - tune_tasks, target, config, work_dir=work_dir, postprocs=lambda: [] + tune_tasks, + config, + work_dir=work_dir, + postprocs=lambda: [], ) else: @@ -457,7 +466,8 @@ def schedule_fn(task, sch): block = sch.get_block("compute") - # Looks up schedule_rule annotation. See the comment in test_tune_relay_manual_tir_vnni(). + # Looks up schedule_rule annotation. + # See the comment in test_tune_relay_manual_tir_vnni(). schedule_rule = sch.get(block).annotations["schedule_rule"] assert "dense_vnni" in schedule_rule @@ -473,6 +483,7 @@ def schedule_fn(task, sch): opt_level=3, config={"relay.backend.use_meta_schedule": True}, ): + # pylint: disable=W0105 """ The log should say Warning: Cannot find workload: tvmgen_default_fused_expand_dims @@ -483,6 +494,7 @@ def schedule_fn(task, sch): This means batch matmul and others are scheduled by TE, and dense (the one not warned) is found in the meta schedule tuning database during ApplyHistoryBest """ + # pylint: enable=W0105 lib = relay.build(relay_mod, target=target, params=params) runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) @@ -499,6 +511,7 @@ def schedule_fn(task, sch): def test_tune_relay_manual_tir_vnni(): manual_tir_common(do_tune=False) + # pylint: disable=W0105 """ We can inject and apply a custom TIR scheduling to a TE compute of interest, using the "schedule_rule" annotation. For example, in topi/x86/dense.py we have the following @@ -510,17 +523,18 @@ def test_tune_relay_manual_tir_vnni(): ) When the meta scheduler encounters a TensorIR block with the "schedule_rule" annotation, - it looks up the packed func registry for a function that is associated with the given schedule rule - key ("meta_schedule.dense_vnni" in this example). The signature of such custom schedule functions - must be + it looks up the packed func registry for a function that is associated with the given schedule + rule key ("meta_schedule.dense_vnni" in this example). The signature of such custom schedule + functions must be (tir.schedule.Schedule, tir.schedule.BlockRV) -> [tir.schedule.Schedule]. - The BlockRV argument corresponds to the TE compute annotated with "schedule_rlue". + The BlockRV argument corresponds to the TE compute annotated with "schedule_rule". The relevant code is in meta_schedule/space_generator/post_order_apply.cc. """ + # pylint: enable=W0105 def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV): schedule_dense(dense_block, None, True, sch) diff --git a/tests/python/unittest/test_meta_schedule_tune_te.py b/tests/python/unittest/test_meta_schedule_tune_te.py index f58ebf34787e..52e5fde85ec9 100644 --- a/tests/python/unittest/test_meta_schedule_tune_te.py +++ b/tests/python/unittest/test_meta_schedule_tune_te.py @@ -19,7 +19,7 @@ import tempfile import pytest -from tvm.meta_schedule import ReplayTraceConfig, tune_te +from tvm.meta_schedule import TuneConfig, tune_te from tvm.meta_schedule.testing import te_workload from tvm.target.target import Target from tvm.tir import Schedule @@ -34,7 +34,8 @@ def test_tune_matmul(): sch: Schedule = tune_te( tensors=te_workload.batch_matmul_nkkm(B=1, N=128, M=128, K=128), target=Target("llvm --num-cores=16"), - config=ReplayTraceConfig( + config=TuneConfig( + strategy="replay_trace", num_trials_per_iter=32, max_trials_per_task=32, max_trials_global=32, diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index ebce33965914..a7806ebda28a 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -19,13 +19,9 @@ import tempfile import pytest -import tvm -from tvm.meta_schedule import ReplayTraceConfig, schedule_rule, tune_tir -from tvm.meta_schedule.space_generator import PostOrderApply -from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule import TuneConfig, tune_tir from tvm.script import tir as T -from tvm.target.target import Target -from tvm.te.operation import create_prim_func +from tvm.target import Target from tvm.tir import Schedule logging.basicConfig() @@ -57,7 +53,8 @@ def test_tune_matmul_cpu(): sch: Schedule = tune_tir( mod=matmul, target=Target("llvm --num-cores=16"), - config=ReplayTraceConfig( + config=TuneConfig( + strategy="replay_trace", num_trials_per_iter=32, max_trials_per_task=32, max_trials_global=32, @@ -77,7 +74,8 @@ def test_tune_matmul_cuda(): sch: Schedule = tune_tir( mod=matmul, target=Target("nvidia/geforce-rtx-3070"), - config=ReplayTraceConfig( + config=TuneConfig( + strategy="replay_trace", num_trials_per_iter=32, max_trials_per_task=32, max_trials_global=32, From 11701030b808bd6208e33239ae1fba761a7d4b3d Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 17 Apr 2022 22:37:40 -0700 Subject: [PATCH 2/2] Update docs in TuneConfig --- python/tvm/meta_schedule/tune.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 05df6f3c33bd..1b417940072b 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -346,17 +346,17 @@ class TuneConfig(NamedTuple): Maximum number of trials to run. num_trials_per_iter: int Number of trials to run per iteration. - max_trials_per_task: int - Maximum number of trials to run per task. - task_scheduler: str + max_trials_per_task: Optional[int] + Maximum number of trials to run per task. If None, use `max_trials_global`. + task_scheduler: str = "gradient" Task scheduler to use. Valid options are: round_robin, gradient. - search_strategy: str + search_strategy: str = "evolutionary" Search strategy to use. Valid options are: evolutionary, replay_func, replay_trace. - task_scheduler_config: Dict[str, Any] + task_scheduler_config: Optional[Dict[str, Any]] = None Configuration for task scheduler. - search_strategy_config: Dict[str, Any] + search_strategy_config: Optional[Dict[str, Any]] = None Configuration for search strategy. """ @@ -365,8 +365,8 @@ class TuneConfig(NamedTuple): max_trials_per_task: Optional[int] = None task_scheduler: str = "gradient" strategy: str = "evolutionary" - task_scheduler_config: Dict[str, Any] = {} - search_strategy_config: Dict[str, Any] = {} + task_scheduler_config: Optional[Dict[str, Any]] = None + search_strategy_config: Optional[Dict[str, Any]] = None def create_strategy(self, **kwargs): """Create search strategy from configuration""" @@ -380,14 +380,19 @@ def create_strategy(self, **kwargs): f"Invalid search strategy: {self.strategy}. " "Valid options are: {}".format(", ".join(cls_tbl.keys())) ) + # `max_trials_per_task` defaults to `max_trials_global` max_trials_per_task = self.max_trials_per_task if max_trials_per_task is None: max_trials_per_task = self.max_trials_global + # `search_strategy_config` defaults to empty dict + config = self.search_strategy_config + if config is None: + config = {} return cls_tbl[self.strategy]( num_trials_per_iter=self.num_trials_per_iter, max_trials_per_task=max_trials_per_task, **kwargs, - **self.search_strategy_config, + **config, ) def create_task_scheduler(self, **kwargs): @@ -401,10 +406,14 @@ def create_task_scheduler(self, **kwargs): f"Invalid task scheduler: {self.task_scheduler}. " "Valid options are: {}".format(", ".join(cls_tbl.keys())) ) + # `task_scheduler_config` defaults to empty dict + config = self.task_scheduler_config + if config is None: + config = {} return cls_tbl[self.task_scheduler]( max_trials=self.max_trials_global, **kwargs, - **self.task_scheduler_config, + **config, )