Skip to content

Commit

Permalink
[MetaSchedule] Fix XGBoost Import Issue (#12936)
Browse files Browse the repository at this point in the history
Previous upgrade introduced a import of xgboost in meta_schedule, removed in current version by using a function to return the call back class.

We've recently introduced a XGBoost Model upgrade to support new xgboost version of callback class in #12141. However, in this PR it uses a function called `optional_xgboost_callback` that works to avoid compatibility issue (xgboost 1.5.2 v.s. 1.6.0). In this specific function, it tries to import the newly introduced xgboost callback class and create a new class using it as base class. This actually imported xgboost when meta_schedule is imported, which is not ideal because xgboost is not a dependency of tvm and meta_schedule, it should only be required when xgboost cost model is employed. This PR fixes the problem by moving the class and the function mentioned above under a function that returns this class when needed. In this way we avoided unwanted import of xgboost in meta_schedule.
  • Loading branch information
zxybazh authored Sep 30, 2022
1 parent 77c8b6e commit 4e4089e
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 179 deletions.
348 changes: 184 additions & 164 deletions python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import tempfile
from collections import OrderedDict
from itertools import chain as itertools_chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Callable

import numpy as np # type: ignore

Expand All @@ -36,26 +36,10 @@
from .metric import max_curve


def optional_xgboost_callback(cls):
"""Decorator for importing TraningCallback from xgboost"""
# pylint:disable = import-outside-toplevel
try:
from xgboost.callback import TrainingCallback # type: ignore
# pylint:enable = import-outside-toplevel
except ImportError:

class TrainingCallback: # type: ignore
pass

class OptXGBoostCustomCallback(cls, TrainingCallback): # type: ignore
pass

return OptXGBoostCustomCallback


if TYPE_CHECKING:

import xgboost as xgb # type: ignore
from xgboost.callback import TrainingCallback # type: ignore

from ..tune_context import TuneContext

Expand Down Expand Up @@ -346,7 +330,7 @@ def __init__(
extractor: FeatureExtractor,
# xgboost model config
config: XGBConfig = XGBConfig(),
# behavior of randomness
# random result before enough samples
num_warmup_samples: int = 100,
# evaluation
early_stopping_rounds: int = 50,
Expand Down Expand Up @@ -598,7 +582,7 @@ def avg_peak_score(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore
num_boost_round=10000,
obj=obj,
callbacks=[
XGBoostCustomCallback(
_get_custom_call_back(
early_stopping_rounds=self.early_stopping_rounds,
verbose_eval=self.verbose_eval,
fevals=[rmse, avg_peak_score],
Expand Down Expand Up @@ -657,158 +641,194 @@ def average_peak_score(ys_pred: np.ndarray):
return eval_result


@optional_xgboost_callback
class XGBoostCustomCallback:
"""Custom callback class for xgboost to support multiple custom evaluation functions"""

def __init__(
self,
early_stopping_rounds: int,
verbose_eval: int,
fevals: List[Callable],
evals: List[Tuple["xgb.DMatrix", str]],
focused_metric: str = "tr-p-rmse",
cvfolds: List["xgb.training.CVPack"] = None,
):
self.early_stopping_rounds = early_stopping_rounds
self.verbose_eval = verbose_eval
self.fevals = fevals
self.evals = evals
self.state: Dict[str, Any] = {}
self.focused_metric = focused_metric
self.sort_key = make_metric_sorter(focused_metric=focused_metric)
self.cvfolds = cvfolds
if cvfolds is not None:
self.aggregated_cv = None

def __call__(self, env: "xgb.core.CallbackEnv"):
# Compatibility with xgboost < 1.3
return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)

def init(self, model: "xgb.Booster"):
"""Internal function for intialization"""
booster: "xgb.Booster" = model
self.state["best_iteration"] = 0
self.state["best_score"] = float("inf")
if booster is None:
assert self.cvfolds is not None
return
if booster.attr("best_score") is not None:
self.state["best_score"] = float(booster.attr("best_score"))
self.state["best_iteration"] = int(booster.attr("best_iteration"))
self.state["best_msg"] = booster.attr("best_msg")
else:
booster.set_attr(best_iteration=str(self.state["best_iteration"]))
booster.set_attr(best_score=str(self.state["best_score"]))
def _get_custom_call_back(
early_stopping_rounds: int,
verbose_eval: int,
fevals: List[Callable],
evals: List[Tuple["xgb.DMatrix", str]],
focused_metric: str = "tr-p-rmse",
cvfolds: List["xgb.training.CVPack"] = None,
) -> "TrainingCallback":
"""Get a customized callback function for XGBoost. Work around xgboost import."""

def after_iteration(
self, model: "xgb.Booster", epoch: int, evals_log: Dict
): # pylint: disable = unused-argument
"""Internal function for after_iteration"""
def optional_xgboost_callback(cls):
"""Decorator for importing TraningCallback from xgboost"""
# pylint:disable = import-outside-toplevel
try:
from xgboost.callback import _fmt_metric # type: ignore
from xgboost.callback import TrainingCallback # type: ignore
# pylint:enable = import-outside-toplevel
except ImportError:
# Compatibility with xgboost >= 1.6

def _fmt_metric(value, show_stdv=True):
if len(value) == 2:
return f"{value[0]}:{value[1]:.5f}"
if len(value) == 3:
if show_stdv:
return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
return f"{value[0]}:{value[1]:.5f}"
raise ValueError("wrong metric value", value)
class TrainingCallback: # type: ignore
pass

import xgboost as xgb
from xgboost import rabit # type: ignore
class OptXGBoostCustomCallback(cls, TrainingCallback): # type: ignore
pass

try:
from xgboost.training import aggcv # type: ignore
except ImportError:
from xgboost.callback import _aggcv as aggcv # type: ignore
return OptXGBoostCustomCallback

# pylint:enable = import-outside-toplevel
if not self.state:
self.init(model)
booster: xgb.Booster = model
iteration: int = epoch
cvfolds: List[xgb.training.CVPack] = self.cvfolds
##### Evaluation #####
# `eval_result` is a list of (key, score)
eval_result: List[Tuple[str, float]] = []
if cvfolds is None:
eval_result = list(
itertools_chain.from_iterable(
[
(key, float(value))
for key, value in map(
lambda x: x.split(":"),
booster.eval_set(
evals=self.evals,
iteration=iteration,
feval=feval,
).split()[1:],
)
]
for feval in self.fevals
)
)
else:
eval_result = list(
itertools_chain.from_iterable(
[
(key, score)
for key, score, _std in aggcv(
fold.eval(
iteration=iteration,
feval=feval,
@optional_xgboost_callback
class XGBoostCustomCallback:
"""Custom callback class for xgboost to support multiple custom evaluation functions"""

def __init__(
self,
early_stopping_rounds: int,
verbose_eval: int,
fevals: List[Callable],
evals: List[Tuple["xgb.DMatrix", str]],
focused_metric: str = "tr-p-rmse",
cvfolds: List["xgb.training.CVPack"] = None,
):
self.early_stopping_rounds = early_stopping_rounds
self.verbose_eval = verbose_eval
self.fevals = fevals
self.evals = evals
self.state: Dict[str, Any] = {}
self.focused_metric = focused_metric
self.sort_key = make_metric_sorter(focused_metric=focused_metric)
self.cvfolds = cvfolds
if cvfolds is not None:
self.aggregated_cv = None

def __call__(self, env: "xgb.core.CallbackEnv"):
# Compatibility with xgboost < 1.3
return self.after_iteration(env.model, env.iteration, env.evaluation_result_list)

def init(self, model: "xgb.Booster"):
"""Internal function for intialization"""
booster: "xgb.Booster" = model
self.state["best_iteration"] = 0
self.state["best_score"] = float("inf")
if booster is None:
assert self.cvfolds is not None
return
if booster.attr("best_score") is not None:
self.state["best_score"] = float(booster.attr("best_score"))
self.state["best_iteration"] = int(booster.attr("best_iteration"))
self.state["best_msg"] = booster.attr("best_msg")
else:
booster.set_attr(best_iteration=str(self.state["best_iteration"]))
booster.set_attr(best_score=str(self.state["best_score"]))

def after_iteration(
self, model: "xgb.Booster", epoch: int, evals_log: Dict
): # pylint: disable = unused-argument
"""Internal function for after_iteration"""
# pylint:disable = import-outside-toplevel
try:
from xgboost.callback import _fmt_metric # type: ignore
except ImportError:
# Compatibility with xgboost >= 1.6

def _fmt_metric(value, show_stdv=True):
if len(value) == 2:
return f"{value[0]}:{value[1]:.5f}"
if len(value) == 3:
if show_stdv:
return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
return f"{value[0]}:{value[1]:.5f}"
raise ValueError("wrong metric value", value)

import xgboost as xgb
from xgboost import rabit # type: ignore

try:
from xgboost.training import aggcv # type: ignore
except ImportError:
from xgboost.callback import _aggcv as aggcv # type: ignore

# pylint:enable = import-outside-toplevel
if not self.state:
self.init(model)
booster: xgb.Booster = model
iteration: int = epoch
cvfolds: List[xgb.training.CVPack] = self.cvfolds
##### Evaluation #####
# `eval_result` is a list of (key, score)
eval_result: List[Tuple[str, float]] = []
if cvfolds is None:
eval_result = list(
itertools_chain.from_iterable(
[
(key, float(value))
for key, value in map(
lambda x: x.split(":"),
booster.eval_set(
evals=self.evals,
iteration=iteration,
feval=feval,
).split()[1:],
)
for fold in cvfolds
)
]
for feval in self.fevals
]
for feval in self.fevals
)
)
)
eval_result = list(eval_result)
eval_result.sort(key=self.sort_key)

##### Print eval result #####
if self.verbose_eval and iteration % self.verbose_eval == 0:
info = []
for key, score in eval_result:
if "null" not in key:
info.append(f"{key}: {score:.6f}")
logger.debug("XGB iter %3d: %s", iteration, "\t".join(info))

##### Choose score and do early stopping #####
score = None
for key, _score in eval_result:
if key == self.focused_metric:
score = _score
break
assert score is not None

best_score = self.state["best_score"]
best_iteration = self.state["best_iteration"]
if score < best_score:
tab = "\t" # to work with f-string
msg = f"[{epoch}] {tab.join([_fmt_metric(x) for x in eval_result])}"
self.state["best_msg"] = msg
self.state["best_score"] = score
self.state["best_iteration"] = epoch
# save the property to attributes, so they will occur in checkpoint.
if model is not None:
model.set_attr(
best_score=str(self.state["best_score"]),
best_iteration=str(self.state["best_iteration"]),
best_msg=self.state["best_msg"],
else:
eval_result = list(
itertools_chain.from_iterable(
[
(key, score)
for key, score, _std in aggcv(
fold.eval(
iteration=iteration,
feval=feval,
)
for fold in cvfolds
)
]
for feval in self.fevals
)
)
elif epoch - best_iteration >= self.early_stopping_rounds:
best_msg = self.state["best_msg"]

if self.verbose_eval and rabit.get_rank() == 0:
logger.debug("XGB stopped. Best iteration: %s ", best_msg)
return True # instead of raising EarlyStopException, returning True to end the training
# False to indicate training should not stop.
return False
eval_result = list(eval_result)
eval_result.sort(key=self.sort_key)

##### Print eval result #####
if self.verbose_eval and iteration % self.verbose_eval == 0:
info = []
for key, score in eval_result:
if "null" not in key:
info.append(f"{key}: {score:.6f}")
logger.debug("XGB iter %3d: %s", iteration, "\t".join(info))

##### Choose score and do early stopping #####
score = None
for key, _score in eval_result:
if key == self.focused_metric:
score = _score
break
assert score is not None

best_score = self.state["best_score"]
best_iteration = self.state["best_iteration"]
if score < best_score:
tab = "\t" # to work with f-string
msg = f"[{epoch}] {tab.join([_fmt_metric(x) for x in eval_result])}"
self.state["best_msg"] = msg
self.state["best_score"] = score
self.state["best_iteration"] = epoch
# save the property to attributes, so they will occur in checkpoint.
if model is not None:
model.set_attr(
best_score=str(self.state["best_score"]),
best_iteration=str(self.state["best_iteration"]),
best_msg=self.state["best_msg"],
)
elif epoch - best_iteration >= self.early_stopping_rounds:
best_msg = self.state["best_msg"]

if self.verbose_eval and rabit.get_rank() == 0:
logger.debug("XGB stopped. Best iteration: %s ", best_msg)
# instead of raising EarlyStopException, returning True to end the training
return True
# False to indicate training should not stop.
return False

return XGBoostCustomCallback(
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
fevals=fevals,
evals=evals,
focused_metric=focused_metric,
cvfolds=cvfolds,
)
Loading

0 comments on commit 4e4089e

Please sign in to comment.