Skip to content

Commit

Permalink
Improve QN solver stopping conditions (logistic regression) to match …
Browse files Browse the repository at this point in the history
…sklearn closer (#3766)

Change the  QN solver (logistic regression) stopping conditions to avoid early stops in some cases (#3645):
  - primary: 
    ```
    || f' ||_inf <= fmag * param.epsilon
    ```
  - secondary:
    ```
    |f - f_prev| <= fmag * param.delta
    ```
where `fmag = max(|f|, param.epsilon)`.

Also change the default value of `tol` in QN solver  (which sets `param.delta`) to be consistent (`1e-4`) with the logistic regression solver.


#### Background

The original primary stopping condition is inconsistent with the sklearn reference implementation and is often triggered too early:
```
|| f' ||_2 <= param.epsilon * max(1.0, || x ||_2)
```

Here are the sklearn conditions for reference:
  - primary: 
    ```
    || grad f ||_inf <= gtol
    ```
  - secondary:
    ```
    |f - f_prev| <= ftol * max(|f|, |f_prev|, 1.0)
    ```
where `gtol` is and exposed parameter like `param.epsilon`, and `ftol = 2.2e-9` (hardcoded).
In addition, `f` in sklearn is scaled with the sample size (softmax or sigmoid over the dataset), so it's not exactly comparable to cuML version.

Currently, cuML checks the gradient w.r.t. the logistic regression weights `x`. As a result, the tolerance value goes up with the number of classes and features; the model stops too early and stays underfit. This may in part be a reason for #3645.
In this proposal I change the stopping condition to be closer to the sklearn version, but compromise the consistency with sklearn for better scaling (tolerance scales with the absolute values of the objective function). Without this scaling sklearn version seems to often run till the maximum iteration limit is reached.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #3766
  • Loading branch information
achirkin authored Apr 23, 2021
1 parent 0a21e87 commit 6584bbf
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 35 deletions.
9 changes: 2 additions & 7 deletions cpp/src/glm/qn/qn_solvers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,11 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam<T> &param,

// Evaluate function and compute gradient
fx = f(x, grad, dev_scalar, stream);
T xnorm = nrm2(x, dev_scalar, stream);
T gnorm = nrm2(grad, dev_scalar, stream);

if (param.past > 0) fx_hist[0] = fx;

// Early exit if the initial x is already a minimizer
if (gnorm <= param.epsilon * std::max(xnorm, T(1.0))) {
if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) {
CUML_LOG_DEBUG("Initial solution fulfills optimality condition.");
return OPT_SUCCESS;
}
Expand Down Expand Up @@ -255,13 +253,10 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam<T> &param, Function &f,
// pseudo.assign_binary(x, grad, pseudo_grad);
update_pseudo(x, grad, pseudo_grad, pg_limit, pseudo, stream);

T xnorm = nrm2(x, dev_scalar, stream);
T gnorm = nrm2(pseudo, dev_scalar, stream);

if (param.past > 0) fx_hist[0] = fx;

// Early exit if the initial x is already a minimizer
if (gnorm <= param.epsilon * std::max(xnorm, T(1.0))) {
if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) {
CUML_LOG_DEBUG("Initial solution fulfills optimality condition.");
return OPT_SUCCESS;
}
Expand Down
20 changes: 11 additions & 9 deletions cpp/src/glm/qn/qn_util.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020, NVIDIA CORPORATION.
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -120,21 +120,23 @@ inline bool check_convergence(const LBFGSParam<T> &param, const int k,
const T fx, SimpleVec<T> &x, SimpleVec<T> &grad,
std::vector<T> &fx_hist, T *dev_scalar,
cudaStream_t stream) {
// New x norm and gradient norm
T xnorm = nrm2(x, dev_scalar, stream);
T gnorm = nrm2(grad, dev_scalar, stream);

CUML_LOG_DEBUG("%04d: f(x)=%.8f conv.crit=%.8f (gnorm=%.8f, xnorm=%.8f)", k,
fx, gnorm / std::max(T(1), xnorm), gnorm, xnorm);
// Gradient norm is now in Linf to match the reference implementation
// (originally it was L2-norm)
T gnorm = nrmMax(grad, dev_scalar, stream);
// Positive scale factor for the stop condition
T fmag = std::max(fx, param.epsilon);

CUML_LOG_DEBUG("%04d: f(x)=%.8f conv.crit=%.8f (gnorm=%.8f, fmag=%.8f)", k,
fx, gnorm / fmag, gnorm, fmag);
// Convergence test -- gradient
if (gnorm <= param.epsilon * std::max(xnorm, T(1.0))) {
if (gnorm <= param.epsilon * fmag) {
CUML_LOG_DEBUG("Converged after %d iterations: f(x)=%.6f", k, fx);
return true;
}
// Convergence test -- objective function value
if (param.past > 0) {
if (k >= param.past &&
std::abs((fx_hist[k % param.past] - fx) / fx) < param.delta) {
std::abs(fx_hist[k % param.past] - fx) <= param.delta * fmag) {
CUML_LOG_DEBUG("Insufficient change in objective value");
return true;
}
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/glm/qn/simple_mat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,17 @@ inline T squaredNorm(const SimpleVec<T> &u, T *tmp_dev, cudaStream_t stream) {
return dot(u, u, tmp_dev, stream);
}

template <typename T>
inline T nrmMax(const SimpleVec<T> &u, T *tmp_dev, cudaStream_t stream) {
auto f = [] __device__(const T x) { return raft::myAbs<T>(x); };
auto r = [] __device__(const T x, const T y) { return raft::myMax<T>(x, y); };
raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data);
T tmp_host;
raft::update_host(&tmp_host, tmp_dev, 1, stream);
cudaStreamSynchronize(stream);
return tmp_host;
}

template <typename T>
inline T nrm2(const SimpleVec<T> &u, T *tmp_dev, cudaStream_t stream) {
return raft::mySqrt<T>(squaredNorm(u, tmp_dev, stream));
Expand Down
21 changes: 13 additions & 8 deletions python/cuml/linear_model/logistic_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ class LogisticRegression(Base,
algorithms. Even though it is presented as a single option, this solver
resolves to two different algorithms underneath:
- Orthant-Wise Limited Memory Quasi-Newton (OWL-QN) if there is l1
regularization
- Orthant-Wise Limited Memory Quasi-Newton (OWL-QN) if there is l1
regularization
- Limited Memory BFGS (L-BFGS) otherwise.
- Limited Memory BFGS (L-BFGS) otherwise.
Note that, just like in Scikit-learn, the bias will not be regularized.
Expand Down Expand Up @@ -119,12 +119,17 @@ class LogisticRegression(Base,
If 'elasticnet' is selected, OWL-QN will be used if l1_ratio > 0,
otherwise L-BFGS will be used.
tol: float (default = 1e-4)
The training process will stop if current_loss > previous_loss - tol
Tolerance for stopping criteria.
The exact stopping conditions depend on the chosen solver.
Check the solver's documentation for more details:
* :class:`Quasi-Newton (L-BFGS/OWL-QN)<cuml.QN>`
C: float (default = 1.0)
Inverse of regularization strength; must be a positive float.
Inverse of regularization strength; must be a positive float.
fit_intercept: boolean (default = True)
If True, the model tries to correct for the global mean of y.
If False, the model expects that you have centered the data.
If True, the model tries to correct for the global mean of y.
If False, the model expects that you have centered the data.
class_weight: None
Custom class weighs are currently not supported.
class_weight: dict or 'balanced', default=None
Expand Down Expand Up @@ -181,7 +186,7 @@ class LogisticRegression(Base,
coefficients and predictions of the model, similar to
using different solvers in Scikit-learn.
For additional information, see Scikit-learn's LogistRegression
For additional information, see `Scikit-learn's LogisticRegression
<https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html>`_.
"""

Expand Down
39 changes: 28 additions & 11 deletions python/cuml/solvers/qn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ class QN(Base,
Two algorithms are implemented underneath cuML's QN class, and which one
is executed depends on the following rule:
* Orthant-Wise Limited Memory Quasi-Newton (OWL-QN) if there is l1
regularization
* Orthant-Wise Limited Memory Quasi-Newton (OWL-QN) if there is l1
regularization
* Limited Memory BFGS (L-BFGS) otherwise.
* Limited Memory BFGS (L-BFGS) otherwise.
cuML's QN class can take array-like objects, either in host as
NumPy arrays or in device (as Numba or __cuda_array_interface__ compliant).
Expand Down Expand Up @@ -206,8 +206,25 @@ class QN(Base,
will not be regularized.
max_iter: int (default = 1000)
Maximum number of iterations taken for the solvers to converge.
tol: float (default = 1e-3)
The training process will stop if current_loss > previous_loss - tol
tol: float (default = 1e-4)
The training process will stop if
`norm(current_loss_grad, inf) <= tol * max(current_loss, tol)`.
This differs slightly from the `gtol`-controlled stopping condition in
`scipy.optimize.minimize(method=’L-BFGS-B’)
<https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html>`_:
`norm(current_loss_projected_grad, inf) <= gtol`.
Note, `sklearn.LogisticRegression()
<https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html>`_
uses the sum of softmax/logistic loss over the input data, whereas cuML
uses the average. As a result, Scikit-learn's loss is usually
`sample_size` times larger than cuML's.
To account for the differences you may divide the `tol` by the sample
size; this would ensure that the cuML solver does not stop earlier than
the Scikit-learn solver.
linesearch_max_iter: int (default = 50)
Max number of linesearch iterations per outer iteration of the
algorithm.
Expand Down Expand Up @@ -245,19 +262,19 @@ class QN(Base,
------
This class contains implementations of two popular Quasi-Newton methods:
- Limited-memory Broyden Fletcher Goldfarb Shanno (L-BFGS) [Nocedal,
Wright - Numerical Optimization (1999)]
- Limited-memory Broyden Fletcher Goldfarb Shanno (L-BFGS) [Nocedal,
Wright - Numerical Optimization (1999)]
- Orthant-wise limited-memory quasi-newton (OWL-QN) [Andrew, Gao - ICML
2007]
<https://www.microsoft.com/en-us/research/publication/scalable-training-of-l1-regularized-log-linear-models/>
- `Orthant-wise limited-memory quasi-newton (OWL-QN)
[Andrew, Gao - ICML 2007]
<https://www.microsoft.com/en-us/research/publication/scalable-training-of-l1-regularized-log-linear-models/>`_
"""

_coef_ = CumlArrayDescriptor()
intercept_ = CumlArrayDescriptor()

def __init__(self, *, loss='sigmoid', fit_intercept=True,
l1_strength=0.0, l2_strength=0.0, max_iter=1000, tol=1e-3,
l1_strength=0.0, l2_strength=0.0, max_iter=1000, tol=1e-4,
linesearch_max_iter=50, lbfgs_memory=5,
verbose=False, handle=None, output_type=None,
warm_start=False):
Expand Down

0 comments on commit 6584bbf

Please sign in to comment.