Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF/ENH: add parameter routing #67

Merged
merged 30 commits into from
Sep 15, 2020
Merged

REF/ENH: add parameter routing #67

merged 30 commits into from
Sep 15, 2020

Conversation

adriangb
Copy link
Owner

@adriangb adriangb commented Aug 25, 2020

Closes #50, closes #49, closes #37

Todo:

@codecov-commenter
Copy link

codecov-commenter commented Aug 25, 2020

Codecov Report

Merging #67 into master will decrease coverage by 0.19%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #67      +/-   ##
==========================================
- Coverage   99.76%   99.56%   -0.20%     
==========================================
  Files           3        3              
  Lines         432      465      +33     
==========================================
+ Hits          431      463      +32     
- Misses          1        2       +1     
Impacted Files Coverage Δ
scikeras/_utils.py 98.91% <100.00%> (+0.21%) ⬆️
scikeras/wrappers.py 99.72% <100.00%> (-0.28%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 250837e...1a38997. Read the comment docs.

@adriangb adriangb changed the title initial attempt at param routing REF/ENH: add parameter routing Aug 25, 2020
@adriangb adriangb mentioned this pull request Aug 25, 2020
tests/test_api.py Outdated Show resolved Hide resolved
Comment on lines 68 to 107
_fit_params = {
# parameters destined to keras.Model.fit
"callbacks",
"batch_size",
"epochs",
"verbose",
"callbacks",
"validation_split",
"shuffle",
"class_weight",
"sample_weight",
"initial_epoch",
"validation_steps",
"validation_batch_size",
"validation_freq",
}

_predict_params = {
# parameters destined to keras.Model.predict
"batch_size",
"verbose",
"callbacks",
"steps",
}

_compile_params = {
# parameters destined to keras.Model.compile
"optimizer",
"loss",
"metrics",
"loss_weights",
"weighted_metrics",
"run_eagerly",
}

_wrapper_params = {
# parameters consumed by the wrappers themselves
"warm_start",
"random_state",
}
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some testing to make sure that these are a subset of keras.Model's parameters. In the future, once we've figured out the usage of class_weight and others, we should add tests that also check that the default initializer for the wrappers also accepts all of these (except for sample_weight and such).

@adriangb adriangb marked this pull request as ready for review August 25, 2020 23:04
@stsievert
Copy link
Collaborator

I think this PR is ready for review now. I'll try to provide a review in the next couple days. I would say this PR also closes #18 again.

Copy link
Collaborator

@stsievert stsievert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, this review looks good. I'll be glad to see this PR merged: I think it'll enable easier usage.

Here's a first pass at some comments. I am having a hard time reviewing this PR because of the length. Could you add some examples of the usage, either in the documentation or in a comment? For me, writing documentation typically leads to API improvements.

I have a couple questions:

  • What if I pass optimizer__momentum=0.9 and optimizer="sgd"? Right now, it looks like I'll have to create the Keras optimizer myself and pass that to model.compile. If so, I think work for a future PR is creating the optimizer in BaseWrapper and passing that to the optimizer key.
  • Is this PR backwards compatible for basic usage? It doesn't look like the tests changed much for the basic uses.

scikeras/wrappers.py Outdated Show resolved Hide resolved
scikeras/wrappers.py Outdated Show resolved Hide resolved
tests/test_api.py Outdated Show resolved Hide resolved
scikeras/wrappers.py Outdated Show resolved Hide resolved
scikeras/_utils.py Outdated Show resolved Hide resolved
tests/test_param_routing.py Outdated Show resolved Hide resolved
tests/test_param_routing.py Outdated Show resolved Hide resolved
@adriangb
Copy link
Owner Author

adriangb commented Aug 26, 2020

Thank you for the review!

Could you add some examples of the usage, either in the documentation or in a comment? For me, writing documentation typically leads to API improvements.

It's not a bad idea to start working on the docs along with this. I'm just afraid that that will make a large PR even larger. Would examples from the tests help? I can cherry pick some.

  • What if I pass optimizer__momentum=0.9 and optimizer="sgd"? Right now, it looks like I'll have to create the Keras optimizer myself and pass that to model.compile. If so, I think work for a future PR is creating the optimizer in BaseWrapper and passing that to the optimizer key.

I agree, but I think that would come after #66

  • Is this PR backwards compatible for basic usage? It doesn't look like the tests changed much for the basic uses.

Sort of. It breaks if your build_fn was previously expecting parameters like n_outputs_ but it will be backwards compatible as long as build_fn was only expecting build parameters (like hidden_layer_sizes).

@stsievert stsievert mentioned this pull request Aug 27, 2020
17 tasks
@adriangb
Copy link
Owner Author

@stsievert I addressed several of the comments, leaving two open so you have a chance to look at those again. Do you still want a couple of examples for this API / do you want to write docs as part of this PR?

@stsievert
Copy link
Collaborator

Do you still want a couple of examples for this API / do you want to write docs as part of this PR?

That'd be great. I think that'd allow me to more easily review this PR, especially if the examples are in .rst files.

@adriangb
Copy link
Owner Author

What would you think of using model_builder or something other than just model to replace build_fn? I started working on docs and felt that it might be a bit confusing with model and Model.

@stsievert
Copy link
Collaborator

stsievert commented Aug 29, 2020

I started working on docs and felt that it might be a bit confusing with model and Model.

That actually might be an indicator that model is a good name for the function because it returns a Keras Model. If you keep the name (I would), I would refer to the two in distinct ways, probably as "the Keras Model" and "the model parameter." I think the capitalization and monospace text is enough to distinguish the two.

Skorch has the same interface with their module parameter and PyTorch Modules.

@adriangb
Copy link
Owner Author

adriangb commented Aug 31, 2020

only an import statement needs to be changed

Agreed, I think this should be one of the "requirements" for this project.

If you'd like, I can mention some workarounds I have in mind.

Sure, more than welcome!

For this issue in particular, I think if we want to keep the original API, I would like the implications of keeping only that API. I don't want two ways to do things unless there is a clear use case where one of them works and one of them doesn't. The only issue I can think of with keeping only the original API (not using model__) is with subclassed models. I fear that if users want to add an arbitrary parameter, they would be forced to also accept this parameter in build_fn. If that's the case, I see two ways around it:

  • keep the introspection into build_fn's arguments
  • somehow make it so that if the user overrides __inti__ then we don't route all parameters. but I don't see how to do this cleanly

@stsievert
Copy link
Collaborator

some workarounds I have in mind.

The workaround I have in mind involves setting a (hidden) parameter at initialization to keep track of the parameters that should be routed to model_build_fn:

class BaseWrapper:
    def __init__(self, model=None, ..., **keras_params):
        self.model = model
        ...
        vars(self).update(**sk_params)
        self._keras_params = set(keras_params)

    def _model_params(self):
        return {k[len("model__"):]: v
                for k in self.get_params()
                if "model__" in k[:len("model__")] or k in self._keras_params}

I fear that if users want to add an arbitrary parameter, they would be forced to also accept this parameter in build_fn.

Why would this usage fail with the above implementation?

def model_build_fn(hidden=10):
    ...
    return model

class CustomKerasRegressor(KerasRegressor):
    def __init__(self, new_param=1, **kwargs):
        self.new_param = new_param
        super().__init__(**kwargs)

    ... # use new_param in fit/score/etc

With this, I think both of these usages would work:

est = CustomKerasRegressor(model=model_build_fn, hidden=20)
est.fit(X, y).score(X, y)  # say X, y defined somewhere

est2 = CustomKerasRegressor(model=model_build_fn, model__hidden=30)
est.fit(X, y).score(X, y)

scikeras/wrappers.py Outdated Show resolved Hide resolved
scikeras/_utils.py Outdated Show resolved Hide resolved
scikeras/wrappers.py Outdated Show resolved Hide resolved
scikeras/wrappers.py Show resolved Hide resolved
@adriangb
Copy link
Owner Author

adriangb commented Sep 1, 2020

The workaround I have in mind involves setting a (hidden) parameter at initialization to keep track of the parameters that should be routed to model_build_fn

Funny enough, that's how the original implementation worked (sort if, it just stored self.kwargs = kwargs, which did not work at all with the scikit-learn API and broke get_params, etc.). I then changed it to store a set and unpack kwargs into self in my first commit.

One minor improvement: if there are no kwargs, don't set the attribute, which allows this test to pass.

class BaseWrapper:
    def __init__(self, model=None, ..., **kwargs):
        self.model = model
        ...
        vars(self).update(**kwargs)
        if kwargs:
            self._init_kwargs = set(kwargs)

    def _model_params(self):
        return {k[len("model__"):]: v
                for k in self.get_params()
                if "model__" == k[:len("model__")] or k in getattr(self, "_init_kwargs", set())}

I think with this + introspecting into the parameters of model_build_fn we should be good. Do you think it will be clear how model_build_fn can "request" arguments from the set self._init_kwargs?

Comment on lines +17 to +21
params = {"model__foo": object()}
destination = "model"
pass_filter = set()
out = route_params(params, destination, pass_filter)
assert out["foo"] is params["model__foo"]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stsievert added your test here. we should probably add some more checks here.

Copy link
Collaborator

@stsievert stsievert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with this

By "this" you mean "keeping backwards compatibility with the Keras API and allowing the prefix model__"?

introspecting into the parameters of model_build_fn we should be good.

I think introspection should only be performed for three parameters: meta, params and compile_kwargs.

Do you think it will be clear how model_build_fn can "request" arguments from the set self._init_kwargs?

I think only meta, params and compile_kwargs should be able to be "requested." I think the documentation surrounding that is very clear:

Arguments to model_build_fn include any parameter with a model__ prefix and parameters provided at initialization. In addition, if model_build_fn accepts keyword arguments for meta, params or compile_kwargs the relevant dictionaries will be provided, described below: [list].

That means I think this code should raise a value error because model_build_fn does not accept an argument bar:

def build(foo=42):
    return _get_model(foo)

BaseWrapper(model=build, model__bar=42)  # fails; bar not a valid kwarg to `build`
BaseWrapper(model=build, bar=42)  # fails; bar not a valid kwarg to `build`

scikeras/wrappers.py Outdated Show resolved Hide resolved
scikeras/wrappers.py Outdated Show resolved Hide resolved
scikeras/wrappers.py Outdated Show resolved Hide resolved
scikeras/wrappers.py Outdated Show resolved Hide resolved
scikeras/wrappers.py Outdated Show resolved Hide resolved
@adriangb
Copy link
Owner Author

adriangb commented Sep 2, 2020

By "this" you mean "keeping backwards compatibility with the Keras API and allowing the prefix model__"?

Yes

Regarding the rest:

That means I think this code should raise a value error because model_build_fn does not accept an argument bar:

def build(foo=42):
    return _get_model(foo)

BaseWrapper(model=build, model__bar=42)  # fails; bar not a valid kwarg to `build`
BaseWrapper(model=build, bar=42)  # fails; bar not a valid kwarg to `build`

I originally thought that the Keras API allowed this:

from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
def build_fn(param1=0):
    ...
clf = KerasClassifier(build_fn=build_fn, param1=1, param2=3)
clf.fit(...)

But upon testing, it actually does raise an error:

from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
def build_fn(param1=0):
    ...
clf = KerasClassifier(build_fn=build_fn, param1=1, param2=3)  # error
clf.fit(...)

Incidentally, it raises an error from __init__, which as you suggested is wrong.

This confirms what I guess you were assuming or knew, which is that raising an error is compatible with the old API. That being the case, I am 100% +1 for doing the same. Sorry for the confusion otherwise.

…ename xyz_params -> xyz_kwargs for meta, fit and predict
@adriangb
Copy link
Owner Author

adriangb commented Sep 2, 2020

Thank you for this last round of review @stsievert. I added the two pending comments/items to the OP to keep things a bit organized.

@stsievert
Copy link
Collaborator

I'll need at least one more review; I'll try to provide one this weekend.

@adriangb
Copy link
Owner Author

adriangb commented Sep 2, 2020

Great, TY for your help thus far

@adriangb adriangb mentioned this pull request Sep 3, 2020
@adriangb
Copy link
Owner Author

@stsievert are you able to take another look at this? Thanks.

Copy link
Collaborator

@stsievert stsievert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping; this slipped off my radar.

I expected the optimizer key in compile_kwargs to be the rendered optimizer. I expected this test to pass:

from typing import Any, Dict
from sklearn.datasets import make_classification
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
import tensorflow.keras.optimizers as opt
from scikeras.wrappers import KerasClassifier

def get_model(num_hidden=10, meta=None, compile_kwargs=None):
    inp = Input(shape=(meta["n_features_in_"],))
    hidden = Dense(num_hidden, activation="relu")(inp)
    out = [Dense(1, activation="sigmoid")(hidden)]
    model = Model(inp, out)

    assert not isinstance(compile_kwargs["optimizer"], str)
    model.compile(**compile_kwargs)
    return model

if __name__ == "__main__":
    est = KerasClassifier(
        model=get_model,
        model__num_hidden=20,
        optimizer=opt.SGD,
        optimizer__learning_rate=0.15,
        optimizer__momentum=0.5,
        loss="binary_crossentropy",
    )
    X, y = make_classification()
    est.fit(X, y)

README.md Show resolved Hide resolved
README.md Show resolved Hide resolved
scikeras/_utils.py Show resolved Hide resolved
@adriangb
Copy link
Owner Author

adriangb commented Sep 15, 2020

I expected the optimizer key in compile_kwargs to be the rendered optimizer. I expected this test to pass:

That will be the work of #66, too much for this PR in my opinion.

That said, I did edit the tests to at least use the model.compile(**compile_kwargs) syntax since it works to pass the kwargs, but it won't compile the optimizer like in your example.

@stsievert
Copy link
Collaborator

stsievert commented Sep 15, 2020

That will be the work of #66, too much for this PR in my opinion.

👍 I don't see reason to hold back on merge; nothing immediately jumps out. #73 and #66 probably deserve more attention now.

@adriangb adriangb merged commit 1c045dc into master Sep 15, 2020
@adriangb adriangb deleted the param-routing branch September 15, 2020 21:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Parameter routing REF: Deprecate callable class as build_fn RFC: time to simplify APIs?
3 participants