Skip to content

Commit

Permalink
[MetaSchedule] Enable Adapative Training For XGBoost Cost Model (#11892)
Browse files Browse the repository at this point in the history
CostModel retraining is a time consuming part for MetaSchedule tuning, similar to AutoScheduler, we can alleviate it with an adapative way of increasing waiting period between each retraining. This PR introduced an argument called `adpative_training` in `TuneConfig` and the constructor of `XGBoostModel` to enable the capability. Testing tuning scripts are also updated.
  • Loading branch information
zxybazh authored Jun 28, 2022
1 parent 97b3076 commit 0e23122
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 17 deletions.
4 changes: 2 additions & 2 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def __init__(
desc,
)

def tune(self, tuning_options, search_policy=None):
def tune(self, tuning_options, search_policy=None, adaptive_training=False):
"""Run auto scheduling search for a task
Parameters
Expand All @@ -492,7 +492,7 @@ def tune(self, tuning_options, search_policy=None):
The search policy to be used for schedule search.
"""
if search_policy is None:
cost_model = XGBModel()
cost_model = XGBModel(adaptive_training=adaptive_training)
search_policy = SketchPolicy(self, cost_model)

_ffi_api.AutoSchedule(search_policy, tuning_options)
Expand Down
14 changes: 11 additions & 3 deletions python/tvm/auto_scheduler/testing/tune_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,14 @@ def _parse_args():
"--cpu-flush",
type=lambda x: bool(strtobool(x)),
required=True,
help="example: `True / False",
help="example: True / False",
)
args.add_argument(
"--adaptive-training",
type=lambda x: bool(strtobool(x)),
required=False,
help="example: True / False",
default=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
Expand All @@ -108,7 +115,7 @@ def _parse_args():
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=3600,
session_timeout_sec=600,
)
return parsed

Expand Down Expand Up @@ -179,7 +186,8 @@ def main():
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
)
),
adaptive_training=ARGS.adaptive_training,
)

with auto_scheduler.ApplyHistoryBest(log_file):
Expand Down
14 changes: 11 additions & 3 deletions python/tvm/auto_scheduler/testing/tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,14 @@ def _parse_args():
"--cpu-flush",
type=lambda x: bool(strtobool(x)),
required=True,
help="example: `True / False",
help="example: True / False",
)
args.add_argument(
"--adaptive-training",
type=lambda x: bool(strtobool(x)),
required=False,
help="example: True / False",
default=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
Expand All @@ -106,7 +113,7 @@ def _parse_args():
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=3600,
session_timeout_sec=600,
)
return parsed

Expand Down Expand Up @@ -180,7 +187,8 @@ def main():
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
)
),
adaptive_training=ARGS.adaptive_training,
)

with auto_scheduler.ApplyHistoryBest(log_file):
Expand Down
12 changes: 10 additions & 2 deletions python/tvm/auto_scheduler/testing/tune_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,14 @@ def _parse_args():
"--cpu-flush",
type=lambda x: bool(strtobool(x)),
required=True,
help="example: `True / False",
help="example: True / False",
)
args.add_argument(
"--adaptive-training",
type=lambda x: bool(strtobool(x)),
required=False,
help="example: True / False",
default=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
Expand Down Expand Up @@ -135,6 +142,7 @@ def main():
repeat=ARGS.repeat,
min_repeat_ms=ARGS.min_repeat_ms,
enable_cpu_cache_flush=ARGS.cpu_flush,
# todo(zxybazh): set session timeout to 60 same as MS
)

# Inspect the computational graph
Expand All @@ -147,7 +155,7 @@ def main():
runner=runner,
)
print("Running AutoTuning:")
task.tune(tune_option)
task.tune(tune_option, adaptive_training=ARGS.adaptive_training)
print("History Best:")
print(task.print_best(log_file))
sch, args = task.apply_best(log_file)
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ class XGBModel(PyCostModel):
The verbose level when doing evaluation.
average_peak_n : int
The number to calculate average peak score.
adaptive_training : bool
Whether use adpative training to reduce tuning time.
"""

# feature extractor
Expand All @@ -314,6 +316,9 @@ class XGBModel(PyCostModel):
data: Dict[str, FeatureGroup]
data_size: int
booster: Optional["xgb.Booster"]
# adaptive training
adaptive_training: bool
last_train_size: int

def __init__(
self,
Expand All @@ -328,6 +333,7 @@ def __init__(
early_stopping_rounds: int = 50,
verbose_eval: int = 25,
average_peak_n: int = 32,
adaptive_training: bool = True,
):
super().__init__()
# feature extractor
Expand All @@ -347,6 +353,9 @@ def __init__(
self.data = OrderedDict()
self.data_size = 0
self.booster = None
# adaptive training
self.adaptive_training = adaptive_training
self.last_train_size = 0

def load(self, path: str) -> None:
"""Load the cost model from given file location.
Expand Down Expand Up @@ -491,6 +500,15 @@ def _mean_cost(x: RunnerResult) -> float:
self.data[new_group_hash] = group
self.data_size += len(new_features)

