Skip to content

Commit

Permalink
[Meta Schedule][XGBoost] Update the custom callback function of xgboo…
Browse files Browse the repository at this point in the history
…st in meta schedule (#12141)

* update the custom callback function of xgboost

* fix lint

* fix ci

* fix lint

* add unit test

* remote unused code

* fix lint

* add decorator

* address comment

* fix lint

* address comments

* fix mypy

* fix lint

* remove unused comments

* address comments

* Fix xgboost unit test import.

Co-authored-by: Xiyou Zhou <[email protected]>
  • Loading branch information
Yuanjing Shi and zxybazh authored Sep 26, 2022
1 parent a61c1ad commit c8423a6
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 60 deletions.
169 changes: 109 additions & 60 deletions python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,26 @@
from ..utils import cpu_count, derived_object, shash2hex
from .metric import max_curve


def optional_xgboost_callback(cls):
"""Decorator for importing TraningCallback from xgboost"""
# pylint:disable = import-outside-toplevel
try:
from xgboost.callback import TrainingCallback # type: ignore
# pylint:enable = import-outside-toplevel
except ImportError:

class TrainingCallback: # type: ignore
pass

class OptXGBoostCustomCallback(cls, TrainingCallback): # type: ignore
pass

return OptXGBoostCustomCallback


if TYPE_CHECKING:

import xgboost as xgb # type: ignore

from ..tune_context import TuneContext
Expand Down Expand Up @@ -579,14 +598,12 @@ def avg_peak_score(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore
num_boost_round=10000,
obj=obj,
callbacks=[
custom_callback(
XGBoostCustomCallback(
early_stopping_rounds=self.early_stopping_rounds,
verbose_eval=self.verbose_eval,
fevals=[
rmse,
avg_peak_score,
],
fevals=[rmse, avg_peak_score],
evals=[(self.d_train.dmatrix, "tr")],
cvfolds=None,
)
],
)
Expand Down Expand Up @@ -640,52 +657,83 @@ def average_peak_score(ys_pred: np.ndarray):
return eval_result


def custom_callback(
early_stopping_rounds: int,
verbose_eval: int,
fevals: List[Callable],
evals: List[Tuple["xgb.DMatrix", str]],
focused_metric: str = "tr-p-rmse",
):
"""Callback function for xgboost to support multiple custom evaluation functions"""
sort_key = make_metric_sorter(focused_metric=focused_metric)

state: Dict[str, Any] = {}

def init(env: "xgb.core.CallbackEnv"):
"""Internal function"""
booster: "xgb.Booster" = env.model
@optional_xgboost_callback
class XGBoostCustomCallback:
"""Custom callback class for xgboost to support multiple custom evaluation functions"""

state["best_iteration"] = 0
state["best_score"] = float("inf")
def __init__(
self,
early_stopping_rounds: int,
verbose_eval: int,
fevals: List[Callable],
evals: List[Tuple["xgb.DMatrix", str]],
focused_metric: str = "tr-p-rmse",
cvfolds: List["xgb.training.CVPack"] = None,
):
self.early_stopping_rounds = early_stopping_rounds
self.verbose_eval = verbose_eval
self.fevals = fevals
self.evals = evals
self.state: Dict[str, Any] = {}
self.focused_metric = focused_metric
self.sort_key = make_metric_sorter(focused_metric=focused_metric)
self.cvfolds = cvfolds
if cvfolds is not None:
self.aggregated_cv = None

def __call__(self, env: "xgb.core.CallbackEnv"):
# Compatibility with xgboost < 1.3
return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)

def init(self, model: "xgb.Booster"):
"""Internal function for intialization"""
booster: "xgb.Booster" = model
self.state["best_iteration"] = 0
self.state["best_score"] = float("inf")
if booster is None:
assert env.cvfolds is not None
assert self.cvfolds is not None
return
if booster.attr("best_score") is not None:
state["best_score"] = float(booster.attr("best_score"))
state["best_iteration"] = int(booster.attr("best_iteration"))
state["best_msg"] = booster.attr("best_msg")
self.state["best_score"] = float(booster.attr("best_score"))
self.state["best_iteration"] = int(booster.attr("best_iteration"))
self.state["best_msg"] = booster.attr("best_msg")
else:
booster.set_attr(best_iteration=str(state["best_iteration"]))
booster.set_attr(best_score=str(state["best_score"]))
booster.set_attr(best_iteration=str(self.state["best_iteration"]))
booster.set_attr(best_score=str(self.state["best_score"]))

def callback(env: "xgb.core.CallbackEnv"):
def after_iteration(
self, model: "xgb.Booster", epoch: int, evals_log: Dict
): # pylint: disable = unused-argument
"""Internal function for after_iteration"""
# pylint:disable = import-outside-toplevel
try:
from xgboost.callback import _fmt_metric # type: ignore
except ImportError:
# Compatibility with xgboost >= 1.6

def _fmt_metric(value, show_stdv=True):
if len(value) == 2:
return f"{value[0]}:{value[1]:.5f}"
if len(value) == 3:
if show_stdv:
return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
return f"{value[0]}:{value[1]:.5f}"
raise ValueError("wrong metric value", value)

import xgboost as xgb
from xgboost.callback import _fmt_metric # type: ignore
from xgboost.core import EarlyStopException # type: ignore
from xgboost import rabit # type: ignore

try:
from xgboost.training import aggcv # type: ignore
except ImportError:
from xgboost.callback import _aggcv as aggcv # type: ignore
# pylint:enable = import-outside-toplevel

if not state:
init(env)
booster: xgb.Booster = env.model
iteration: int = env.iteration
cvfolds: List[xgb.training.CVPack] = env.cvfolds
# pylint:enable = import-outside-toplevel
if not self.state:
self.init(model)
booster: xgb.Booster = model
iteration: int = epoch
cvfolds: List[xgb.training.CVPack] = self.cvfolds
##### Evaluation #####
# `eval_result` is a list of (key, score)
eval_result: List[Tuple[str, float]] = []
Expand All @@ -697,13 +745,13 @@ def callback(env: "xgb.core.CallbackEnv"):
for key, value in map(
lambda x: x.split(":"),
booster.eval_set(
evals=evals,
evals=self.evals,
iteration=iteration,
feval=feval,
).split()[1:],
)
]
for feval in fevals
for feval in self.fevals
)
)
else:
Expand All @@ -719,14 +767,14 @@ def callback(env: "xgb.core.CallbackEnv"):
for fold in cvfolds
)
]
for feval in fevals
for feval in self.fevals
)
)
eval_result = list(eval_result)
eval_result.sort(key=sort_key)
eval_result.sort(key=self.sort_key)

