-
Notifications
You must be signed in to change notification settings - Fork 835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XGBoost runtime fixes #5938
base: master
Are you sure you want to change the base?
XGBoost runtime fixes #5938
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# XGBoost Runtime |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import os | ||
import tempfile | ||
import numpy as np | ||
import pytest | ||
import xgboost as xgb | ||
from XGBoostServer import XGBoostServer | ||
|
||
@pytest.fixture | ||
def model_uri(): | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
X = np.random.rand(100, 5) | ||
y = np.random.randint(2, size=100) | ||
dtrain = xgb.DMatrix(X, label=y) | ||
params = {'objective': 'binary:logistic', 'eval_metric': 'error'} | ||
booster = xgb.train(params, dtrain, num_boost_round=10) | ||
model_path = os.path.join(temp_dir, "model.json") | ||
booster.save_model(model_path) | ||
yield temp_dir | ||
|
||
def test_init_metadata(model_uri): | ||
metadata = {"key": "value"} | ||
metadata_path = os.path.join(model_uri, "metadata.yaml") | ||
with open(metadata_path, "w") as f: | ||
yaml.dump(metadata, f) | ||
|
||
server = XGBoostServer(model_uri) | ||
|
||
loaded_metadata = server.init_metadata() | ||
|
||
# Assert that the loaded metadata matches the original metadata | ||
assert loaded_metadata == metadata | ||
|
||
def test_predict_invalid_input(model_uri): | ||
# Create an instance of XGBoostServer with the model URI | ||
server = XGBoostServer(model_uri) | ||
server.load() | ||
X_test = np.random.rand(10, 3) # Incorrect number of features | ||
with pytest.raises(ValueError): | ||
server.predict(X_test, names=[]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add new line |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
import tempfile | ||
import numpy as np | ||
from unittest import mock | ||
import xgboost as xgb | ||
from XGBoostServer import XGBoostServer | ||
|
||
def test_load_json_model(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we add a test for the old format as well (if this test doesnt exist elsewhere?) |
||
|
||
with tempfile.TemporaryDirectory() as temp_dir: | ||
X = np.random.rand(100, 5) | ||
y = np.random.randint(2, size=100) | ||
dtrain = xgb.DMatrix(X, label=y) | ||
params = {'objective': 'binary:logistic', 'eval_metric': 'error'} | ||
booster = xgb.train(params, dtrain, num_boost_round=10) | ||
model_path = os.path.join(temp_dir, "model.json") | ||
booster.save_model(model_path) | ||
|
||
server = XGBoostServer(temp_dir) | ||
server.load() | ||
|
||
assert server.ready | ||
assert isinstance(server._booster, xgb.Booster) | ||
|
||
def test_predict(): | ||
# Create a temporary directory for the model file | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
# Train a dummy XGBoost model and save it in .json format | ||
X = np.random.rand(100, 5) | ||
y = np.random.randint(2, size=100) | ||
dtrain = xgb.DMatrix(X, label=y) | ||
params = {'objective': 'binary:logistic', 'eval_metric': 'error'} | ||
booster = xgb.train(params, dtrain, num_boost_round=10) | ||
model_path = os.path.join(temp_dir, "model.json") | ||
booster.save_model(model_path) | ||
|
||
# Create an instance of XGBoostServer with the model URI | ||
server = XGBoostServer(temp_dir) | ||
|
||
server.load() | ||
|
||
# Prepare test data | ||
X_test = np.random.rand(10, 5) | ||
|
||
with mock.patch("seldon_core.Storage.download", return_value=temp_dir): | ||
predictions = server.predict(X_test, names=[]) | ||
|
||
# Assert the expected shape and type of predictions | ||
assert isinstance(predictions, np.ndarray) | ||
assert predictions.shape == (10,) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add new line |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,10 +6,13 @@ | |
import yaml | ||
import logging | ||
import xgboost as xgb | ||
import json | ||
from packaging import version | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
BOOSTER_FILE = "model.bst" | ||
BOOSTER_FILE = "model.json" | ||
BOOSTER_FILE_DEPRECATED = "model.bst" | ||
|
||
|
||
class XGBoostServer(SeldonComponent): | ||
|
@@ -22,7 +25,27 @@ def load(self): | |
model_file = os.path.join( | ||
seldon_core.Storage.download(self.model_uri), BOOSTER_FILE | ||
) | ||
self._booster = xgb.Booster(model_file=model_file) | ||
if not os.path.exists(model_file): | ||
# Fallback to deprecated .bst format | ||
model_file = os.path.join( | ||
seldon_core.Storage.download(self.model_uri), BOOSTER_FILE_DEPRECATED | ||
) | ||
if os.path.exists(model_file): | ||
logger.warning( | ||
"Using deprecated .bst format for XGBoost model. " | ||
"Please update to the .json format in the future." | ||
) | ||
else: | ||
raise FileNotFoundError(f"Model file not found: {BOOSTER_FILE} or {BOOSTER_FILE_DEPRECATED}") | ||
|
||
if version.parse(xgb.__version__) < version.parse("1.7.0"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this code path possible given that we build the server with a specific |
||
# Load model using deprecated method for older XGBoost versions | ||
self._booster = xgb.Booster(model_file=model_file) | ||
else: | ||
# Load model using the new .json format for XGBoost >= 1.7.0 | ||
self._booster = xgb.Booster() | ||
self._booster.load_model(model_file) | ||
|
||
self.ready = True | ||
|
||
def predict( | ||
|
@@ -39,12 +62,14 @@ def init_metadata(self): | |
|
||
try: | ||
with open(file_path, "r") as f: | ||
return yaml.safe_load(f.read()) | ||
metadata = yaml.safe_load(f.read()) | ||
# Validate and sanitize the loaded metadata if needed | ||
return metadata | ||
Comment on lines
+65
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this change really required, this looks not to be adding anything compared to the existing code. |
||
except FileNotFoundError: | ||
logger.debug(f"metadata file {file_path} does not exist") | ||
logger.debug(f"Metadata file {file_path} does not exist") | ||
return {} | ||
except yaml.YAMLError: | ||
logger.error( | ||
f"metadata file {file_path} present but does not contain valid yaml" | ||
f"Metadata file {file_path} present but does not contain valid YAML" | ||
) | ||
return {} | ||
return {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add new line |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
scikit-learn == 1.0.2 | ||
scikit-learn >= 1.0.2, <=1.5.0 | ||
numpy >= 1.8.2 | ||
xgboost == 1.4.2 | ||
xgboost >= 1.7.0, <=2.2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.