Skip to content

Commit

Permalink
marshal: Add support for XGBoost
Browse files Browse the repository at this point in the history
This commit adds support for serializing XGBoost DMatrix and Booster objects.
DMatrix is an internal data structure, which is optimized for both memory
efficiency and training speed. It is not pickle-able so we should use its
own  method. On the other hand, Booster, which is an XGBoost
model is pickle-able, but it's better to use its native method.
Similarly, we can use the corresponding load methods to load a DMatrix or
Booster object.

Signed-off-by: Dimitris Poulopoulos <[email protected]>
Reviewed-by: Stefano Fioravanzo <[email protected]>
  • Loading branch information
Dimitris Poulopoulos authored and StefanoFioravanzo committed Nov 4, 2020
1 parent 6b112b1 commit 1b9a848
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions backend/kale/marshal/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,50 @@ def resource_pandas_save(obj, path, **kwargs):
fallback_save(obj, path, **kwargs)


@resource_load.register(r'.*\.dmatrix')
def resource_dmatrix_load(uri, **kwargs):
"""Load an XGBoost DMatrix resource."""
try:
import xgboost as xgb
log.info("Loading XGBoost DMatrix obj: %s", _get_obj_name(uri))
return xgb.DMatrix(uri)
except ImportError:
return fallback_load(uri, **kwargs)


@resource_save.register(r'xgboost.core.DMatrix')
def resource_dmatrix_save(obj, path, **kwargs):
"""Save an XGBoost DMatrix object."""
try:
log.info("Saving XGBoost DMatrix obj: %s", _get_obj_name(path))
obj.save_binary(path + '.dmatrix')
except ImportError:
fallback_save(obj, path, **kwargs)


@resource_load.register(r'.*\.bst')
def resource_xgb_load(uri, **kwargs):
"""Load an XGBoost Model resource."""
try:
import xgboost as xgb
log.info("Loading XGBoost Model obj: %s", _get_obj_name(uri))
obj_xgb = xgb.Booster()
obj_xgb.load_model(uri)
return obj_xgb
except ImportError:
return fallback_load(uri, **kwargs)


@resource_save.register(r'xgboost.core.Booster')
def resource_xgb_save(obj, path, **kwargs):
"""Save an XGBoost Model object."""
try:
log.info("Saving XGBoost model obj: %s", _get_obj_name(path))
obj.save_model(path + '.bst')
except ImportError:
fallback_save(obj, path, **kwargs)


@resource_load.register(r'.*\.pt')
def resource_torch_load(uri, **kwargs):
"""Load a torch resource."""
Expand Down

0 comments on commit 1b9a848

Please sign in to comment.