if (
self.adaptive_training
and self.data_size - self.last_train_size < self.last_train_size / 5
):
# Set a training threshold related to `last_train_size` to reduce the training
# overhead when there're too many results
return
self.last_train_size = self.data_size

# Step 5. Re-train the model
self._train(
xs=list(itertools_chain.from_iterable([g.features for g in self.data.values()])),
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/meta_schedule/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,14 @@ def callbacks( # pylint: disable=redefined-outer-name

def cost_model(
cost_model: Optional[CostModel], # pylint: disable=redefined-outer-name
adpative_training: Optional[bool],
) -> CostModel:
"""Normalize the input to tvm.meta_schedule.CostModel"""
if cost_model is None:
return XGBModel(extractor=PerStoreFeature()) # type: ignore
return XGBModel( # type: ignore
extractor=PerStoreFeature(),
adaptive_training=adpative_training is None or adpative_training,
)
if not isinstance(cost_model, CostModel):
raise TypeError(f"Expected `cost_model` to be CostModel, but gets: {cost_model}")
return cost_model
Expand Down
12 changes: 10 additions & 2 deletions python/tvm/meta_schedule/testing/tune_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,14 @@ def _parse_args():
"--cpu-flush",
type=lambda x: bool(strtobool(x)),
required=True,
help="example: `True / False",
help="example: True / False",
)
args.add_argument(
"--adaptive-training",
type=lambda x: bool(strtobool(x)),
required=False,
help="example: True / False",
default=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
Expand All @@ -105,7 +112,7 @@ def _parse_args():
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=3600,
session_timeout_sec=600,
)
return parsed

Expand Down Expand Up @@ -147,6 +154,7 @@ def main():
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
adaptive_training=ARGS.adaptive_training,
),
runner=runner, # type: ignore
work_dir=ARGS.work_dir,
Expand Down
12 changes: 10 additions & 2 deletions python/tvm/meta_schedule/testing/tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,14 @@ def _parse_args():
"--cpu-flush",
type=lambda x: bool(strtobool(x)),
required=True,
help="example: `True / False",
help="example: True / False",
)
args.add_argument(
"--adaptive-training",
type=lambda x: bool(strtobool(x)),
required=False,
help="example: True / False",
default=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
Expand All @@ -103,7 +110,7 @@ def _parse_args():
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=3600,
session_timeout_sec=600,
)
return parsed

Expand Down Expand Up @@ -148,6 +155,7 @@ def main():
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
adaptive_training=ARGS.adaptive_training,
),
runner=runner, # type: ignore
work_dir=ARGS.work_dir,
Expand Down
10 changes: 9 additions & 1 deletion python/tvm/meta_schedule/testing/tune_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,14 @@ def _parse_args():
"--cpu-flush",
type=lambda x: bool(strtobool(x)),
required=True,
help="example: `True / False",
help="example: True / False",
)
args.add_argument(
"--adaptive-training",
type=lambda x: bool(strtobool(x)),
required=False,
help="example: True / False",
default=True,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
Expand Down Expand Up @@ -125,6 +132,7 @@ def main():
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
adaptive_training=ARGS.adaptive_training,
),
runner=runner, # type: ignore
task_name=ARGS.workload,
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class TuneConfig(NamedTuple):
Configuration for search strategy.
logger_config: Optional[Dict[str, Any]] = None
Configuration for logger.
adaptive_training: Optional[bool] = None
Whether adpative training is enabled for cost model.
"""

max_trials_global: int
Expand All @@ -88,6 +90,7 @@ class TuneConfig(NamedTuple):
task_scheduler_config: Optional[Dict[str, Any]] = None
search_strategy_config: Optional[Dict[str, Any]] = None
logger_config: Optional[Dict[str, Any]] = None
adaptive_training: Optional[bool] = None

def create_strategy(self):
"""Create search strategy from configuration"""
Expand Down Expand Up @@ -310,7 +313,7 @@ def tune_extracted_tasks(
database = default_config.database(database, work_dir)
builder = default_config.builder(builder)
runner = default_config.runner(runner)
cost_model = default_config.cost_model(cost_model)
cost_model = default_config.cost_model(cost_model, config.adaptive_training)
measure_callbacks = default_config.callbacks(measure_callbacks)
# parse the tuning contexts
tune_contexts = []
Expand Down

0 comments on commit 0e23122

Please sign in to comment.