From c3fd2256c92bc2ff96b8d28dde946d90cde8cefa Mon Sep 17 00:00:00 2001 From: Ashley Scillitoe Date: Tue, 11 Oct 2022 14:46:46 +0100 Subject: [PATCH] Remove load_dir from sklearn load_model --- alibi_detect/saving/_sklearn/loading.py | 5 +---- alibi_detect/saving/loading.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/alibi_detect/saving/_sklearn/loading.py b/alibi_detect/saving/_sklearn/loading.py index b52ff5283..b6ed23912 100644 --- a/alibi_detect/saving/_sklearn/loading.py +++ b/alibi_detect/saving/_sklearn/loading.py @@ -7,7 +7,6 @@ def load_model(filepath: Union[str, os.PathLike], - load_dir: str = 'model', ) -> BaseEstimator: """ Load scikit-learn (or xgboost) model. Models are assumed to be a subclass of :class:`~sklearn.base.BaseEstimator`. @@ -18,12 +17,10 @@ def load_model(filepath: Union[str, os.PathLike], ---------- filepath Saved model directory. - load_dir - Name of saved model folder within the filepath directory. Returns ------- Loaded model. """ - model_dir = Path(filepath).joinpath(load_dir) + model_dir = Path(filepath) return joblib.load(model_dir.joinpath('model.joblib')) diff --git a/alibi_detect/saving/loading.py b/alibi_detect/saving/loading.py index a6ff8c5c9..385dfcb31 100644 --- a/alibi_detect/saving/loading.py +++ b/alibi_detect/saving/loading.py @@ -267,7 +267,7 @@ def _load_model_config(cfg: dict) -> Callable: if flavour == 'tensorflow': model = load_model_tf(src, load_dir='.', custom_objects=custom_obj, layer=layer) elif flavour == 'sklearn': - model = load_model_sk(src, load_dir='.') + model = load_model_sk(src) else: raise NotImplementedError('Loading of PyTorch models not currently supported')