Skip to content

Commit

Permalink
[Meta Schedule][XGBoost] enable custom callback func test with xgboos…
Browse files Browse the repository at this point in the history
…t>=1.6.0 (#17168)

enable callback func test with xgboost>=1.6.0
  • Loading branch information
mshr-h authored Jul 18, 2024
1 parent 22a8978 commit 70d86e3
Showing 1 changed file with 6 additions and 19 deletions.
25 changes: 6 additions & 19 deletions tests/python/meta_schedule/test_meta_schedule_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,6 @@ def test_meta_schedule_xgb_model_reupdate():
model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)])


def xgb_version_check():

# pylint: disable=import-outside-toplevel
import xgboost as xgb
from packaging import version

# pylint: enable=import-outside-toplevel
return version.parse(xgb.__version__) >= version.parse("1.6.0")


@unittest.skipIf(xgb_version_check(), "test not supported for xgboost version after 1.6.0")
def test_meta_schedule_xgb_model_callback_as_function():
# pylint: disable=import-outside-toplevel
from itertools import chain as itertools_chain
Expand Down Expand Up @@ -330,14 +319,12 @@ def avg_peak_score(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignor
num_boost_round=10000,
obj=obj,
callbacks=[
partial(
_get_custom_call_back(
early_stopping_rounds=model.early_stopping_rounds,
verbose_eval=model.verbose_eval,
fevals=[rmse, avg_peak_score],
evals=[(d_train.dmatrix, "tr")],
cvfolds=None,
)
_get_custom_call_back(
early_stopping_rounds=model.early_stopping_rounds,
verbose_eval=model.verbose_eval,
fevals=[rmse, avg_peak_score],
evals=[(d_train.dmatrix, "tr")],
cvfolds=None,
)
],
)
Expand Down

0 comments on commit 70d86e3

Please sign in to comment.