Skip to content

Commit

Permalink
Allow functools.partial in wrappers.py (#279)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Garcia Badaracco <[email protected]>
  • Loading branch information
jpgard and adriangb authored Jul 21, 2022
1 parent 2c8e9e0 commit 71cabed
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
3 changes: 2 additions & 1 deletion scikeras/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Wrapper for using the Scikit-Learn API with Keras models.
"""
import functools
import inspect
import warnings

Expand Down Expand Up @@ -313,7 +314,7 @@ def _check_model_param(self):
def final_build_fn():
return model

elif inspect.isfunction(model):
elif inspect.isfunction(model) or isinstance(model, functools.partial):
if hasattr(self, "_keras_build_fn"):
raise ValueError(
"This class cannot implement ``_keras_build_fn`` if"
Expand Down
19 changes: 19 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for Scikit-learn API wrapper."""
import pickle

from functools import partial
from typing import Any, Dict

import numpy as np
Expand Down Expand Up @@ -872,3 +873,21 @@ def test_prebuilt_model(self, wrapper):
np.testing.assert_allclose(y_pred_keras, y_pred_scikeras)
# Check that we are still using the same model object
assert est.model_ is m2


def build_model_for_partial_wrapping(input_size: int = 100) -> keras.Model:
inp = keras.layers.Input((input_size,))
out = keras.layers.Dense(1)(inp)
return keras.Model(inp, out)


def test_partial_model_build_fn() -> None:
X = np.random.random((100, 1))
y = np.random.uniform(low=0, high=3, size=(100,))

build_fn = partial(build_model_for_partial_wrapping, input_size=1)

reg = KerasRegressor(build_fn, loss="mse")
reg = reg.fit(X, y)
reg = pickle.loads(pickle.dumps(reg))
reg.predict(X)

0 comments on commit 71cabed

Please sign in to comment.