Skip to content

Commit

Permalink
Remove load_dir from sklearn load_model
Browse files Browse the repository at this point in the history
  • Loading branch information
ascillitoe committed Oct 11, 2022
1 parent 4d530e5 commit c3fd225
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
5 changes: 1 addition & 4 deletions alibi_detect/saving/_sklearn/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


def load_model(filepath: Union[str, os.PathLike],
load_dir: str = 'model',
) -> BaseEstimator:
"""
Load scikit-learn (or xgboost) model. Models are assumed to be a subclass of :class:`~sklearn.base.BaseEstimator`.
Expand All @@ -18,12 +17,10 @@ def load_model(filepath: Union[str, os.PathLike],
----------
filepath
Saved model directory.
load_dir
Name of saved model folder within the filepath directory.
Returns
-------
Loaded model.
"""
model_dir = Path(filepath).joinpath(load_dir)
model_dir = Path(filepath)
return joblib.load(model_dir.joinpath('model.joblib'))
2 changes: 1 addition & 1 deletion alibi_detect/saving/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _load_model_config(cfg: dict) -> Callable:
if flavour == 'tensorflow':
model = load_model_tf(src, load_dir='.', custom_objects=custom_obj, layer=layer)
elif flavour == 'sklearn':
model = load_model_sk(src, load_dir='.')
model = load_model_sk(src)
else:
raise NotImplementedError('Loading of PyTorch models not currently supported')

Expand Down

0 comments on commit c3fd225

Please sign in to comment.