-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: time to simplify APIs? #37
Comments
Tagging @stsievert and @sim-san because your feedback has been very valuable lately. |
Thanks for the notification @adriangb. I have a couple thoughts. Currently, there are 4 ways to create a SciKeras model:
Why are all these methods necessary? What features does each one provide that other options don't? For example, (3) is necessary because it allows the use an already trained model. I'm not seeing the additional features that (2) and (4) provide. If they don't provide any additional features I think they should be removed, or not mentioned in the documentation.
👍 I would encourage those steps. I might use this interface: class BaseWrapper(BaseEstimator):
def __init__(self, build_fn=None, batch_size=32, epochs=100):
self.build_fn = build_fn
self.batch_size = batch_size
self.epochs = epochs Some of these are repeated for def fit(self, X, y=None, batch_size=None):
kwargs = {"batch_size": batch_size or self.batch_size, ...}
...
👎 I am not a fan of the subclassing interface, at least when parameters are defined in from sklearn.base import BaseEstimator # sklearn.__version__ == "0.23.2"
class Foo(BaseEstimator):
def __init__(self, x=1):
self.x = x
super().__init__()
class Bar(Foo):
def __init__(self, y=2, **kwargs):
self.y = y
super().__init__(**kwargs)
if __name__ == "__main__":
bar = Bar(x=2)
bar.set_params(x=3) # raises ValueError: "Invalid parameter x for estimator Bar() ..." I think the subclassing interface should still exist in case the user wants to overwrite another function (e.g, |
Thank you for the feedback @stsievert
(1), (2) and (4) came from the TF/Keras team (I added (3) ). Unfortunately, we do not have any design docs that explain the rational for all of these features. (4) I think it has something to do with Keras' internal use of (2) This method is what I initially used to get full compatibility with the scikit-learn estimator checks (
I think that could be easily solved by explicitly passing the parameters, but I do see how that could be tedious if you are trying to just add a single parameter. We could also add a small hack to I think the following is a good argument in favor of keeping subclassing. Please correct me if you see an alternative that I am not seeing. On the other hand, we could eliminate the def _keras_build_fn(self, ...):
return build_fn(...) # where build_fn is a function that you would have previously passed to `__init__` But eliminating I think maybe a reasonable first step would be to proceed with the steps we do agree on and then loop back here as we come up with more comments regarding subclassing/ |
Thanks for the mentioning @adriangb
I completely agree with you here.
I also like the possibility to first test ideas with the |
@sim-san how would you feel if
|
I don't understand what you mean exactly.
estimator = MLPRegressor(hidden_layer_sizes=[200], a_kwarg="saveme")
estimator.a_kwarg == "saveme" # True You mean this behavior ? I never used it. |
Yes, that behavior is what would not work anymore. |
This is fine for me |
To be clear on what removing def build_fn(hidden_layer_sizes=(100, )):
pass Case 1:
|
I think the subclassing interface should stay. It might be removed from the documentation, but at the least, there should still be a note that "subclassing SciKeras classes will allow overwriting of
I think SciKeras needs to roll it's own implementation of
I think that's a key use-case. I think SciKeras should support these features without subclassing:
I think this is possible, but I think it will require a custom implementation of
estimator = KerasRegressor(
build_fn=build_fn,
hidden_layer_sizes=(200, ),
batch_size=64,
)
|
We already roll our own def _get_param_names(self):
parameters = super()._get_param_names()
parameters.extend(
inspect.signature(self._check_build_fn).parameters.keys()
)
return parameters That said, I may be missing something, but I don't see how this would be fully Scikit-Learn compliant. The tests essentially call The alternative that avoids overriding
I think the main tradeoff here is complexity of implementation on our side and complexity of the docs vs. complexity for users. I feel that if the docs are clear and the implementation is clear for anyone that wants to inspect it, it should not be too much to ask for more users to use the "advanced" interface than previously. |
While we continue with this discussion, I want to say again that I really appreciate all of the great feedback. Another proposal would be to do something similar to what skorch does: instead of dynamically inspecting the signature of I think this is nice in the sense that it keeps the flexibility offered by the current |
I see a couple issues around API simplification:
The check marks represent the state of SciKeras the last time I used it. I think all of these points should be implemented. On the 2nd point: it's a fairly simple implementation, and then the user doesn't have to look at the documentation when their hyperparameter search fails. When that happened for me, I had to look at the source code before filing #18. I'm not seeing why it's necessary to write custom code via subclassing to obtain a Scikit-Learn compatible Keras model. I think it should be used if needed to overwrite
I think that'd be the simplest solution. Note: Two other points that came up while writing this comment:
|
Going from the bottom up on this one:
Would you be willing to continue work on #18 and get a fully working example of that passes all of the tests without subclassing, and I will start tackling (1) cleaning up the pre/post process functions, (2) implementing the parameter routing and (3) hardcoding the |
Does "passes all tests" mean "all Scikit-Learn estimator checks pass except check_no_attributes_set_in_init"? I don't see how that test can pass: I think parameters need to be set in |
Yes, I didn't remember exactly which ones would not pass, but Thank you for helping! |
Background
Currently, this package has many ways to:
build_fn
/_keras_build_fn
.fit
orpredict
.This comes from a combination of supporting the original
tf.keras.wrappers.scikit_learn
interface along with introduction of new ways to do things to improve Scikit-Learn compability.Important principles
Since this aims to be a small pure Python package, I think it is important to keep in mind some of Python's guiding principles (cherry picked from PEP20):
For the most part, the APIs mentioned above rely on dynamic parsing of function signatures and filtering of kwargs or attributes by name. This is somewhat "complex" and "implicit" in my opinion.
Next steps
I feel (and here is where I would appreciate some feedback from others) that it would be good to fully document the requirements that these wrappers have and then narrow down the API to be as simple as possible while still meeting all of those requirements. Off the top of my head:
fit
andpredict
methods.fit
, or can we require that they be set from__init__
? As far as I can tell, only the latter makes sense as far as Scikit-Learn is concerned.Based on the above reqs (which admittedly could be shortsighted), it seems reasonable to me to take the following action:
**kwargs
fromfit
andpredict
and require that these parameters be set via__init__
.**kwargs
from__init__
and instead hardcode Keras Modelfit
andpredict
parameters (maybe alsocompile
parameters?) as proposed in [feature request] add keras.fit parameters as attributes to BaseWrapper #30.A more extreme step would be to remove the
build_fn
argument and force users to always use the subclassing interface since it is technically even possible to return a pre-built model from_keras_build_fn
via a closure or other methods. This would greatly simplify the API but I worry that it would be an inconvenience for users (even if it is just a couple more lines of code).All in all I hope to reduce codebase complexity and simplify documentation. Any comments are welcome.
The text was updated successfully, but these errors were encountered: