-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[python-package] Correctly recognize LGBMClassifier(num_class=2, objective="multiclass") as multiclass classification #6524
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, could you please add a unit test covering this behavior, to be sure it isn't broken by future refactorings?
@microsoft-github-policy-service agree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks very much.
#6519 is reporting the same root cause as #3636, and the solution you've arrived at here looks very similar to what @jmoralez recommended there (#3636 (comment)). There he recommended using result.ndim
instead of len(result.shape)
... I don't have a strong opinion on one of those being preferable to another. I think the len(result.shape)
approach you've taken is fine.
Re-reading #3636 though, I think there is one other place where self._n_classes > 2
is used to decide whether binary of multiclass classification is being performed, and which needs to be changed.
@rabyj mentioned something related to metrics at #3636 (comment), which I think is about this behavior where LGBMClassifier
rewrites binary classification metrics to multiclass metrics and vice-versa:
LightGBM/python-package/lightgbm/sklearn.py
Line 1254 in a5054f7
if self._n_classes > 2: |
This is making me think that maybe LGBMClassifier
should stop relying on self._n_classes > 2
as a proxy for "is doing multiclass classification".
What do you think about adding a private property on LGBMClassifier
that has something like this?
@property
def __is_multiclass(self) -> bool:
multiclass_objectives = {"multiclass", "softmax", "multiclassova", "multiclass_ova", "ova", "ovr"}
return (
self._n_classes > 2
or isinstance(self._objective, str) and self._objective in multiclass_objectives
)
And then using that in place of these self._n_classes > 2
conditions, where appropriate, e.g.
if self.__is_multiclass:
# do multiclass classification things
else:
# do binary classification things
If you agree, then please also add a case within this test checking that the behavior is correct when using num_class=2, objective="multiclass"
:
def test_metrics(): |
Sorry this is a lot... there is a lot of indirection in the scikit-learn
estimator, some of it imposed by scikit-learn
itself, so seemingly simple things can become complicated fast 😅
Other references you might find relevant:
Thanks for review, @jameslamb. As you mentioned, I added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New tests look great, thanks very much for your attention to detail there making them minimal and similar to the other existing tests! And for fixing typos in comments, much appreciated.
I just have a few more small suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks very much! The changes and tests look great to me.
I left one more very very very minor suggestion.
Don't worry about the failing CI job (https://github.com/microsoft/LightGBM/actions/runs/9852647887/job/27218413942?pr=6524). I think that's the result of some recent change in the numpy
or pandas
nightlies, not anything done in this PR.
@jmoralez could you also review? I'm not 100% confident I've thought of all the implications of this change.
fixes #6519
fixes #3636
I created a pull request addressing issue #6519, which I've also encountered.
I modified the code to use
np.vstack
only whenresult.shape
look like(num_data,)
, and some typos.