From 677e40f29136365e96c8aa22f50e53d64562d92a Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Wed, 11 Oct 2023 17:21:26 +0900 Subject: [PATCH] feat(sklearn): add `recursive_strict` parameter to `patch()` (#86) --- src/boost_loss/sklearn.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/boost_loss/sklearn.py b/src/boost_loss/sklearn.py index 8e73e8c..d2bdcf0 100644 --- a/src/boost_loss/sklearn.py +++ b/src/boost_loss/sklearn.py @@ -48,6 +48,7 @@ def apply_custom_loss( copy: bool = True, target_transformer: BaseEstimator | Any | None = StandardScaler(), recursive: bool = True, + recursive_strict: bool = False, ) -> TEstimator | TransformedTargetRegressor: """Apply custom loss to the estimator. @@ -234,7 +235,13 @@ def predict_var( TAny = TypeVar("TAny") -def patch(estimator: TAny, *, copy: bool = True, recursive: bool = True) -> TAny: +def patch( + estimator: TAny, + *, + copy: bool = True, + recursive: bool = True, + recursive_strict: bool = False, +) -> TAny: """Patch estimator if it is supported. (`patch_ngboost` and `patch_catboost`.) The patch will not apply if the estimator is cloned using `sklearn.base.clone()` and requires re-patching. @@ -247,12 +254,17 @@ def patch(estimator: TAny, *, copy: bool = True, recursive: bool = True) -> TAny Whether to copy the estimator before patching, by default True recursive : bool, optional Whether to recursively patch the estimator, by default True + recursive_strict : bool, optional + Whether to recursively patch the estimator's attributes, + lists, tuples, sets, and frozensets as well, by default False Returns ------- TAny The patched estimator. """ + if recursive_strict and not recursive: + raise ValueError("recursive_strict requires recursive=True") if copy: estimator = clone(estimator) if importlib.util.find_spec("ngboost") is not None: @@ -263,6 +275,14 @@ def patch(estimator: TAny, *, copy: bool = True, recursive: bool = True) -> TAny if recursive and hasattr(estimator, "get_params"): for _, value in estimator.get_params(deep=True).items(): - patch(value, copy=False, recursive=False) + patch(value, copy=False, recursive=False, recursive_strict=recursive_strict) + if recursive_strict: + if hasattr(estimator, "__dict__"): + for _, value in estimator.__dict__.items(): + patch(value, copy=False, recursive=True, recursive_strict=True) + elif isinstance(estimator, (list, tuple, set, frozenset)): + # https://github.com/scikit-learn/scikit-learn/blob/364c77e047ca08a95862becf40a04fe9d4cd2c98/sklearn/base.py#L66 + for value in estimator: + patch(value, copy=False, recursive=True, recursive_strict=True) return estimator