##### Print eval result #####
if verbose_eval and iteration % verbose_eval == 0:
if self.verbose_eval and iteration % self.verbose_eval == 0:
info = []
for key, score in eval_result:
if "null" not in key:
Expand All @@ -736,30 +784,31 @@ def callback(env: "xgb.core.CallbackEnv"):
##### Choose score and do early stopping #####
score = None
for key, _score in eval_result:
if key == focused_metric:
if key == self.focused_metric:
score = _score
break
assert score is not None

best_score = state["best_score"]
best_iteration = state["best_iteration"]
best_score = self.state["best_score"]
best_iteration = self.state["best_iteration"]
if score < best_score:
tab = "\t" # to work with f-string
msg = f"[{env.iteration}] {tab.join([_fmt_metric(x) for x in eval_result])}"
state["best_msg"] = msg
state["best_score"] = score
state["best_iteration"] = env.iteration
msg = f"[{epoch}] {tab.join([_fmt_metric(x) for x in eval_result])}"
self.state["best_msg"] = msg
self.state["best_score"] = score
self.state["best_iteration"] = epoch
# save the property to attributes, so they will occur in checkpoint.
if env.model is not None:
env.model.set_attr(
best_score=str(state["best_score"]),
best_iteration=str(state["best_iteration"]),
best_msg=state["best_msg"],
if model is not None:
model.set_attr(
best_score=str(self.state["best_score"]),
best_iteration=str(self.state["best_iteration"]),
best_msg=self.state["best_msg"],
)
elif env.iteration - best_iteration >= early_stopping_rounds:
best_msg = state["best_msg"]
if verbose_eval and env.rank == 0:
logger.debug("XGB stopped. Best iteration: %s ", best_msg)
raise EarlyStopException(best_iteration)
elif epoch - best_iteration >= self.early_stopping_rounds:
best_msg = self.state["best_msg"]

return callback
if self.verbose_eval and rabit.get_rank() == 0:
logger.debug("XGB stopped. Best iteration: %s ", best_msg)
return True # instead of raising EarlyStopException, returning True to end the training
# False to indicate training should not stop.
return False
85 changes: 85 additions & 0 deletions tests/python/unittest/test_meta_schedule_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import tvm
import tvm.testing
from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel
from tvm.meta_schedule.cost_model.xgb_model import XGBoostCustomCallback, PackSum
from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor
from tvm.meta_schedule.runner import RunnerResult
from tvm.meta_schedule.search_strategy import MeasureCandidate
Expand Down Expand Up @@ -228,5 +229,89 @@ def test_meta_schedule_xgb_model_reupdate():
model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)])


