diff --git a/backend/kale/marshal/backends.py b/backend/kale/marshal/backends.py index 6742a47ce..2ecefb945 100644 --- a/backend/kale/marshal/backends.py +++ b/backend/kale/marshal/backends.py @@ -87,6 +87,50 @@ def resource_pandas_save(obj, path, **kwargs): fallback_save(obj, path, **kwargs) +@resource_load.register(r'.*\.dmatrix') +def resource_dmatrix_load(uri, **kwargs): + """Load an XGBoost DMatrix resource.""" + try: + import xgboost as xgb + log.info("Loading XGBoost DMatrix obj: %s", _get_obj_name(uri)) + return xgb.DMatrix(uri) + except ImportError: + return fallback_load(uri, **kwargs) + + +@resource_save.register(r'xgboost.core.DMatrix') +def resource_dmatrix_save(obj, path, **kwargs): + """Save an XGBoost DMatrix object.""" + try: + log.info("Saving XGBoost DMatrix obj: %s", _get_obj_name(path)) + obj.save_binary(path + '.dmatrix') + except ImportError: + fallback_save(obj, path, **kwargs) + + +@resource_load.register(r'.*\.bst') +def resource_xgb_load(uri, **kwargs): + """Load an XGBoost Model resource.""" + try: + import xgboost as xgb + log.info("Loading XGBoost Model obj: %s", _get_obj_name(uri)) + obj_xgb = xgb.Booster() + obj_xgb.load_model(uri) + return obj_xgb + except ImportError: + return fallback_load(uri, **kwargs) + + +@resource_save.register(r'xgboost.core.Booster') +def resource_xgb_save(obj, path, **kwargs): + """Save an XGBoost Model object.""" + try: + log.info("Saving XGBoost model obj: %s", _get_obj_name(path)) + obj.save_model(path + '.bst') + except ImportError: + fallback_save(obj, path, **kwargs) + + @resource_load.register(r'.*\.pt') def resource_torch_load(uri, **kwargs): """Load a torch resource."""