Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training more than one epoch #914

Open
peregilk opened this issue Sep 24, 2024 · 4 comments
Open

Training more than one epoch #914

peregilk opened this issue Sep 24, 2024 · 4 comments

Comments

@peregilk
Copy link

@aireenmei Referring you here, because I think this issue is touched in #571 where you write:

I did not implement the auto restart because some users may not want their model to see repetitive data. I can add the multi-epoch support to our backlog. Meanwhile it should be straightforward to change the shard update logic here: https://github.com/google/maxtext/blob/main/MaxText/input_pipeline/_input_pipeline_utils.py#L105

The behaviour now seems to have changed a bit, and it might even be more confusing. I am a bit uncertain what has changed in the code here.

What I am trying to do is switching dataset during training. Here from step 160k. This is a fairly small special task dataset, and I am studying the effect. The dataset has 256 shards, and one epoch is roughly 350 steps.

Here is what is happening with comments:

# Perfectly normal. Switching to next shard. Weights and loss are fine
Updating host 3 dataset 0, was on shard 3
New shard is 67
completed step: 160086, seconds: 4.090, TFLOP/s/device: 116.004, Tokens/s/device: 2003.075, total_weights: 2080798, loss: 1.113

# Still normal
Updating host 3 dataset 0, was on shard 67
New shard is 131
completed step: 160177, seconds: 4.090, TFLOP/s/device: 115.995, Tokens/s/device: 2002.925, total_weights: 2078579, loss: 1.072

# Still normal
Updating host 3 dataset 0, was on shard 131
New shard is 195
completed step: 160268, seconds: 4.090, TFLOP/s/device: 115.989, Tokens/s/device: 2002.811, total_weights: 2079952, loss: 1.049

# Here things are starting to go south. The host starts generating all-0 paddings
completed step: 160359, seconds: 4.090, TFLOP/s/device: 116.001, Tokens/s/device: 2003.031, total_weights: 2077782, loss: 1.036

# Runs for a while, but then the total_weights start dropping, and the loss starts to drop
completed step: 160367, seconds: 4.091, TFLOP/s/device: 115.971, Tokens/s/device: 2002.507, total_weights: 2034296, loss: 1.030
completed step: 160368, seconds: 4.090, TFLOP/s/device: 116.002, Tokens/s/device: 2003.040, total_weights: 1860858, loss: 1.028
completed step: 160369, seconds: 4.090, TFLOP/s/device: 115.995, Tokens/s/device: 2002.928, total_weights: 1207504, loss: 1.038
completed step: 160370, seconds: 4.090, TFLOP/s/device: 115.991, Tokens/s/device: 2002.854, total_weights: 616193, loss: 1.038
completed step: 160371, seconds: 4.090, TFLOP/s/device: 115.984, Tokens/s/device: 2002.734, total_weights: 184994, loss: 1.037
completed step: 160372, seconds: 4.090, TFLOP/s/device: 115.984, Tokens/s/device: 2002.739, total_weights: 46490, loss: 1.058
completed step: 160373, seconds: 4.091, TFLOP/s/device: 115.976, Tokens/s/device: 2002.600, total_weights: 32596, loss: 0.989
completed step: 160374, seconds: 4.091, TFLOP/s/device: 115.978, Tokens/s/device: 2002.634, total_weights: 32491, loss: 1.041

# A bit later
completed step: 160460, seconds: 4.090, TFLOP/s/device: 115.987, Tokens/s/device: 2002.787, total_weights: 32673, loss: 0.980
completed step: 160461, seconds: 4.091, TFLOP/s/device: 115.970, Tokens/s/device: 2002.484, total_weights: 32503, loss: 1.043
completed step: 160462, seconds: 4.090, TFLOP/s/device: 115.984, Tokens/s/device: 2002.736, total_weights: 1904, loss: 1.068
completed step: 160463, seconds: 4.091, TFLOP/s/device: 115.966, Tokens/s/device: 2002.420, total_weights: 0, loss: 0.000
completed step: 160464, seconds: 4.090, TFLOP/s/device: 115.990, Tokens/s/device: 2002.845, total_weights: 0, loss: 0.000

```

This behaviour is a bit unpredictable. Especially since some shards here can be smaller, and it is hard to know when the first host runs out of shards. Running out of shards seems to hurt the model.

What is your advice here?
@aireenmei
Copy link
Collaborator

Hi @peregilk, the new behavior is documented here: https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#huggingface-pipeline-in-multihost.

@peregilk
Copy link
Author

@aireenmei Thanks a lot for the explanation. I thought the drop in weights and loss here did hurt the model, and was wondering why this did not show up in my evaluations. Now it makes total sense. Thanks.

@peregilk
Copy link
Author

peregilk commented Oct 3, 2024

@aireenmei Just a couple of minor issues. Attaching to this thread since they are related. I followed the instructions on the page above, and discovered two minor issues:

  • eval_interval needs to be set in the config-file. It is not accepted from the command line.
  • I need to set eval_steps>0 as well. Even if config comment says this is for debugging only. Evalutation crashes if eval_steps is not set.

@aireenmei
Copy link
Collaborator

Thanks for reporting. Yes setting eval_steps is recommended, it's no longer for debugging only. I'll update that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants