Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes early stopping with XGBoost 2.0 #597

Merged
merged 17 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ jobs:
numpy.version: ''
scipy.version: ''

Python311-1150-RT1160-xgb175-lgbm40:
Python311-1150-RT1163-xgb175-lgbm40:
python.version: '3.11'
ONNX_PATH: 'onnx==1.15.0'
ONNXRT_PATH: 'onnxruntime==1.16.2'
ONNXRT_PATH: 'onnxruntime==1.16.3'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
xgboost.version: '==1.7.5'
xgboost.version: '>=1.7.5,<2'
numpy.version: ''
scipy.version: ''

Expand All @@ -41,7 +41,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.16.2'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
xgboost.version: '==1.7.5'
xgboost.version: '>=1.7.5,<2'
numpy.version: ''
scipy.version: ''

Expand All @@ -51,7 +51,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
lightgbm.version: '<4.0'
xgboost.version: '==1.7.5'
xgboost.version: '>=1.7.5,<2'
numpy.version: ''
scipy.version: ''

Expand All @@ -61,7 +61,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.14.0'
COREML_PATH: NONE
lightgbm.version: '<4.0'
xgboost.version: '==1.7.5'
xgboost.version: '>=1.7.5,<2'
numpy.version: ''
scipy.version: ''

Expand All @@ -71,7 +71,7 @@ jobs:
ONNXRT_PATH: 'onnxruntime==1.15.1'
COREML_PATH: NONE
lightgbm.version: '>=4.0'
xgboost.version: '==1.7.5'
xgboost.version: '>=1.7.5,<2'
numpy.version: ''
scipy.version: '==1.8.0'

Expand Down
8 changes: 8 additions & 0 deletions .azure-pipelines/win32-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ jobs:
strategy:
matrix:

Python311-1150-RT1163:
python.version: '3.11'
ONNX_PATH: 'onnx==1.15.0'
ONNXRT_PATH: 'onnxruntime==1.16.3'
COREML_PATH: NONE
numpy.version: ''
xgboost.version: '2.0.2'

Python311-1150-RT1162:
python.version: '3.11'
ONNX_PATH: 'onnx==1.15.0'
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 1.12.0

* Fix early stopping for XGBClassifier and xgboost > 2
[#597](https://github.com/onnx/onnxmltools/pull/597)
* Fix discrepancies with XGBRegressor and xgboost > 2
[#670](https://github.com/onnx/onnxmltools/pull/670)
* Support count:poisson for XGBRegressor
Expand Down
8 changes: 8 additions & 0 deletions onnxmltools/convert/xgboost/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def get_xgb_params(xgb_node):
bs = float(config["learner"]["learner_model_param"]["base_score"])
# xgboost >= 2.0
params["base_score"] = bs

bst = xgb_node.get_booster()
if hasattr(bst, "best_ntree_limit"):
params["best_ntree_limit"] = bst.best_ntree_limit
if "gradient_booster" in config["learner"]:
gbp = config["learner"]["gradient_booster"]["gbtree_model_param"]
if "num_trees" in gbp:
params["best_ntree_limit"] = int(gbp["num_trees"])
return params


Expand Down
25 changes: 17 additions & 8 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,20 @@ def common_members(xgb_node, inputs):
params = XGBConverter.get_xgb_params(xgb_node)
objective = params["objective"]
base_score = params["base_score"]
if hasattr(xgb_node, "best_ntree_limit"):
best_ntree_limit = xgb_node.best_ntree_limit
elif hasattr(xgb_node, "best_iteration"):
best_ntree_limit = xgb_node.best_iteration + 1
else:
best_ntree_limit = params.get("best_ntree_limit", None)
if base_score is None:
base_score = 0.5
booster = xgb_node.get_booster()
# The json format was available in October 2017.
# XGBoost 0.7 was the first version released with it.
js_tree_list = booster.get_dump(with_stats=True, dump_format="json")
js_trees = [json.loads(s) for s in js_tree_list]
return objective, base_score, js_trees
return objective, base_score, js_trees, best_ntree_limit

@staticmethod
def _get_default_tree_attribute_pairs(is_classifier):
Expand Down Expand Up @@ -231,17 +237,17 @@ def _get_default_tree_attribute_pairs():
def convert(scope, operator, container):
xgb_node = operator.raw_operator
inputs = operator.inputs
objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs)
objective, base_score, js_trees, best_ntree_limit = XGBConverter.common_members(
xgb_node, inputs
)

if objective in ["reg:gamma", "reg:tweedie"]:
raise RuntimeError("Objective '{}' not supported.".format(objective))

attr_pairs = XGBRegressorConverter._get_default_tree_attribute_pairs()
attr_pairs["base_values"] = [base_score]

bst = xgb_node.get_booster()
best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees))
if best_ntree_limit < len(js_trees):
if best_ntree_limit and best_ntree_limit < len(js_trees):
js_trees = js_trees[:best_ntree_limit]

XGBConverter.fill_tree_attributes(
Expand Down Expand Up @@ -289,7 +295,9 @@ def convert(scope, operator, container):
xgb_node = operator.raw_operator
inputs = operator.inputs

objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs)
objective, base_score, js_trees, best_ntree_limit = XGBConverter.common_members(
xgb_node, inputs
)

params = XGBConverter.get_xgb_params(xgb_node)
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)
Expand All @@ -305,8 +313,9 @@ def convert(scope, operator, container):
else:
ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators

bst = xgb_node.get_booster()
best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl
best_ntree_limit = best_ntree_limit or len(js_trees)
if ncl > 0:
best_ntree_limit *= ncl
if 0 < best_ntree_limit < len(js_trees):
js_trees = js_trees[:best_ntree_limit]
attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
Expand Down
Loading
Loading