diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index c570517d..dace2d3f 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -1,5 +1,6 @@ """Wrapper for using the Scikit-Learn API with Keras models. """ +import functools import inspect import warnings @@ -313,7 +314,7 @@ def _check_model_param(self): def final_build_fn(): return model - elif inspect.isfunction(model): + elif inspect.isfunction(model) or isinstance(model, functools.partial): if hasattr(self, "_keras_build_fn"): raise ValueError( "This class cannot implement ``_keras_build_fn`` if" diff --git a/tests/test_api.py b/tests/test_api.py index 4c251bc7..3c51d6e3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,7 @@ """Tests for Scikit-learn API wrapper.""" import pickle +from functools import partial from typing import Any, Dict import numpy as np @@ -872,3 +873,21 @@ def test_prebuilt_model(self, wrapper): np.testing.assert_allclose(y_pred_keras, y_pred_scikeras) # Check that we are still using the same model object assert est.model_ is m2 + + +def build_model_for_partial_wrapping(input_size: int = 100) -> keras.Model: + inp = keras.layers.Input((input_size,)) + out = keras.layers.Dense(1)(inp) + return keras.Model(inp, out) + + +def test_partial_model_build_fn() -> None: + X = np.random.random((100, 1)) + y = np.random.uniform(low=0, high=3, size=(100,)) + + build_fn = partial(build_model_for_partial_wrapping, input_size=1) + + reg = KerasRegressor(build_fn, loss="mse") + reg = reg.fit(X, y) + reg = pickle.loads(pickle.dumps(reg)) + reg.predict(X)