def test_meta_schedule_xgb_model_callback():
import xgboost as xgb
from itertools import chain as itertools_chain
from functools import partial

extractor = RandomFeatureExtractor()
model = XGBModel(extractor=extractor, num_warmup_samples=10)
update_sample_count = 20
predict_sample_count = 30

model.update(
TuneContext(),
[_dummy_candidate() for i in range(update_sample_count)],
[_dummy_result() for i in range(update_sample_count)],
)
model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)])
with tempfile.NamedTemporaryFile() as path:
# Backup and train on new TrainingCallBack api
random_state = model.extractor.random_state # save feature extractor's random state

model.save(path.name)

old_booster = model.booster
xs = [
x.numpy().astype("float32")
for x in extractor.extract_from(
TuneContext(),
[_dummy_candidate() for i in range(predict_sample_count)],
)
]
d_test = PackSum(xs=xs, ys=None)
pred1 = old_booster.predict(d_test.dmatrix)

# Load and train on deprecated TrainingCallBack api
model.extractor.random_state = random_state # load feature extractor's random state
model.load(path.name)
d_train = PackSum(
xs=list(itertools_chain.from_iterable([g.features for g in model.data.values()])),
ys=np.concatenate(
[g.min_cost / g.costs for g in model.data.values()],
axis=0,
),
)

def obj(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
return d_train.obj_square_error(ys_pred)

def rmse(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
return d_train.rmse(ys_pred)

def avg_peak_score(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
return d_train.average_peak_score(ys_pred, model.average_peak_n)

new_booster = xgb.train(
model.config.to_dict(),
d_train.dmatrix,
num_boost_round=10000,
obj=obj,
callbacks=[
partial(
XGBoostCustomCallback(
early_stopping_rounds=model.early_stopping_rounds,
verbose_eval=model.verbose_eval,
fevals=[rmse, avg_peak_score],
evals=[(d_train.dmatrix, "tr")],
cvfolds=None,
)
)
],
)

xs = [
x.numpy().astype("float32")
for x in extractor.extract_from(
TuneContext(),
[_dummy_candidate() for i in range(predict_sample_count)],
)
]
d_test = PackSum(xs=xs, ys=None)
pred2 = new_booster.predict(d_test.dmatrix)

assert np.allclose(pred1, pred2, rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c8423a6

Please sign in to comment.