-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DOC: add transition guide & ease fit/predict kwarg compatibility #138
Conversation
Codecov Report
@@ Coverage Diff @@
## master #138 +/- ##
=======================================
Coverage 99.51% 99.51%
=======================================
Files 5 5
Lines 614 623 +9
=======================================
+ Hits 611 620 +9
Misses 3 3
Continue to review full report at Codecov.
|
scikeras/wrappers.py
Outdated
for k, v in kwargs.items(): | ||
warnings.warn( | ||
"``kwargs`` will be removed in a future release of SciKeras." | ||
f"Instead, set fit arguments at initialization (i.e., ``BaseWrapper({k}={v})``)" | ||
) | ||
self.set_params( | ||
**{ | ||
(k if k.startswith("fit__") else "fit__" + k): v | ||
for k, v in kwargs.items() | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stsievert what are your thoughts on these implementations? The main con I see is that subsequent calls will re-use the previous parameters. That is:
est.fit(..., batch_size=32)
...
est.fit(...) # implicitly uses batch_size=32 with no warnings to the user
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👎 on this implementation. As a user, I would expect the params to behave the same unless I change them. Let the user do that, and issue an appropriate warning (like is being done).
PEP 20: "Explicit is better than implicit."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which part of the implementation are you 👎 on? This is basically what you proposed, plus the prefixing of fit__
and predict__
. Adding this prefix makes it behave more like the old implement, in the sense that passing a kwarg to fit
won't "set" it for predict
and visa versa.
Overall, I think there are 3 options:
- We don't support
**kwargs
at all and force users to make the change. - This implementation.
- This implementation + we "unset"/"revert" the kwargs at the end. I've tested this to work, but it is extra code and complexity. That said, this is the only one that is fully backwards compatible and transparent to users.
Edit: I implemented (3) above, but the diff needs some cleanup. I think we should pick between (1) and (3).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrappers.py
diff cleaned up. (3) is looking pretty good now.
scikeras/wrappers.py
Outdated
(k if k.startswith("fit__") else "fit__" + k): v for k, v in kwargs.items() | ||
} | ||
existing_kwargs = {k: v for k, v in self.get_params().items() if k in kwargs} | ||
self.set_params(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👎
I'm still not seeing the need to set parameters. It looks like set_params
is a vehicle to get these kwargs
to self._fit
or self._fit_keras_model
. Why can't kwargs
be passed to those functions and this processing be done there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You were the one who originally suggested this approach in #124 (comment)
We could pass **kwargs
around. Perhaps at this level of complexity/LOC that is a better approach. I'll try implementing it that way instead and see how it turns out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was able to implement this passing **kwargs
around directly. It is indeed simpler and cleaner.
…cikeras into docs/transition-guide
Co-authored-by: Scott Sievert <[email protected]>
Co-authored-by: Scott Sievert <[email protected]>
…cikeras into docs/transition-guide
@stsievert are we ready to merge this? In light of dask/dask-ml#764 (comment) I'd like to cut a patch release, and I think it would be good if this is in it so that it can be a decent long term release that Keras users can switch over into. |
README.md
Outdated
pass your loss function to the constructor: | ||
|
||
```python | ||
clf = KerasClassifier(loss="categorical_crossentropy") # or loss=CategoricalCrossentropy(), etc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clf = KerasClassifier(loss="categorical_crossentropy") # or loss=CategoricalCrossentropy(), etc. | |
clf = KerasClassifier(loss="categorical_crossentropy") |
README.md
Outdated
Or to declare separate values for `fit` and `predict`: | ||
|
||
```python | ||
clf = KerasClassifier(fit__batch_size=32, predict__batch_size=10000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tangential to this PR, I think these would be good defaults.
README.md
Outdated
```diff | ||
- def get_model(my_param=123): | ||
+ def get_model(my_param): # You can optionally remove the default here | ||
... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... | |
... | |
return model |
README.md
Outdated
+ def get_model(my_param): # You can optionally remove the default here | ||
... | ||
- clf = KerasClassifier(get_model) | ||
+ clf = KerasClassifier(get_model, my_param=123) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+ clf = KerasClassifier(get_model, my_param=123) | |
+ clf = KerasClassifier(get_model, my_param=123) # option 1 | |
+ clf = KerasClassifier(get_model, model__my_param=123) # option 2 |
docs/source/index.rst
Outdated
|
||
clf = KerasClassifier(fit__batch_size=32, predict__batch_size=10000) | ||
|
||
Renaming of ``build_fn`` to ``model`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these warnings in index.rst
necessary? I'm having a hard time seeing the reason to feature the following on the front page of the documentation:
- Removal of keyword args
- Renaming
build_fn
tomodel
- Model introspection
All of those issue warnings with tips on how to resolve. Why are they on the front page of the documentation? When I write docs, the front page always describes why a user should use my software, not issues they might face.
I might dedicate a separate page to potential hotspots when transitioning from TF to SciKeras.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah a separate page should work and sounds good. Thanks.
@@ -463,6 +464,8 @@ def _fit_keras_model( | |||
Number of epochs for which the model will be trained. | |||
initial_epoch : int | |||
Epoch at which to begin training. | |||
**kwargs : Dict[str, Any] | |||
Extra arguments to route to ``Model.fit``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What arguments can these be? I would include a link to the Keras API docs here. https://keras.io/api/models/model_training_apis/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current docs are at https://www.tensorflow.org/api_docs/python/tf/keras/Model?version=nightly#fit but yeah, I'll include that.
for k, v in kwargs.items(): | ||
warnings.warn( | ||
"``**kwargs`` has been deprecated in SciKeras 0.2.1 and support will be removed be 1.0.0." | ||
f" Instead, set fit arguments at initialization (i.e., ``BaseWrapper({k}={v})``)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f" Instead, set fit arguments at initialization (i.e., ``BaseWrapper({k}={v})``)" | |
f" Instead, set fit arguments at initialization (e.g, ``BaseWrapper({k}={v})`` " | |
f"or ``BaseWrapper(fit__{k}={v})``)" |
for k, v in kwargs.items(): | ||
warnings.warn( | ||
"``**kwargs`` has been deprecated in SciKeras 0.2.1 and support will be removed be 1.0.0." | ||
f" Instead, set predict arguments at initialization (i.e., ``BaseWrapper({k}={v})``)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f" Instead, set predict arguments at initialization (i.e., ``BaseWrapper({k}={v})``)" | |
f" Instead, set predict arguments at initialization (i.e., ``BaseWrapper({k}={v})`` " | |
f"or ``BaseWrapper(predict__{k}={v})``)" |
for k, v in kwargs.items(): | ||
warnings.warn( | ||
"``**kwargs`` has been deprecated in SciKeras 0.2.1 and support will be removed be 1.0.0." | ||
f" Instead, set predict_proba arguments at initialization (i.e., ``BaseWrapper({k}={v})``)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f" Instead, set predict_proba arguments at initialization (i.e., ``BaseWrapper({k}={v})``)" | |
f" Instead, set predict_proba arguments at initialization (i.e, ``BaseWrapper({k}={v})`` " | |
f"or ``BaseWrapper(predict__{k}={v})``" |
Does that behavior work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, for the purposes of parameter routing predict_proba
uses predict__
since they both call predict__
in Keras. We could alias them if you think that might make things easier. It might also be good to have an error or at least a warning for unknown prefixes. All of this would have to be a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've edited the suggested change to use the predict__
prefix.
@stsievert I incorporated your suggestions manually instead of committing them because I moved all of this content to a |
Closes #124