Skip to content

Commit

Permalink
Remove deprecation warning for **kwargs and support them in partial_f…
Browse files Browse the repository at this point in the history
…it (#238)
  • Loading branch information
adriangb authored Aug 2, 2021
1 parent cd845db commit 86093c4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 29 deletions.
40 changes: 19 additions & 21 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,6 @@
from scikeras.utils.transformers import ClassifierLabelEncoder, RegressorTargetEncoder


_kwarg_warn = """Passing estimator parameters as keyword arguments (aka as `**kwargs`) to `{0}` is not supported by the Scikit-Learn API, and will be removed in a future version of SciKeras.
To resolve this issue, either set these parameters in the constructor (e.g., `est = BaseWrapper(..., foo=bar)`) or via `set_params` (e.g., `est.set_params(foo=bar)`). The following parameters were passed to `{0}`:
{1}
More detail is available at https://www.adriangb.com/scikeras/migration.html#variable-keyword-arguments-in-fit-and-predict
"""


class BaseWrapper(BaseEstimator):
"""Implementation of the scikit-learn classifier API for Keras.
Expand Down Expand Up @@ -734,10 +724,6 @@ def fit(self, X, y, sample_weight=None, **kwargs) -> "BaseWrapper":
BaseWrapper
A reference to the instance that can be chain called (``est.fit(X,y).transform(X)``).
"""
if kwargs:
kwarg_list = "\n * ".join([f"`{k}={v}`" for k, v in kwargs.items()])
warnings.warn(_kwarg_warn.format("fit", kwarg_list))

# epochs via kwargs > fit__epochs > epochs
kwargs["epochs"] = kwargs.get(
"epochs", getattr(self, "fit__epochs", self.epochs)
Expand Down Expand Up @@ -919,7 +905,7 @@ def _fit(
**kwargs,
)

def partial_fit(self, X, y, sample_weight=None) -> "BaseWrapper":
def partial_fit(self, X, y, sample_weight=None, **kwargs) -> "BaseWrapper":
"""Fit the estimator for a single epoch, preserving the current
training history and model parameters.
Expand All @@ -934,20 +920,32 @@ def partial_fit(self, X, y, sample_weight=None) -> "BaseWrapper":
sample_weight : array-like of shape (n_samples,), default=None
Array of weights that are assigned to individual samples.
If not provided, then each sample is given unit weight.
**kwargs : Dict[str, Any]
Extra arguments to route to ``Model.fit``.
Returns
-------
BaseWrapper
A reference to the instance that can be chain called
(ex: instance.partial_fit(X, y).transform(X) )
"""
if "epochs" in kwargs:
raise TypeError(
"Invalid argument `epochs` to `partial_fit`: `partial_fit` always trains for 1 epoch"
)
if "initial_epoch" in kwargs:
raise TypeError(
"Invalid argument `initial_epoch` to `partial_fit`: `partial_fit` always trains for from the current epoch"
)

self._fit(
X,
y,
sample_weight=sample_weight,
warm_start=True,
epochs=1,
initial_epoch=self.current_epoch,
**kwargs,
)
return self

Expand All @@ -957,10 +955,6 @@ def _predict_raw(self, X, **kwargs):
For classification, this corresponds to predict_proba.
For regression, this corresponds to predict.
"""
if kwargs:
kwarg_list = "\n * ".join([f"`{k}={v}`" for k, v in kwargs.items()])
warnings.warn(_kwarg_warn.format("predict", kwarg_list), stacklevel=2)

# check if fitted
if not self.initialized_:
raise NotFittedError(
Expand Down Expand Up @@ -1464,7 +1458,9 @@ def fit(self, X, y, sample_weight=None, **kwargs) -> "KerasClassifier":
super().fit(X=X, y=y, sample_weight=sample_weight, **kwargs)
return self

def partial_fit(self, X, y, classes=None, sample_weight=None) -> "KerasClassifier":
def partial_fit(
self, X, y, classes=None, sample_weight=None, **kwargs
) -> "KerasClassifier":
"""Fit classifier for a single epoch, preserving the current epoch
and all model parameters and state.
Expand All @@ -1485,6 +1481,8 @@ def partial_fit(self, X, y, classes=None, sample_weight=None) -> "KerasClassifie
sample_weight : array-like of shape (n_samples,), default=None
Array of weights that are assigned to individual samples.
If not provided, then each sample is given unit weight.
**kwargs : Dict[str, Any]
Extra arguments to route to ``Model.fit``.
Returns
-------
Expand All @@ -1498,7 +1496,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None) -> "KerasClassifie
if self.class_weight is not None:
sample_weight = 1 if sample_weight is None else sample_weight
sample_weight *= compute_sample_weight(class_weight=self.class_weight, y=y)
super().partial_fit(X, y, sample_weight=sample_weight)
super().partial_fit(X, y, sample_weight=sample_weight, **kwargs)
return self

def predict_proba(self, X, **kwargs):
Expand Down
18 changes: 10 additions & 8 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,7 @@ def test_kwargs(wrapper, builder):
# check fit
match = "estimator parameters as keyword arguments"
with mock.patch.object(est.model_, "fit", side_effect=est.model_.fit) as mock_fit:
with pytest.warns(UserWarning, match=match.format("fit")):
est.fit(
X, y, batch_size=kwarg_batch_size, epochs=kwarg_epochs, **extra_kwargs
)
est.fit(X, y, batch_size=kwarg_batch_size, epochs=kwarg_epochs, **extra_kwargs)
call_args = mock_fit.call_args_list
assert len(call_args) == 1
call_kwargs = call_args[0][1]
Expand All @@ -297,16 +294,14 @@ def test_kwargs(wrapper, builder):
with mock.patch.object(
est.model_, "predict", side_effect=est.model_.predict
) as mock_predict:
with pytest.warns(UserWarning, match=match.format("predict")):
est.predict(X, batch_size=kwarg_batch_size, **extra_kwargs)
est.predict(X, batch_size=kwarg_batch_size, **extra_kwargs)
call_args = mock_predict.call_args_list
assert len(call_args) == 1
call_kwargs = call_args[0][1]
assert "batch_size" in call_kwargs
assert call_kwargs["batch_size"] == kwarg_batch_size
if isinstance(est, KerasClassifier):
with pytest.warns(UserWarning, match=match.format("predict")):
est.predict_proba(X, batch_size=kwarg_batch_size, **extra_kwargs)
est.predict_proba(X, batch_size=kwarg_batch_size, **extra_kwargs)
call_args = mock_predict.call_args_list
assert len(call_args) == 2
call_kwargs = call_args[1][1]
Expand All @@ -323,6 +318,13 @@ def test_kwargs(wrapper, builder):
)


@pytest.mark.parametrize("kwargs", (dict(epochs=1), dict(initial_epoch=1)))
def test_partial_fit_epoch_kwargs(kwargs):
est = KerasClassifier(dynamic_classifier)
with pytest.raises(TypeError, match="Invalid argument"):
est.partial_fit([[1]], [1], **kwargs)


@pytest.mark.parametrize("length", (10, 100))
@pytest.mark.parametrize("prefix", ("", "fit__"))
@pytest.mark.parametrize("base", ("validation_batch_size", "batch_size"))
Expand Down

0 comments on commit 86093c4

Please sign in to comment.