From c8423a6843edec5e85003a33d260f2214fd16c42 Mon Sep 17 00:00:00 2001 From: Yuanjing Shi Date: Sun, 25 Sep 2022 22:50:20 -0700 Subject: [PATCH] [Meta Schedule][XGBoost] Update the custom callback function of xgboost in meta schedule (#12141) * update the custom callback function of xgboost * fix lint * fix ci * fix lint * add unit test * remote unused code * fix lint * add decorator * address comment * fix lint * address comments * fix mypy * fix lint * remove unused comments * address comments * Fix xgboost unit test import. Co-authored-by: Xiyou Zhou --- .../tvm/meta_schedule/cost_model/xgb_model.py | 169 +++++++++++------- .../unittest/test_meta_schedule_cost_model.py | 85 +++++++++ 2 files changed, 194 insertions(+), 60 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 8de034758b4b..1171e081b90a 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -35,7 +35,26 @@ from ..utils import cpu_count, derived_object, shash2hex from .metric import max_curve + +def optional_xgboost_callback(cls): + """Decorator for importing TraningCallback from xgboost""" + # pylint:disable = import-outside-toplevel + try: + from xgboost.callback import TrainingCallback # type: ignore + # pylint:enable = import-outside-toplevel + except ImportError: + + class TrainingCallback: # type: ignore + pass + + class OptXGBoostCustomCallback(cls, TrainingCallback): # type: ignore + pass + + return OptXGBoostCustomCallback + + if TYPE_CHECKING: + import xgboost as xgb # type: ignore from ..tune_context import TuneContext @@ -579,14 +598,12 @@ def avg_peak_score(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore num_boost_round=10000, obj=obj, callbacks=[ - custom_callback( + XGBoostCustomCallback( early_stopping_rounds=self.early_stopping_rounds, verbose_eval=self.verbose_eval, - fevals=[ - rmse, - avg_peak_score, - ], + fevals=[rmse, avg_peak_score], evals=[(self.d_train.dmatrix, "tr")], + cvfolds=None, ) ], ) @@ -640,52 +657,83 @@ def average_peak_score(ys_pred: np.ndarray): return eval_result -def custom_callback( - early_stopping_rounds: int, - verbose_eval: int, - fevals: List[Callable], - evals: List[Tuple["xgb.DMatrix", str]], - focused_metric: str = "tr-p-rmse", -): - """Callback function for xgboost to support multiple custom evaluation functions""" - sort_key = make_metric_sorter(focused_metric=focused_metric) - - state: Dict[str, Any] = {} - - def init(env: "xgb.core.CallbackEnv"): - """Internal function""" - booster: "xgb.Booster" = env.model +@optional_xgboost_callback +class XGBoostCustomCallback: + """Custom callback class for xgboost to support multiple custom evaluation functions""" - state["best_iteration"] = 0 - state["best_score"] = float("inf") + def __init__( + self, + early_stopping_rounds: int, + verbose_eval: int, + fevals: List[Callable], + evals: List[Tuple["xgb.DMatrix", str]], + focused_metric: str = "tr-p-rmse", + cvfolds: List["xgb.training.CVPack"] = None, + ): + self.early_stopping_rounds = early_stopping_rounds + self.verbose_eval = verbose_eval + self.fevals = fevals + self.evals = evals + self.state: Dict[str, Any] = {} + self.focused_metric = focused_metric + self.sort_key = make_metric_sorter(focused_metric=focused_metric) + self.cvfolds = cvfolds + if cvfolds is not None: + self.aggregated_cv = None + + def __call__(self, env: "xgb.core.CallbackEnv"): + # Compatibility with xgboost < 1.3 + return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) + + def init(self, model: "xgb.Booster"): + """Internal function for intialization""" + booster: "xgb.Booster" = model + self.state["best_iteration"] = 0 + self.state["best_score"] = float("inf") if booster is None: - assert env.cvfolds is not None + assert self.cvfolds is not None return if booster.attr("best_score") is not None: - state["best_score"] = float(booster.attr("best_score")) - state["best_iteration"] = int(booster.attr("best_iteration")) - state["best_msg"] = booster.attr("best_msg") + self.state["best_score"] = float(booster.attr("best_score")) + self.state["best_iteration"] = int(booster.attr("best_iteration")) + self.state["best_msg"] = booster.attr("best_msg") else: - booster.set_attr(best_iteration=str(state["best_iteration"])) - booster.set_attr(best_score=str(state["best_score"])) + booster.set_attr(best_iteration=str(self.state["best_iteration"])) + booster.set_attr(best_score=str(self.state["best_score"])) - def callback(env: "xgb.core.CallbackEnv"): + def after_iteration( + self, model: "xgb.Booster", epoch: int, evals_log: Dict + ): # pylint: disable = unused-argument + """Internal function for after_iteration""" # pylint:disable = import-outside-toplevel + try: + from xgboost.callback import _fmt_metric # type: ignore + except ImportError: + # Compatibility with xgboost >= 1.6 + + def _fmt_metric(value, show_stdv=True): + if len(value) == 2: + return f"{value[0]}:{value[1]:.5f}" + if len(value) == 3: + if show_stdv: + return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}" + return f"{value[0]}:{value[1]:.5f}" + raise ValueError("wrong metric value", value) + import xgboost as xgb - from xgboost.callback import _fmt_metric # type: ignore - from xgboost.core import EarlyStopException # type: ignore + from xgboost import rabit # type: ignore try: from xgboost.training import aggcv # type: ignore except ImportError: from xgboost.callback import _aggcv as aggcv # type: ignore - # pylint:enable = import-outside-toplevel - if not state: - init(env) - booster: xgb.Booster = env.model - iteration: int = env.iteration - cvfolds: List[xgb.training.CVPack] = env.cvfolds + # pylint:enable = import-outside-toplevel + if not self.state: + self.init(model) + booster: xgb.Booster = model + iteration: int = epoch + cvfolds: List[xgb.training.CVPack] = self.cvfolds ##### Evaluation ##### # `eval_result` is a list of (key, score) eval_result: List[Tuple[str, float]] = [] @@ -697,13 +745,13 @@ def callback(env: "xgb.core.CallbackEnv"): for key, value in map( lambda x: x.split(":"), booster.eval_set( - evals=evals, + evals=self.evals, iteration=iteration, feval=feval, ).split()[1:], ) ] - for feval in fevals + for feval in self.fevals ) ) else: @@ -719,14 +767,14 @@ def callback(env: "xgb.core.CallbackEnv"): for fold in cvfolds ) ] - for feval in fevals + for feval in self.fevals ) ) eval_result = list(eval_result) - eval_result.sort(key=sort_key) + eval_result.sort(key=self.sort_key) ##### Print eval result ##### - if verbose_eval and iteration % verbose_eval == 0: + if self.verbose_eval and iteration % self.verbose_eval == 0: info = [] for key, score in eval_result: if "null" not in key: @@ -736,30 +784,31 @@ def callback(env: "xgb.core.CallbackEnv"): ##### Choose score and do early stopping ##### score = None for key, _score in eval_result: - if key == focused_metric: + if key == self.focused_metric: score = _score break assert score is not None - best_score = state["best_score"] - best_iteration = state["best_iteration"] + best_score = self.state["best_score"] + best_iteration = self.state["best_iteration"] if score < best_score: tab = "\t" # to work with f-string - msg = f"[{env.iteration}] {tab.join([_fmt_metric(x) for x in eval_result])}" - state["best_msg"] = msg - state["best_score"] = score - state["best_iteration"] = env.iteration + msg = f"[{epoch}] {tab.join([_fmt_metric(x) for x in eval_result])}" + self.state["best_msg"] = msg + self.state["best_score"] = score + self.state["best_iteration"] = epoch # save the property to attributes, so they will occur in checkpoint. - if env.model is not None: - env.model.set_attr( - best_score=str(state["best_score"]), - best_iteration=str(state["best_iteration"]), - best_msg=state["best_msg"], + if model is not None: + model.set_attr( + best_score=str(self.state["best_score"]), + best_iteration=str(self.state["best_iteration"]), + best_msg=self.state["best_msg"], ) - elif env.iteration - best_iteration >= early_stopping_rounds: - best_msg = state["best_msg"] - if verbose_eval and env.rank == 0: - logger.debug("XGB stopped. Best iteration: %s ", best_msg) - raise EarlyStopException(best_iteration) + elif epoch - best_iteration >= self.early_stopping_rounds: + best_msg = self.state["best_msg"] - return callback + if self.verbose_eval and rabit.get_rank() == 0: + logger.debug("XGB stopped. Best iteration: %s ", best_msg) + return True # instead of raising EarlyStopException, returning True to end the training + # False to indicate training should not stop. + return False diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index d1d558181324..94b7bce246f4 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -27,6 +27,7 @@ import tvm import tvm.testing from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel +from tvm.meta_schedule.cost_model.xgb_model import XGBoostCustomCallback, PackSum from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.search_strategy import MeasureCandidate @@ -228,5 +229,89 @@ def test_meta_schedule_xgb_model_reupdate(): model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) +def test_meta_schedule_xgb_model_callback(): + import xgboost as xgb + from itertools import chain as itertools_chain + from functools import partial + + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=10) + update_sample_count = 20 + predict_sample_count = 30 + + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + with tempfile.NamedTemporaryFile() as path: + # Backup and train on new TrainingCallBack api + random_state = model.extractor.random_state # save feature extractor's random state + + model.save(path.name) + + old_booster = model.booster + xs = [ + x.numpy().astype("float32") + for x in extractor.extract_from( + TuneContext(), + [_dummy_candidate() for i in range(predict_sample_count)], + ) + ] + d_test = PackSum(xs=xs, ys=None) + pred1 = old_booster.predict(d_test.dmatrix) + + # Load and train on deprecated TrainingCallBack api + model.extractor.random_state = random_state # load feature extractor's random state + model.load(path.name) + d_train = PackSum( + xs=list(itertools_chain.from_iterable([g.features for g in model.data.values()])), + ys=np.concatenate( + [g.min_cost / g.costs for g in model.data.values()], + axis=0, + ), + ) + + def obj(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return d_train.obj_square_error(ys_pred) + + def rmse(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return d_train.rmse(ys_pred) + + def avg_peak_score(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return d_train.average_peak_score(ys_pred, model.average_peak_n) + + new_booster = xgb.train( + model.config.to_dict(), + d_train.dmatrix, + num_boost_round=10000, + obj=obj, + callbacks=[ + partial( + XGBoostCustomCallback( + early_stopping_rounds=model.early_stopping_rounds, + verbose_eval=model.verbose_eval, + fevals=[rmse, avg_peak_score], + evals=[(d_train.dmatrix, "tr")], + cvfolds=None, + ) + ) + ], + ) + + xs = [ + x.numpy().astype("float32") + for x in extractor.extract_from( + TuneContext(), + [_dummy_candidate() for i in range(predict_sample_count)], + ) + ] + d_test = PackSum(xs=xs, ys=None) + pred2 = new_booster.predict(d_test.dmatrix) + + assert np.allclose(pred1, pred2, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": tvm.testing.main()