Skip to content

Commit

Permalink
feat(sklearn): add recursive_strict parameter to patch() (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
34j authored Oct 11, 2023
1 parent 41f231c commit 677e40f
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/boost_loss/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def apply_custom_loss(
copy: bool = True,
target_transformer: BaseEstimator | Any | None = StandardScaler(),
recursive: bool = True,
recursive_strict: bool = False,
) -> TEstimator | TransformedTargetRegressor:
"""Apply custom loss to the estimator.
Expand Down Expand Up @@ -234,7 +235,13 @@ def predict_var(
TAny = TypeVar("TAny")


def patch(estimator: TAny, *, copy: bool = True, recursive: bool = True) -> TAny:
def patch(
estimator: TAny,
*,
copy: bool = True,
recursive: bool = True,
recursive_strict: bool = False,
) -> TAny:
"""Patch estimator if it is supported. (`patch_ngboost` and `patch_catboost`.)
The patch will not apply if the estimator is cloned using `sklearn.base.clone()`
and requires re-patching.
Expand All @@ -247,12 +254,17 @@ def patch(estimator: TAny, *, copy: bool = True, recursive: bool = True) -> TAny
Whether to copy the estimator before patching, by default True
recursive : bool, optional
Whether to recursively patch the estimator, by default True
recursive_strict : bool, optional
Whether to recursively patch the estimator's attributes,
lists, tuples, sets, and frozensets as well, by default False
Returns
-------
TAny
The patched estimator.
"""
if recursive_strict and not recursive:
raise ValueError("recursive_strict requires recursive=True")
if copy:
estimator = clone(estimator)
if importlib.util.find_spec("ngboost") is not None:
Expand All @@ -263,6 +275,14 @@ def patch(estimator: TAny, *, copy: bool = True, recursive: bool = True) -> TAny

if recursive and hasattr(estimator, "get_params"):
for _, value in estimator.get_params(deep=True).items():
patch(value, copy=False, recursive=False)
patch(value, copy=False, recursive=False, recursive_strict=recursive_strict)
if recursive_strict:
if hasattr(estimator, "__dict__"):
for _, value in estimator.__dict__.items():
patch(value, copy=False, recursive=True, recursive_strict=True)
elif isinstance(estimator, (list, tuple, set, frozenset)):
# https://github.com/scikit-learn/scikit-learn/blob/364c77e047ca08a95862becf40a04fe9d4cd2c98/sklearn/base.py#L66
for value in estimator:
patch(value, copy=False, recursive=True, recursive_strict=True)

return estimator

0 comments on commit 677e40f

Please sign in to comment.