-
Notifications
You must be signed in to change notification settings - Fork 225
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for serializing detectors with scikit-learn backends and/or m…
…odels (#642)
- Loading branch information
1 parent
b915d63
commit 1898ad2
Showing
14 changed files
with
660 additions
and
403 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from alibi_detect.saving._sklearn.saving import save_model_config as save_model_config_sk | ||
from alibi_detect.saving._sklearn.loading import load_model as load_model_sk | ||
|
||
__all__ = [ | ||
"save_model_config_sk", | ||
"load_model_sk" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
import joblib | ||
from sklearn.base import BaseEstimator | ||
|
||
|
||
def load_model(filepath: Union[str, os.PathLike], | ||
) -> BaseEstimator: | ||
""" | ||
Load scikit-learn (or xgboost) model. Models are assumed to be a subclass of :class:`~sklearn.base.BaseEstimator`. | ||
This includes xgboost models following the scikit-learn API | ||
(see https://xgboost.readthedocs.io/en/latest/python/python_api.html#module-xgboost.sklearn). | ||
Parameters | ||
---------- | ||
filepath | ||
Saved model directory. | ||
Returns | ||
------- | ||
Loaded model. | ||
""" | ||
model_dir = Path(filepath) | ||
return joblib.load(model_dir.joinpath('model.joblib')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import logging | ||
import os | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
import joblib | ||
from sklearn.base import BaseEstimator | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def save_model_config(model: BaseEstimator, | ||
base_path: Path, | ||
local_path: Path = Path('.')) -> dict: | ||
""" | ||
Save a scikit-learn (or xgboost) model to a config dictionary. | ||
Models are assumed to be a subclass of :class:`~sklearn.base.BaseEstimator`. This includes xgboost models | ||
following the scikit-learn API | ||
(see https://xgboost.readthedocs.io/en/latest/python/python_api.html#module-xgboost.sklearn). | ||
Parameters | ||
---------- | ||
model | ||
The model to save. | ||
base_path | ||
Base filepath to save to (the location of the `config.toml` file). | ||
local_path | ||
A local (relative) filepath to append to base_path. | ||
Returns | ||
------- | ||
The model config dict. | ||
""" | ||
filepath = base_path.joinpath(local_path) | ||
save_model(model, filepath=filepath, save_dir='model') | ||
cfg_model = { | ||
'flavour': 'sklearn', | ||
'src': local_path.joinpath('model') | ||
} | ||
return cfg_model | ||
|
||
|
||
def save_model(model: BaseEstimator, | ||
filepath: Union[str, os.PathLike], | ||
save_dir: Union[str, os.PathLike] = 'model') -> None: | ||
""" | ||
Save scikit-learn (and xgboost) models. Models are assumed to be a subclass of :class:`~sklearn.base.BaseEstimator`. | ||
This includes xgboost models following the scikit-learn API | ||
(see https://xgboost.readthedocs.io/en/latest/python/python_api.html#module-xgboost.sklearn). | ||
Parameters | ||
---------- | ||
model | ||
The tf.keras.Model to save. | ||
filepath | ||
Save directory. | ||
save_dir | ||
Name of folder to save to within the filepath directory. | ||
""" | ||
# create folder to save model in | ||
model_path = Path(filepath).joinpath(save_dir) | ||
if not model_path.is_dir(): | ||
logger.warning('Directory {} does not exist and is now created.'.format(model_path)) | ||
model_path.mkdir(parents=True, exist_ok=True) | ||
|
||
# save model | ||
model_path = model_path.joinpath('model.joblib') | ||
joblib.dump(model, model_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from pytest_cases import param_fixture, parametrize, parametrize_with_cases | ||
|
||
from alibi_detect.saving.tests.datasets import ContinuousData | ||
from alibi_detect.saving.tests.models import classifier_model, xgb_classifier_model | ||
|
||
from alibi_detect.saving.loading import _load_model_config | ||
from alibi_detect.saving.saving import _path2str, _save_model_config | ||
from alibi_detect.saving.schemas import ModelConfig | ||
|
||
backend = param_fixture("backend", ['sklearn']) | ||
|
||
|
||
@parametrize_with_cases("data", cases=ContinuousData.data_synthetic_nd, prefix='data_') | ||
@parametrize('model', [classifier_model, xgb_classifier_model]) | ||
def test_save_model_sk(data, model, tmp_path): | ||
""" | ||
Unit test for _save_model_config and _load_model_config with scikit-learn and xgboost model. | ||
""" | ||
# Save model | ||
filepath = tmp_path | ||
cfg_model, _ = _save_model_config(model, base_path=filepath) | ||
cfg_model = _path2str(cfg_model) | ||
cfg_model = ModelConfig(**cfg_model).dict() | ||
assert tmp_path.joinpath('model').is_dir() | ||
assert tmp_path.joinpath('model/model.joblib').is_file() | ||
|
||
# Adjust config | ||
cfg_model['src'] = tmp_path.joinpath('model') # Need to manually set to absolute path here | ||
|
||
# Load model | ||
model_load = _load_model_config(cfg_model) | ||
assert isinstance(model_load, type(model)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from pytest_cases import param_fixture, parametrize, parametrize_with_cases | ||
|
||
from alibi_detect.saving.tests.datasets import ContinuousData | ||
from alibi_detect.saving.tests.models import encoder_model | ||
|
||
from alibi_detect.cd.tensorflow import HiddenOutput as HiddenOutput_tf | ||
from alibi_detect.saving.loading import _load_model_config, _load_optimizer_config | ||
from alibi_detect.saving.saving import _path2str, _save_model_config | ||
from alibi_detect.saving.schemas import ModelConfig | ||
|
||
backend = param_fixture("backend", ['tensorflow']) | ||
|
||
|
||
def test_load_optimizer_tf(backend): | ||
"Test the tensorflow _load_optimizer_config." | ||
class_name = 'Adam' | ||
learning_rate = 0.01 | ||
epsilon = 1e-7 | ||
amsgrad = False | ||
|
||
# Load | ||
cfg_opt = { | ||
'class_name': class_name, | ||
'config': { | ||
'name': class_name, | ||
'learning_rate': learning_rate, | ||
'epsilon': epsilon, | ||
'amsgrad': amsgrad | ||
} | ||
} | ||
optimizer = _load_optimizer_config(cfg_opt, backend=backend) | ||
assert type(optimizer).__name__ == class_name | ||
assert optimizer.learning_rate == learning_rate | ||
assert optimizer.epsilon == epsilon | ||
assert optimizer.amsgrad == amsgrad | ||
|
||
|
||
@parametrize_with_cases("data", cases=ContinuousData.data_synthetic_nd, prefix='data_') | ||
@parametrize('model', [encoder_model]) | ||
@parametrize('layer', [None, -1]) | ||
def test_save_model_tf(data, model, layer, tmp_path): | ||
""" | ||
Unit test for _save_model_config and _load_model_config with tensorflow model. | ||
""" | ||
# Save model | ||
filepath = tmp_path | ||
input_shape = (data[0].shape[1],) | ||
cfg_model, _ = _save_model_config(model, base_path=filepath, input_shape=input_shape) | ||
cfg_model = _path2str(cfg_model) | ||
cfg_model = ModelConfig(**cfg_model).dict() | ||
assert tmp_path.joinpath('model').is_dir() | ||
assert tmp_path.joinpath('model/model.h5').is_file() | ||
|
||
# Adjust config | ||
cfg_model['src'] = tmp_path.joinpath('model') # Need to manually set to absolute path here | ||
if layer is not None: | ||
cfg_model['layer'] = layer | ||
|
||
# Load model | ||
model_load = _load_model_config(cfg_model) | ||
if layer is None: | ||
assert isinstance(model_load, type(model)) | ||
else: | ||
assert isinstance(model_load, HiddenOutput_tf) |
Oops, something went wrong.