-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dask] Early stopping #3952
[dask] Early stopping #3952
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for this!! I'll give it a thorough review tonight or tomorrow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I took a first look through this. I think the approach you took to knitting the parts together makes sense, and thanks for dealing with the added complexity of the fact that every worked has to have at least a little bit of evaluation data!
I still need to test this and go through it again more carefully, but I left an initial round of suggestions for your consideration.
We also do some weird stuff in LightGBM's docs for the Dask estimators, and right now all of the eval_*
things are hidden in those docs. I described this in #3950 (review). If you see that and are brave enough, I'd appreciate if you could update the docs in this PR as well. But if you look and it seems too complicated based on the amount of time you available to contribute, no problem at all and I can do it in a later PR.
Just want to remind you that you will not be able to see the final result at RTD because this PR is from a fork repo. |
I understand. The instructions at https://github.com/microsoft/LightGBM/blob/master/docs/README.rst#build have worked well for me in testing the docs locally. |
_train_part model.fit args to lines Co-authored-by: James Lamb <[email protected]>
_train_part model.fit args to lines, pt2 Co-authored-by: James Lamb <[email protected]>
_train_part model.fit args to lines pt3 Co-authored-by: James Lamb <[email protected]>
dask_model.fit args to lines Co-authored-by: James Lamb <[email protected]>
Co-authored-by: James Lamb <[email protected]>
use is instead of id() Co-authored-by: James Lamb <[email protected]>
Co-authored-by: James Lamb <[email protected]>
Co-authored-by: James Lamb <[email protected]>
Co-authored-by: James Lamb <[email protected]>
Co-authored-by: James Lamb <[email protected]>
Co-authored-by: James Lamb <[email protected]>
Co-authored-by: James Lamb <[email protected]>
Co-authored-by: James Lamb <[email protected]>
Co-authored-by: James Lamb <[email protected]>
Hey yeah I think I can handle opening up the docs for the eval* params. I mean the current kind of surgery on the doc strings is just eliminating a whole section of parameters, so I need to open them back up, but probably still keep... |
Hey, sorry for the delay in communication. The code and tests are basically how I want them, but I've been testing this out over the last week via pytest and in a jupyterlab notebook. When I
(Note the screenshot of the worker status was taken at a different point in time as the task call stack - the worker addresses are the same as those in the task reports). After googling "Dask task hanging," I saw the note about oversubscribed threads. So I attempted a fix where I set If it's any use, both the Totally open to any/all advice or insights here! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Man sorry it took SO LONG to get back to you. I've finished the known issues I wanted to get done for the next release (#3872 ), so now I can fully focus on this.
I cloned your branch and ran some tests locally, and I was able to exactly reproduce the behavior you thoroughly documented in #3952 (comment). Thanks for describing that so well!
I just focused on the regression tests with a single eval set, since that was enough to reproduce the issue.
testing code
import lightgbm as lgb
import dask.array as da
import numpy as np
from dask.array.utils import assert_eq
from dask.distributed import Client, LocalCluster
from sklearn.datasets import make_regression
def _create_data(objective, n_samples=100, n_features=None, centers=2, output='array', chunk_size=50, random_weights=True):
n_features = n_features if n_features else 100
X, y = make_regression(n_samples=n_samples, n_features=n_features, random_state=42)
rnd = np.random.RandomState(42)
weights = rnd.random(X.shape[0]) * 0.01
if not random_weights:
weights = np.ones([X.shape[0]])
dX = da.from_array(X, (chunk_size, X.shape[1]))
dy = da.from_array(y, chunk_size)
dw = da.from_array(weights, chunk_size)
return X, y, weights, dX, dy, dw
def _r2_score(dy_true, dy_pred):
numerator = ((dy_true - dy_pred) ** 2).sum(axis=0, dtype=np.float64)
denominator = ((dy_true - dy_pred.mean(axis=0)) ** 2).sum(axis=0, dtype=np.float64)
return (1 - numerator / denominator).compute()
def _accuracy_score(dy_true, dy_pred):
return da.average(dy_true == dy_pred).compute()
cluster = LocalCluster(n_workers=2)
client = Client(cluster)
task = 'regression'
eval_sizes = [0.9]
eval_names_prefix = 'specified'
client.restart()
# use larger number of samples to prevent faux early stopping whereby
# boosting stops on accident because each worker has few data points and achieves 0 loss.
n_samples = 1000
n_features = 10
n_eval_sets = len(eval_sizes)
early_stopping_rounds = 1
eval_set = []
eval_sample_weight = []
eval_names = [eval_names_prefix + f'_{i}' for i in range(len(eval_sizes))]
X, y, w, dX, dy, dw = _create_data(
n_samples=n_samples,
n_features=n_features,
objective=task,
output='array',
chunk_size=10,
random_weights=False
)
dg = None
eval_at = None
eval_group = None
model_factory = lgb.DaskLGBMRegressor
eval_metrics = ['rmse']
for eval_size in eval_sizes:
_, _, _, dX_e, dy_e, dw_e = _create_data(
n_samples=max(10, int(n_samples * eval_size)),
n_features=n_features,
objective=task,
output='array',
chunk_size=10,
random_weights=False
)
eval_set.append((dX_e, dy_e))
eval_sample_weight.append(dw_e)
full_trees = 100
params = {
"random_state": 42,
"n_estimators": full_trees,
"num_leaves": 31,
"first_metric_only": True
}
dask_model = model_factory(
client=client,
**params
)
dask_model = dask_model.fit(
dX,
dy,
group=dg,
eval_set=eval_set,
eval_names=eval_names,
eval_sample_weight=eval_sample_weight,
eval_group=eval_group,
eval_metric=eval_metrics,
early_stopping_rounds=early_stopping_rounds,
eval_at=eval_at,
verbose=True
)
fitted_trees = dask_model.booster_.num_trees()
assert fitted_trees < full_trees
assert dask_model.best_iteration_ < full_trees
# be sure that model still produces decent output.
p1 = dask_model.predict(dX)
p1_r2 = _r2_score(dy, p1)
msg = f'r2 score of predictions with actuals was <= 0.8 ({p1_r2})'
assert p1_r2 > 0.8, msg
# check that evals_result contains expected eval_set names when provided.
n_rounds_tried = dask_model.best_iteration_ + early_stopping_rounds
evals_result = dask_model.evals_result_
assert len(evals_result) == n_eval_sets
evals_result_names = list(evals_result.keys())
if eval_names:
assert all(x in eval_names for x in evals_result_names)
# check that evals_result names default to "training" or "valid_xx" without eval_names.
for evals_result_name in evals_result_names:
if not eval_names:
assert evals_result_name.startswith('training') or evals_result_name.startswith('valid')
# check that eval_metric(s) are contained in evals_result dicts.
for i, metric in enumerate(eval_metrics):
assert metric in evals_result[evals_result_name]
# len of each eval_metric should be number of fitted trees + early_stopping_rounds.
assert len(evals_result[evals_result_name][metric]) == n_rounds_tried
# stopping decision should have been made based on the best score of the first of eval_metrics.
if i == 0:
best_score = dask_model.best_score_[evals_result_name][metric]
best_iter_zero_indexed = dask_model.best_iteration_ - 1
assert_eq(best_score, min(evals_result[evals_result_name][metric]), atol=0.03)
assert abs(best_iter_zero_indexed - np.argmin(evals_result[evals_result_name][metric])) \
<= early_stopping_rounds
A Theory
I have a theory about what's going on. I think that it might be the case that once early stopping has been detected, workers are shutting down and not carefully coordinating the shutdown. That could create a race condition where the order that they shut down matters. I suspect this because you were getting socket recv error: 104
.
That comes from
LightGBM/src/network/socket_wrapper.hpp
Line 286 in e5c3f7e
Log::Fatal("Socket recv error, code: %d", GetLastError()); |
LightGBM/src/network/socket_wrapper.hpp
Line 168 in e5c3f7e
return errno; |
UPDATE (3/13): found a better list of all error codes at https://www-numi.fnal.gov/offline_software/srt_public_context/WebDocs/Errors/unix_system_errors.html
#include <iostream>
#include <cmath>
#include <cerrno>
#include <cstring>
#include <clocale>
#include <cstdio>
int main()
{
std::cout << "message: " << std::strerror(ECONNRESET) << '\n';
printf("error code: %d\n", ECONNRESET);
}
message: Connection reset by peer
error code: 104
I think this type of problem could explain all three behaviors you're seeing:
- sometimes it works
- sometimes training fails with this "connection reset by peer"
- sometimes it hangs forever
We could try to figure this out by looking at how mmlspark handles early stopping, and by @
-ing in other maintainers here. We could also try to see if the problem shows up in non-Dask distributed training, like with the LightGBM CLI, by picking up #3841.
HOWEVER, before we do that, I have a request.
A Request
Can you open a separate PR with only the non-early-stopping eval changes? Being able to use evaluation sets is valuable even without early stopping, because that allows you to get a deeper understanding of how the training process is progressing. The test_eval_set_without_early_stopping
tests passed every time I ran them, so I'm fairly confident that they could be separated out.
pytest test_dask.py::test_eval_set_without_early_stopping
Doing the evaluation sets separately would also make this PR smaller and make it a bit easier for us to focus in on the source of instability uncovered by your tests.
Thank you SO SO SO much for all of your hard work on this. Sorry the PR has dragged on for a while. Now that the next release is almost ready, I can be a lot more responsive and help move this forward.
else: | ||
# when a worker receives no eval_set while other workers have eval data, causes LightGBMExceptions. | ||
if evals_provided: | ||
local_worker_address = get_worker().address |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this again, I have a proposal that I think might simplify this part. Instead of having each worker process check "hey did I get any eval data", could we move this check up into _train()
?`
Since _train()
already computes the result of client.who_has(parts)
, I think it should be possible to look at those results and check right away if any worker is going to end up not getting eval data.
I think that's preferable because it means you get the error on the client before a bunch of training tasks start on the different workers. That avoids lightgbm
needing to know how to manage the failure of one worker process. I think you've seen so far that a single worker erroring out is not handled very gracefully right now, and can result in those "socket recv error: 104" types of errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha yeah I've had the same thought - "why am I waiting to do this here?" I was just too lazy to pry into the worker-parts map and go check before launching client.submit
s. This is a good call and should be done prior to worker dispatch.
Hey @jameslamb no problem, seriously this was worth the wait. Thanks for all of your effort in actually going and reproducing the issue! The worst part about bugs is I never really know if it's just my bad code, if there really is a bug, or if I'm just going nuts. Yeah I totally agree with you on both fronts - this is probably a deeper issue than it makes sense to tackle in this PR or even before the lightgbm.dask release, and what you're proposing about just tackling eval_sets makes sense. Totally happy make a simpler PR to get eval_sets out the door but not early stopping. Questions about that PR:
|
Yeah I would accept a PR that adds
I have a preference for not leaving TODO comments in the code, since they can become out of date as things get refactored and since they're not as visible in the backlog as issues. Whether you choose to update this branch or close this PR and create a new one is totally up to you! If it was me, I'd do this:
I would do that if it was me so you could always come back to this PR and easily look at the diff to see where early stopping changes fit in. But up to you! |
Awesome, thanks for these tips, on it |
Hey I'm currently fixing merge conflicts and noticed #3950, so I'm going to address eval_init_score in the upcoming PR for eval_sets as well |
sounds good! |
Hey I just opened up #4101 so I'm going to close this out! Thanks for all the patience, and sorry eval_set/ES couldn't make it into the first lightgbm.dask release! |
This pull request has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this. |
Attempt to address #3712 - lightgbm.dask support for early stopping. Implemented this to work with multiple eval sets (i.e. multiple (X, y) pairs), sample weights, group lengths, and implemented so that when an individual
eval_set
,eval_sample_weights
, oreval_group
is the same as (data
,label
),sample_weights
, orgroup
, just use the latter instead of having to recompute the training set/weights/group lengths.This is all that's going on, making little mini eval sets out of delayed parts in a consistent manner:
Note that our test cases actually uncovered a small data issue - when one worker isn't distributed any
eval_set
data (e.g. because the validation set has fewer parts than workers, or because the data distribution is very unbalanced), then LightGBM throws an exception (because other workers do have eval_set data). This is why I added theRuntimeError
- check for whether a worker has not received any eval_set parts when_train
has been provided aneval_set
.