Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tolerate QN linesearch failures when it's harmless #3791

Merged
140 changes: 100 additions & 40 deletions cpp/src/glm/qn/qn_solvers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,68 @@ inline size_t owlqn_workspace_size(const LBFGSParam<T>& param, const int n)
return lbfgs_workspace_size(param, n) + vec_size;
}

template <typename T>
inline bool update_and_check(const char* solver,
const LBFGSParam<T>& param,
int iter,
LINE_SEARCH_RETCODE lsret,
T& fx,
T& fxp,
ML::SimpleVec<T>& x,
ML::SimpleVec<T>& xp,
ML::SimpleVec<T>& grad,
ML::SimpleVec<T>& gradp,
std::vector<T>& fx_hist,
T* dev_scalar,
OPT_RETCODE& outcode,
cudaStream_t stream)
{
bool stop = false;
bool converged = false;
bool isLsValid = !isnan(fx) && !isinf(fx);
// Linesearch may fail to converge, but still come closer to the solution;
// if that is not the case, let `check_convergence` ("insufficient change")
// below terminate the loop.
bool isLsNonCritical = lsret == LS_INVALID_STEP_MIN || lsret == LS_MAX_ITERS_REACHED;
bool isLsInDoubt = isLsValid && fx <= fxp + param.ftol && isLsNonCritical;
achirkin marked this conversation as resolved.
Show resolved Hide resolved
bool isLsSuccess = lsret == LS_SUCCESS || isLsInDoubt;

// if the target is at least finite, we can check the convergence
if (isLsValid)
converged = check_convergence(param, iter, fx, x, grad, fx_hist, dev_scalar, stream);

if (!isLsSuccess && !converged) {
CUML_LOG_WARN("%s line search failed (code %d)", solver, lsret);
outcode = OPT_LS_FAILED;
stop = true;
} else if (!isLsValid) {
CUML_LOG_ERROR("%s error fx=%f at iteration %d", solver, fx, iter);
outcode = OPT_NUMERIC_ERROR;
stop = true;
} else if (converged) {
CUML_LOG_DEBUG("%s converged", solver);
outcode = OPT_SUCCESS;
stop = true;
} else if (isLsInDoubt && fx + param.ftol >= fxp) {
CUML_LOG_WARN(
"%s stopped, because the line search failed to advance (step delta = %f); "
"perhaps, the convergence criteria are too strict?..",
achirkin marked this conversation as resolved.
Show resolved Hide resolved
solver,
fx - fxp);
outcode = OPT_LS_FAILED;
stop = true;
}

// if lineseach wasn't successful, undo the update.
if (!isLsSuccess || !isLsValid) {
fx = fxp;
x.copy_async(xp, stream);
grad.copy_async(gradp, stream);
}

return stop;
}

template <typename T, typename Function>
inline OPT_RETCODE min_lbfgs(const LBFGSParam<T>& param,
Function& f, // function to minimize
Expand Down Expand Up @@ -131,35 +193,32 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam<T>& param,
*k = 1;
int end = 0;
int n_vec = 0; // number of vector updates made in lbfgs_search_dir
OPT_RETCODE retcode;
LINE_SEARCH_RETCODE lsret;
for (; *k <= param.max_iterations; (*k)++) {
// Save the curent x and gradient
xp.copy_async(x, stream);
gradp.copy_async(grad, stream);
fxp = fx;

// Line search to update x, fx and gradient
LINE_SEARCH_RETCODE lsret =
ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream);

bool isLsSuccess = lsret == LS_SUCCESS;
CUML_LOG_TRACE("Iteration %d, fx=%f", *k, fx);
achirkin marked this conversation as resolved.
Show resolved Hide resolved

if (!isLsSuccess || isnan(fx) || isinf(fx)) {
fx = fxp;
x.copy_async(xp, stream);
grad.copy_async(gradp, stream);
if (!isLsSuccess) {
CUML_LOG_ERROR("L-BFGS line search failed");
return OPT_LS_FAILED;
}
CUML_LOG_ERROR("L-BFGS error fx=%f at iteration %d", fx, *k);
return OPT_NUMERIC_ERROR;
}

if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) {
CUML_LOG_DEBUG("L-BFGS converged");
return OPT_SUCCESS;
}
lsret = ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream);

if (update_and_check("L-BFGS",
param,
*k,
lsret,
fx,
fxp,
x,
xp,
grad,
gradp,
fx_hist,
dev_scalar,
retcode,
stream))
return retcode;

// Update s and y
// s_{k+1} = x_{k+1} - x_k
Expand Down Expand Up @@ -282,37 +341,38 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam<T>& param,

int end = 0;
int n_vec = 0; // number of vector updates made in lbfgs_search_dir
OPT_RETCODE retcode;
LINE_SEARCH_RETCODE lsret;
for ((*k) = 1; (*k) <= param.max_iterations; (*k)++) {
// Save the curent x and gradient
xp.copy_async(x, stream);
gradp.copy_async(grad, stream);
fxp = fx;

// Projected line search to update x, fx and gradient
LINE_SEARCH_RETCODE lsret = ls_backtrack_projected(
lsret = ls_backtrack_projected(
param, f_wrap, fx, x, grad, pseudo, step, drt, xp, l1_penalty, dev_scalar, stream);

bool isLsSuccess = lsret == LS_SUCCESS;
if (!isLsSuccess || isnan(fx) || isinf(fx)) {
fx = fxp;
x.copy_async(xp, stream);
grad.copy_async(gradp, stream);
if (!isLsSuccess) {
CUML_LOG_ERROR("QWL-QN line search failed");
return OPT_LS_FAILED;
}
CUML_LOG_ERROR("OWL-QN error fx=%f at iteration %d", fx, *k);
return OPT_NUMERIC_ERROR;
}
if (update_and_check("QWL-QN",
param,
*k,
lsret,
fx,
fxp,
x,
xp,
grad,
gradp,
fx_hist,
dev_scalar,
retcode,
stream))
return retcode;

// recompute pseudo
// pseudo.assign_binary(x, grad, pseudo_grad);
update_pseudo(x, grad, pseudo_grad, pg_limit, pseudo, stream);

if (check_convergence(param, *k, fx, x, pseudo, fx_hist, dev_scalar, stream)) {
CUML_LOG_DEBUG("OWL-QN converged");
return OPT_SUCCESS;
}

// Update s and y - We should only do this if there is no skipping condition

col_ref(S, svec, end);
Expand Down