Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: set build_fn default values as model parameters #22

Merged
merged 7 commits into from
Aug 13, 2020

Conversation

stsievert
Copy link
Collaborator

@stsievert stsievert commented Jul 16, 2020

What does this PR implement?
It set default values of build_fn to be model parameters. This means model selection searches mirror the Scikit-learn interface.

This PR makes this search possible:

def build_fn(solver="sgd"):
    ...
    model.compile(optimizer="sgd", ...)
    return model

params = {"solver": ["sgd", "adam"]}
model = wrappers.KerasClassifier(build_fn, epochs=2)
assert "solver" in model.get_params()
search = GridSearchCV(model, params)
search.fit(*mnist())

This almost exactly mirrors this Scikit-learn example:

from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV
from sklearn.neural_network import MLPClassifier

X, y = make_classification()
params = {"solver": ["sgd", "adam"]}
model = MLPClassifier(max_iter=2)
assert "solver" in model.get_params()
search = GridSearchCV(model, params)
search.fit(X, y)

Reference issues/PRs
This resolves #18.

@stsievert
Copy link
Collaborator Author

The ScikitLearn API in theory requires all of the tunable parameters to be passed when the model's constructor.

#18 (comment).

I'm not sure what you mean by that. The Scikit-learn API dictates that model parameters be set at initialization. If passed to __init__, they should match the value passed.

@adriangb
Copy link
Owner

adriangb commented Jul 16, 2020

I think the spirit of that requirment that it should be clear what the possible parameters to the model are based on the signature of it's __init__.

These wrappers offer the ability to dynamically add kwargs to the KerasClassifier constructor directly (which is what your example is doing), however that is only really there for backwards compatibility. The same is also achievable by sub-classing the wrappers, which is more verbose and clear as to what the parameters are:

class MyClassifier(KerasClassifier):
        def __init__(self, build_fn=None, solver="sgd"):
            self.solver=solver
            super().__init__(build_fn=build_fn)

I think that if I was writing these from scratch with no backward compatibility, I would only allow use of the wrappers via sub-classing. Ideally, these wrappers are sklearn.base but with added functionality to handle the Keras specific parts.

Full example
from scikeras.wrappers import KerasClassifier

from sklearn.model_selection import GridSearchCV
import tensorflow as tf
from tensorflow.keras.datasets import mnist as keras_mnist
from tensorflow.keras.layers import Dense, Activation, Dropout
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import to_categorical


def mnist():
    (X_train, y_train), _ = keras_mnist.load_data()
    X_train = X_train[:100]
    y_train = y_train[:100]
    X_train = X_train.reshape(X_train.shape[0], 784)
    X_train = X_train.astype("float32")
    X_train /= 255
    Y_train = to_categorical(y_train, 10)
    return X_train, y_train

def build_fn(solver="sgd"):
    layers = [
        Dense(100, input_shape=(784,), activation="relu"),
        Dense(10, input_shape=(100,), activation="softmax"),
    ]
    model = Sequential(layers)
    model.compile(loss="categorical_crossentropy", optimizer=solver, metrics=["accuracy"])
    return model


class MyClassifier(KerasClassifier):

    def __init__(self, build_fn=None, solver="sgd"):
        self.solver=solver
        super().__init__(build_fn=build_fn)

params = {"solver": ["sgd", "adam"]}
model = MyClassifier(build_fn=build_fn)
assert "solver" in model.get_params()
search = GridSearchCV(model, params)
search.fit(*mnist())

@stsievert
Copy link
Collaborator Author

I've re-implemented this PR. All tests pass on my own machine.

@codecov-commenter
Copy link

codecov-commenter commented Aug 13, 2020

Codecov Report

Merging #22 into master will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master      #22   +/-   ##
=======================================
  Coverage   99.52%   99.52%           
=======================================
  Files           3        3           
  Lines         418      425    +7     
=======================================
+ Hits          416      423    +7     
  Misses          2        2           
Impacted Files Coverage Δ
scikeras/_utils.py 98.66% <100.00%> (+0.07%) ⬆️
scikeras/wrappers.py 99.71% <100.00%> (+<0.01%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8d88825...0c680a6. Read the comment docs.

Copy link
Owner

@adriangb adriangb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good! Just minor comments. Thanks for this work.

tests/test_wrappers.py Outdated Show resolved Hide resolved
scikeras/wrappers.py Outdated Show resolved Hide resolved
stsievert and others added 3 commits August 12, 2020 21:37
Co-authored-by: Adrian Garcia Badaracco <[email protected]>
Co-authored-by: Adrian Garcia Badaracco <[email protected]>
@adriangb adriangb merged commit 072b3d1 into adriangb:master Aug 13, 2020
This was referenced Aug 24, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Default parameters of build_fn are not returned by get_params
3 participants