-
Notifications
You must be signed in to change notification settings - Fork 548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Logistic regression coefficients (for feature importance) significantly differ from Scikit-learn #3645
Comments
cc @avantikalal |
The cuML's QN solver stops much earlier than sklearn for this test case. A few issues with QN stopping condition was spotted by @achirkin while updating the QN solver default parameters. He will fix the issues with the stopping condition and have a look how these affect this issue. |
Change the starting coefficients of the QN model from `ones` to `zeros` for a few reasons: - This behavior matches better sklearn reference implementation - It makes the initial model state to predict all classes with the same probabilities (for both sigmoid and softmax losses) - It makes the model converge faster in some cases In addition, it enables the `warm_start` feature (same as in sklearn). Contributes to solving #3645 Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Dante Gama Dessavre (https://github.com/dantegd) URL: #3774
…sklearn closer (#3766) Change the QN solver (logistic regression) stopping conditions to avoid early stops in some cases (#3645): - primary: ``` || f' ||_inf <= fmag * param.epsilon ``` - secondary: ``` |f - f_prev| <= fmag * param.delta ``` where `fmag = max(|f|, param.epsilon)`. Also change the default value of `tol` in QN solver (which sets `param.delta`) to be consistent (`1e-4`) with the logistic regression solver. #### Background The original primary stopping condition is inconsistent with the sklearn reference implementation and is often triggered too early: ``` || f' ||_2 <= param.epsilon * max(1.0, || x ||_2) ``` Here are the sklearn conditions for reference: - primary: ``` || grad f ||_inf <= gtol ``` - secondary: ``` |f - f_prev| <= ftol * max(|f|, |f_prev|, 1.0) ``` where `gtol` is and exposed parameter like `param.epsilon`, and `ftol = 2.2e-9` (hardcoded). In addition, `f` in sklearn is scaled with the sample size (softmax or sigmoid over the dataset), so it's not exactly comparable to cuML version. Currently, cuML checks the gradient w.r.t. the logistic regression weights `x`. As a result, the tolerance value goes up with the number of classes and features; the model stops too early and stays underfit. This may in part be a reason for #3645. In this proposal I change the stopping condition to be closer to the sklearn version, but compromise the consistency with sklearn for better scaling (tolerance scales with the absolute values of the objective function). Without this scaling sklearn version seems to often run till the maximum iteration limit is reached. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Tamas Bela Feher (https://github.com/tfeher) URL: #3766
- Expose a parameter `delta` of the `QN` solver to control the loss value change stopping condition - Set a reasonable default for the parameter value that should keep the behavior close to sklearn in most cases Note, this change does not expose `delta` to the wrapper class `LogisticRegression`. Note, although this change does not break the python API, it does break the C/C++ API. Contributes to solving #3645 Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Dante Gama Dessavre (https://github.com/dantegd) URL: #3777
As we've anticipated, the issue was in the stopping conditions, forcing cuML solver to stop much earlier than sklearn. With PRs #3766, #3774, #3777, cuML is now much more careful as for when to stop the optimization process. Yet, if you use the default parameters for both variants, some difference still remains:
However, this discrepancy now is due to sklearn's default iteration limit
Here is the convergence plot after the fixes (1000 iterations max). Mind the constant difference of loss functions, it suggests the results must be very close: With this I believe we can close the issue now :) |
…idsai#3774) Change the starting coefficients of the QN model from `ones` to `zeros` for a few reasons: - This behavior matches better sklearn reference implementation - It makes the initial model state to predict all classes with the same probabilities (for both sigmoid and softmax losses) - It makes the model converge faster in some cases In addition, it enables the `warm_start` feature (same as in sklearn). Contributes to solving rapidsai#3645 Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Dante Gama Dessavre (https://github.com/dantegd) URL: rapidsai#3774
…sklearn closer (rapidsai#3766) Change the QN solver (logistic regression) stopping conditions to avoid early stops in some cases (rapidsai#3645): - primary: ``` || f' ||_inf <= fmag * param.epsilon ``` - secondary: ``` |f - f_prev| <= fmag * param.delta ``` where `fmag = max(|f|, param.epsilon)`. Also change the default value of `tol` in QN solver (which sets `param.delta`) to be consistent (`1e-4`) with the logistic regression solver. #### Background The original primary stopping condition is inconsistent with the sklearn reference implementation and is often triggered too early: ``` || f' ||_2 <= param.epsilon * max(1.0, || x ||_2) ``` Here are the sklearn conditions for reference: - primary: ``` || grad f ||_inf <= gtol ``` - secondary: ``` |f - f_prev| <= ftol * max(|f|, |f_prev|, 1.0) ``` where `gtol` is and exposed parameter like `param.epsilon`, and `ftol = 2.2e-9` (hardcoded). In addition, `f` in sklearn is scaled with the sample size (softmax or sigmoid over the dataset), so it's not exactly comparable to cuML version. Currently, cuML checks the gradient w.r.t. the logistic regression weights `x`. As a result, the tolerance value goes up with the number of classes and features; the model stops too early and stays underfit. This may in part be a reason for rapidsai#3645. In this proposal I change the stopping condition to be closer to the sklearn version, but compromise the consistency with sklearn for better scaling (tolerance scales with the absolute values of the objective function). Without this scaling sklearn version seems to often run till the maximum iteration limit is reached. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Tamas Bela Feher (https://github.com/tfeher) URL: rapidsai#3766
- Expose a parameter `delta` of the `QN` solver to control the loss value change stopping condition - Set a reasonable default for the parameter value that should keep the behavior close to sklearn in most cases Note, this change does not expose `delta` to the wrapper class `LogisticRegression`. Note, although this change does not break the python API, it does break the C/C++ API. Contributes to solving rapidsai#3645 Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Dante Gama Dessavre (https://github.com/dantegd) URL: rapidsai#3777
Linking in the initial issue from our
rapids-single-cell-examples
repository: NVIDIA-Genomics-Research/rapids-single-cell-examples#29TLDR; there has been an ongoing discrepancy of the resulting gene rankings between Scikit-learn and cuML. We originally thought a bug fix in cuML's regularization would solve this issue but it doesn't appear to have done so. This use-case specifically relies on regularization for feature ranking.
Here's a small reproducer:
This code is executed on highly variable genes, so there shouldn't be much correlation at all between features. I've attached the datasets to run the example. The input has also been centered and normalized into z-scores.
Here's the output for
penalty='l2'
. The last line contains the elements that differed between the two rankings. I've also tried setting the regularization weight to a few different values without much benefit to the corresponding rankings.Here are the correlations between the coefficients for each class:
Output:
Here's the data files to run the MRE: https://drive.google.com/file/d/1SU5EVP8Om0Q7ZBo9ifkk4AYlIvWNTO32/view?usp=sharing
The text was updated successfully, but these errors were encountered: