diff --git a/cpp/src/glm/qn/qn_solvers.cuh b/cpp/src/glm/qn/qn_solvers.cuh index 23b1f7eb03..f98873bf5a 100644 --- a/cpp/src/glm/qn/qn_solvers.cuh +++ b/cpp/src/glm/qn/qn_solvers.cuh @@ -108,13 +108,11 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam ¶m, // Evaluate function and compute gradient fx = f(x, grad, dev_scalar, stream); - T xnorm = nrm2(x, dev_scalar, stream); - T gnorm = nrm2(grad, dev_scalar, stream); if (param.past > 0) fx_hist[0] = fx; // Early exit if the initial x is already a minimizer - if (gnorm <= param.epsilon * std::max(xnorm, T(1.0))) { + if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) { CUML_LOG_DEBUG("Initial solution fulfills optimality condition."); return OPT_SUCCESS; } @@ -255,13 +253,10 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam ¶m, Function &f, // pseudo.assign_binary(x, grad, pseudo_grad); update_pseudo(x, grad, pseudo_grad, pg_limit, pseudo, stream); - T xnorm = nrm2(x, dev_scalar, stream); - T gnorm = nrm2(pseudo, dev_scalar, stream); - if (param.past > 0) fx_hist[0] = fx; // Early exit if the initial x is already a minimizer - if (gnorm <= param.epsilon * std::max(xnorm, T(1.0))) { + if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) { CUML_LOG_DEBUG("Initial solution fulfills optimality condition."); return OPT_SUCCESS; } diff --git a/cpp/src/glm/qn/qn_util.cuh b/cpp/src/glm/qn/qn_util.cuh index 4546ef736e..2721af417b 100644 --- a/cpp/src/glm/qn/qn_util.cuh +++ b/cpp/src/glm/qn/qn_util.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020, NVIDIA CORPORATION. + * Copyright (c) 2018-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -120,21 +120,23 @@ inline bool check_convergence(const LBFGSParam ¶m, const int k, const T fx, SimpleVec &x, SimpleVec &grad, std::vector &fx_hist, T *dev_scalar, cudaStream_t stream) { - // New x norm and gradient norm - T xnorm = nrm2(x, dev_scalar, stream); - T gnorm = nrm2(grad, dev_scalar, stream); - - CUML_LOG_DEBUG("%04d: f(x)=%.8f conv.crit=%.8f (gnorm=%.8f, xnorm=%.8f)", k, - fx, gnorm / std::max(T(1), xnorm), gnorm, xnorm); + // Gradient norm is now in Linf to match the reference implementation + // (originally it was L2-norm) + T gnorm = nrmMax(grad, dev_scalar, stream); + // Positive scale factor for the stop condition + T fmag = std::max(fx, param.epsilon); + + CUML_LOG_DEBUG("%04d: f(x)=%.8f conv.crit=%.8f (gnorm=%.8f, fmag=%.8f)", k, + fx, gnorm / fmag, gnorm, fmag); // Convergence test -- gradient - if (gnorm <= param.epsilon * std::max(xnorm, T(1.0))) { + if (gnorm <= param.epsilon * fmag) { CUML_LOG_DEBUG("Converged after %d iterations: f(x)=%.6f", k, fx); return true; } // Convergence test -- objective function value if (param.past > 0) { if (k >= param.past && - std::abs((fx_hist[k % param.past] - fx) / fx) < param.delta) { + std::abs(fx_hist[k % param.past] - fx) <= param.delta * fmag) { CUML_LOG_DEBUG("Insufficient change in objective value"); return true; } diff --git a/cpp/src/glm/qn/simple_mat.cuh b/cpp/src/glm/qn/simple_mat.cuh index 308d5039c3..225d2cf022 100644 --- a/cpp/src/glm/qn/simple_mat.cuh +++ b/cpp/src/glm/qn/simple_mat.cuh @@ -239,6 +239,17 @@ inline T squaredNorm(const SimpleVec &u, T *tmp_dev, cudaStream_t stream) { return dot(u, u, tmp_dev, stream); } +template +inline T nrmMax(const SimpleVec &u, T *tmp_dev, cudaStream_t stream) { + auto f = [] __device__(const T x) { return raft::myAbs(x); }; + auto r = [] __device__(const T x, const T y) { return raft::myMax(x, y); }; + raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data); + T tmp_host; + raft::update_host(&tmp_host, tmp_dev, 1, stream); + cudaStreamSynchronize(stream); + return tmp_host; +} + template inline T nrm2(const SimpleVec &u, T *tmp_dev, cudaStream_t stream) { return raft::mySqrt(squaredNorm(u, tmp_dev, stream)); diff --git a/python/cuml/linear_model/logistic_regression.pyx b/python/cuml/linear_model/logistic_regression.pyx index c439fe3d9e..41642c98ba 100644 --- a/python/cuml/linear_model/logistic_regression.pyx +++ b/python/cuml/linear_model/logistic_regression.pyx @@ -56,10 +56,10 @@ class LogisticRegression(Base, algorithms. Even though it is presented as a single option, this solver resolves to two different algorithms underneath: - - Orthant-Wise Limited Memory Quasi-Newton (OWL-QN) if there is l1 - regularization + - Orthant-Wise Limited Memory Quasi-Newton (OWL-QN) if there is l1 + regularization - - Limited Memory BFGS (L-BFGS) otherwise. + - Limited Memory BFGS (L-BFGS) otherwise. Note that, just like in Scikit-learn, the bias will not be regularized. @@ -119,12 +119,17 @@ class LogisticRegression(Base, If 'elasticnet' is selected, OWL-QN will be used if l1_ratio > 0, otherwise L-BFGS will be used. tol: float (default = 1e-4) - The training process will stop if current_loss > previous_loss - tol + Tolerance for stopping criteria. + The exact stopping conditions depend on the chosen solver. + Check the solver's documentation for more details: + + * :class:`Quasi-Newton (L-BFGS/OWL-QN)` + C: float (default = 1.0) - Inverse of regularization strength; must be a positive float. + Inverse of regularization strength; must be a positive float. fit_intercept: boolean (default = True) - If True, the model tries to correct for the global mean of y. - If False, the model expects that you have centered the data. + If True, the model tries to correct for the global mean of y. + If False, the model expects that you have centered the data. class_weight: None Custom class weighs are currently not supported. class_weight: dict or 'balanced', default=None @@ -181,7 +186,7 @@ class LogisticRegression(Base, coefficients and predictions of the model, similar to using different solvers in Scikit-learn. - For additional information, see Scikit-learn's LogistRegression + For additional information, see `Scikit-learn's LogisticRegression `_. """ diff --git a/python/cuml/solvers/qn.pyx b/python/cuml/solvers/qn.pyx index f224c9811f..64c5d2e09a 100644 --- a/python/cuml/solvers/qn.pyx +++ b/python/cuml/solvers/qn.pyx @@ -132,10 +132,10 @@ class QN(Base, Two algorithms are implemented underneath cuML's QN class, and which one is executed depends on the following rule: - * Orthant-Wise Limited Memory Quasi-Newton (OWL-QN) if there is l1 - regularization + * Orthant-Wise Limited Memory Quasi-Newton (OWL-QN) if there is l1 + regularization - * Limited Memory BFGS (L-BFGS) otherwise. + * Limited Memory BFGS (L-BFGS) otherwise. cuML's QN class can take array-like objects, either in host as NumPy arrays or in device (as Numba or __cuda_array_interface__ compliant). @@ -206,8 +206,25 @@ class QN(Base, will not be regularized. max_iter: int (default = 1000) Maximum number of iterations taken for the solvers to converge. - tol: float (default = 1e-3) - The training process will stop if current_loss > previous_loss - tol + tol: float (default = 1e-4) + The training process will stop if + + `norm(current_loss_grad, inf) <= tol * max(current_loss, tol)`. + + This differs slightly from the `gtol`-controlled stopping condition in + `scipy.optimize.minimize(method=’L-BFGS-B’) + `_: + + `norm(current_loss_projected_grad, inf) <= gtol`. + + Note, `sklearn.LogisticRegression() + `_ + uses the sum of softmax/logistic loss over the input data, whereas cuML + uses the average. As a result, Scikit-learn's loss is usually + `sample_size` times larger than cuML's. + To account for the differences you may divide the `tol` by the sample + size; this would ensure that the cuML solver does not stop earlier than + the Scikit-learn solver. linesearch_max_iter: int (default = 50) Max number of linesearch iterations per outer iteration of the algorithm. @@ -245,19 +262,19 @@ class QN(Base, ------ This class contains implementations of two popular Quasi-Newton methods: - - Limited-memory Broyden Fletcher Goldfarb Shanno (L-BFGS) [Nocedal, - Wright - Numerical Optimization (1999)] + - Limited-memory Broyden Fletcher Goldfarb Shanno (L-BFGS) [Nocedal, + Wright - Numerical Optimization (1999)] - - Orthant-wise limited-memory quasi-newton (OWL-QN) [Andrew, Gao - ICML - 2007] - + - `Orthant-wise limited-memory quasi-newton (OWL-QN) + [Andrew, Gao - ICML 2007] + `_ """ _coef_ = CumlArrayDescriptor() intercept_ = CumlArrayDescriptor() def __init__(self, *, loss='sigmoid', fit_intercept=True, - l1_strength=0.0, l2_strength=0.0, max_iter=1000, tol=1e-3, + l1_strength=0.0, l2_strength=0.0, max_iter=1000, tol=1e-4, linesearch_max_iter=50, lbfgs_memory=5, verbose=False, handle=None, output_type=None, warm_start=False):