-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: set build_fn default values as model parameters #22
Conversation
I'm not sure what you mean by that. The Scikit-learn API dictates that model parameters be set at initialization. If passed to |
I think the spirit of that requirment that it should be clear what the possible parameters to the model are based on the signature of it's These wrappers offer the ability to dynamically add kwargs to the class MyClassifier(KerasClassifier):
def __init__(self, build_fn=None, solver="sgd"):
self.solver=solver
super().__init__(build_fn=build_fn) I think that if I was writing these from scratch with no backward compatibility, I would only allow use of the wrappers via sub-classing. Ideally, these wrappers are Full examplefrom scikeras.wrappers import KerasClassifier
from sklearn.model_selection import GridSearchCV
import tensorflow as tf
from tensorflow.keras.datasets import mnist as keras_mnist
from tensorflow.keras.layers import Dense, Activation, Dropout
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import to_categorical
def mnist():
(X_train, y_train), _ = keras_mnist.load_data()
X_train = X_train[:100]
y_train = y_train[:100]
X_train = X_train.reshape(X_train.shape[0], 784)
X_train = X_train.astype("float32")
X_train /= 255
Y_train = to_categorical(y_train, 10)
return X_train, y_train
def build_fn(solver="sgd"):
layers = [
Dense(100, input_shape=(784,), activation="relu"),
Dense(10, input_shape=(100,), activation="softmax"),
]
model = Sequential(layers)
model.compile(loss="categorical_crossentropy", optimizer=solver, metrics=["accuracy"])
return model
class MyClassifier(KerasClassifier):
def __init__(self, build_fn=None, solver="sgd"):
self.solver=solver
super().__init__(build_fn=build_fn)
params = {"solver": ["sgd", "adam"]}
model = MyClassifier(build_fn=build_fn)
assert "solver" in model.get_params()
search = GridSearchCV(model, params)
search.fit(*mnist()) |
I've re-implemented this PR. All tests pass on my own machine. |
Codecov Report
@@ Coverage Diff @@
## master #22 +/- ##
=======================================
Coverage 99.52% 99.52%
=======================================
Files 3 3
Lines 418 425 +7
=======================================
+ Hits 416 423 +7
Misses 2 2
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks good! Just minor comments. Thanks for this work.
Co-authored-by: Adrian Garcia Badaracco <[email protected]>
Co-authored-by: Adrian Garcia Badaracco <[email protected]>
What does this PR implement?
It set default values of
build_fn
to be model parameters. This means model selection searches mirror the Scikit-learn interface.This PR makes this search possible:
This almost exactly mirrors this Scikit-learn example:
Reference issues/PRs
This resolves #18.