-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
REF/ENH: add parameter routing #67
Conversation
Codecov Report
@@ Coverage Diff @@
## master #67 +/- ##
==========================================
- Coverage 99.76% 99.56% -0.20%
==========================================
Files 3 3
Lines 432 465 +33
==========================================
+ Hits 431 463 +32
- Misses 1 2 +1
Continue to review full report at Codecov.
|
cbc816e
to
be9814f
Compare
scikeras/wrappers.py
Outdated
_fit_params = { | ||
# parameters destined to keras.Model.fit | ||
"callbacks", | ||
"batch_size", | ||
"epochs", | ||
"verbose", | ||
"callbacks", | ||
"validation_split", | ||
"shuffle", | ||
"class_weight", | ||
"sample_weight", | ||
"initial_epoch", | ||
"validation_steps", | ||
"validation_batch_size", | ||
"validation_freq", | ||
} | ||
|
||
_predict_params = { | ||
# parameters destined to keras.Model.predict | ||
"batch_size", | ||
"verbose", | ||
"callbacks", | ||
"steps", | ||
} | ||
|
||
_compile_params = { | ||
# parameters destined to keras.Model.compile | ||
"optimizer", | ||
"loss", | ||
"metrics", | ||
"loss_weights", | ||
"weighted_metrics", | ||
"run_eagerly", | ||
} | ||
|
||
_wrapper_params = { | ||
# parameters consumed by the wrappers themselves | ||
"warm_start", | ||
"random_state", | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's some testing to make sure that these are a subset of keras.Model's parameters. In the future, once we've figured out the usage of class_weight
and others, we should add tests that also check that the default initializer for the wrappers also accepts all of these (except for sample_weight
and such).
I think this PR is ready for review now. I'll try to provide a review in the next couple days. I would say this PR also closes #18 again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, this review looks good. I'll be glad to see this PR merged: I think it'll enable easier usage.
Here's a first pass at some comments. I am having a hard time reviewing this PR because of the length. Could you add some examples of the usage, either in the documentation or in a comment? For me, writing documentation typically leads to API improvements.
I have a couple questions:
- What if I pass
optimizer__momentum=0.9
andoptimizer="sgd"
? Right now, it looks like I'll have to create the Keras optimizer myself and pass that tomodel.compile
. If so, I think work for a future PR is creating the optimizer inBaseWrapper
and passing that to theoptimizer
key. - Is this PR backwards compatible for basic usage? It doesn't look like the tests changed much for the basic uses.
Thank you for the review!
It's not a bad idea to start working on the docs along with this. I'm just afraid that that will make a large PR even larger. Would examples from the tests help? I can cherry pick some.
I agree, but I think that would come after #66
Sort of. It breaks if your |
@stsievert I addressed several of the comments, leaving two open so you have a chance to look at those again. Do you still want a couple of examples for this API / do you want to write docs as part of this PR? |
That'd be great. I think that'd allow me to more easily review this PR, especially if the examples are in |
What would you think of using |
That actually might be an indicator that Skorch has the same interface with their |
Agreed, I think this should be one of the "requirements" for this project.
Sure, more than welcome! For this issue in particular, I think if we want to keep the original API, I would like the implications of keeping only that API. I don't want two ways to do things unless there is a clear use case where one of them works and one of them doesn't. The only issue I can think of with keeping only the original API (not using
|
The workaround I have in mind involves setting a (hidden) parameter at initialization to keep track of the parameters that should be routed to class BaseWrapper:
def __init__(self, model=None, ..., **keras_params):
self.model = model
...
vars(self).update(**sk_params)
self._keras_params = set(keras_params)
def _model_params(self):
return {k[len("model__"):]: v
for k in self.get_params()
if "model__" in k[:len("model__")] or k in self._keras_params}
Why would this usage fail with the above implementation? def model_build_fn(hidden=10):
...
return model
class CustomKerasRegressor(KerasRegressor):
def __init__(self, new_param=1, **kwargs):
self.new_param = new_param
super().__init__(**kwargs)
... # use new_param in fit/score/etc With this, I think both of these usages would work: est = CustomKerasRegressor(model=model_build_fn, hidden=20)
est.fit(X, y).score(X, y) # say X, y defined somewhere
est2 = CustomKerasRegressor(model=model_build_fn, model__hidden=30)
est.fit(X, y).score(X, y) |
Funny enough, that's how the original implementation worked (sort if, it just stored One minor improvement: if there are no kwargs, don't set the attribute, which allows this test to pass. class BaseWrapper:
def __init__(self, model=None, ..., **kwargs):
self.model = model
...
vars(self).update(**kwargs)
if kwargs:
self._init_kwargs = set(kwargs)
def _model_params(self):
return {k[len("model__"):]: v
for k in self.get_params()
if "model__" == k[:len("model__")] or k in getattr(self, "_init_kwargs", set())} I think with this + introspecting into the parameters of |
params = {"model__foo": object()} | ||
destination = "model" | ||
pass_filter = set() | ||
out = route_params(params, destination, pass_filter) | ||
assert out["foo"] is params["model__foo"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stsievert added your test here. we should probably add some more checks here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think with this
By "this" you mean "keeping backwards compatibility with the Keras API and allowing the prefix model__
"?
introspecting into the parameters of
model_build_fn
we should be good.
I think introspection should only be performed for three parameters: meta
, params
and compile_kwargs
.
Do you think it will be clear how model_build_fn can "request" arguments from the set self._init_kwargs?
I think only meta
, params
and compile_kwargs
should be able to be "requested." I think the documentation surrounding that is very clear:
Arguments to
model_build_fn
include any parameter with amodel__
prefix and parameters provided at initialization. In addition, ifmodel_build_fn
accepts keyword arguments formeta
,params
orcompile_kwargs
the relevant dictionaries will be provided, described below: [list].
That means I think this code should raise a value error because model_build_fn
does not accept an argument bar
:
def build(foo=42):
return _get_model(foo)
BaseWrapper(model=build, model__bar=42) # fails; bar not a valid kwarg to `build`
BaseWrapper(model=build, bar=42) # fails; bar not a valid kwarg to `build`
Yes Regarding the rest:
I originally thought that the Keras API allowed this: from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
def build_fn(param1=0):
...
clf = KerasClassifier(build_fn=build_fn, param1=1, param2=3)
clf.fit(...) But upon testing, it actually does raise an error: from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
def build_fn(param1=0):
...
clf = KerasClassifier(build_fn=build_fn, param1=1, param2=3) # error
clf.fit(...) Incidentally, it raises an error from This confirms what I guess you were assuming or knew, which is that raising an error is compatible with the old API. That being the case, I am 100% +1 for doing the same. Sorry for the confusion otherwise. |
…ename xyz_params -> xyz_kwargs for meta, fit and predict
Thank you for this last round of review @stsievert. I added the two pending comments/items to the OP to keep things a bit organized. |
I'll need at least one more review; I'll try to provide one this weekend. |
Great, TY for your help thus far |
@stsievert are you able to take another look at this? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ping; this slipped off my radar.
I expected the optimizer
key in compile_kwargs
to be the rendered optimizer. I expected this test to pass:
from typing import Any, Dict
from sklearn.datasets import make_classification
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
import tensorflow.keras.optimizers as opt
from scikeras.wrappers import KerasClassifier
def get_model(num_hidden=10, meta=None, compile_kwargs=None):
inp = Input(shape=(meta["n_features_in_"],))
hidden = Dense(num_hidden, activation="relu")(inp)
out = [Dense(1, activation="sigmoid")(hidden)]
model = Model(inp, out)
assert not isinstance(compile_kwargs["optimizer"], str)
model.compile(**compile_kwargs)
return model
if __name__ == "__main__":
est = KerasClassifier(
model=get_model,
model__num_hidden=20,
optimizer=opt.SGD,
optimizer__learning_rate=0.15,
optimizer__momentum=0.5,
loss="binary_crossentropy",
)
X, y = make_classification()
est.fit(X, y)
That will be the work of #66, too much for this PR in my opinion. That said, I did edit the tests to at least use the |
Closes #50, closes #49, closes #37
Todo:
keras_expected_n_ouputs_
(REF/ENH: add parameter routing #67 (comment))