Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RLHF slowdown in attention multi steps extend_step. #849

Merged
merged 1 commit into from
Nov 19, 2024

Conversation

ds-hwang
Copy link
Contributor

Fix RLHF slowdown in attention multi steps extend_step.

jax.lax.dynamic_update_slice_in_dim is generally faster than advanced indexing,
but an unusual slowdown was observed, with RLHF sampling taking up to 3 hours
per run. TODO: Investigate and fix it.

For your information, in #831, I experimented
with both dynamic_update_slice and advanced indexing on TPUv4 and chose the
faster option. It's also known that dynamic_update_slice performs better when
copying contiguous memory. This is a very surprising case.

Advanced Indexing

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
QkvLinearExtendStepBenchmark/2048/16/1024/1         7.16 ms        0.623 ms          492
QkvLinearExtendStepBenchmark/2048/16/4096/1         8.52 ms        0.624 ms          561
QkvLinearExtendStepBenchmark/2048/16/32768/1        34.6 ms         1.64 ms           78
QkvLinearExtendStepBenchmark/2048/16/4096/8         63.6 ms         1.74 ms           81
QkvLinearExtendStepBenchmark/2048/16/4096/64         276 ms         2.40 ms           81
QkvLinearExtendStepBenchmark/2048/16/4096/512       2541 ms         81.6 ms            1

dynamic_update_slice

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
QkvLinearExtendStepBenchmark/2048/16/1024/1         1.70 ms        0.513 ms         1125
QkvLinearExtendStepBenchmark/2048/16/4096/1         3.40 ms        0.519 ms         1174
QkvLinearExtendStepBenchmark/2048/16/32768/1        20.1 ms        0.930 ms          404
QkvLinearExtendStepBenchmark/2048/16/4096/8         3.68 ms        0.524 ms         1139
QkvLinearExtendStepBenchmark/2048/16/4096/64        3.74 ms        0.532 ms         1125
QkvLinearExtendStepBenchmark/2048/16/4096/512       2530 ms         80.4 ms            1

jax.lax.dynamic_update_slice_in_dim is generally faster than advanced indexing,
but an unusual slowdown was observed, with RLHF sampling taking up to 3 hours
per run. Investigate and fix it.
https://a1350286.slack.com/archives/C03HJAYC7JA/p1731998432387409?thread_ts=1731968765.840839&cid=C03HJAYC7JA

For your information, in
https://github.pie.apple.com/foundation-models/axlearn/pull/894, I experimented
with both dynamic_update_slice and advanced indexing on TPUv4 and chose the
faster option. It's also known that dynamic_update_slice performs better when
copying contiguous memory. This is a very surprising case.

Advanced Indexing
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
QkvLinearExtendStepBenchmark/2048/16/1024/1         7.16 ms        0.623 ms          492
QkvLinearExtendStepBenchmark/2048/16/4096/1         8.52 ms        0.624 ms          561
QkvLinearExtendStepBenchmark/2048/16/32768/1        34.6 ms         1.64 ms           78
QkvLinearExtendStepBenchmark/2048/16/4096/8         63.6 ms         1.74 ms           81
QkvLinearExtendStepBenchmark/2048/16/4096/64         276 ms         2.40 ms           81
QkvLinearExtendStepBenchmark/2048/16/4096/512       2541 ms         81.6 ms            1

dynamic_update_slice
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
QkvLinearExtendStepBenchmark/2048/16/1024/1         1.70 ms        0.513 ms         1125
QkvLinearExtendStepBenchmark/2048/16/4096/1         3.40 ms        0.519 ms         1174
QkvLinearExtendStepBenchmark/2048/16/32768/1        20.1 ms        0.930 ms          404
QkvLinearExtendStepBenchmark/2048/16/4096/8         3.68 ms        0.524 ms         1139
QkvLinearExtendStepBenchmark/2048/16/4096/64        3.74 ms        0.532 ms         1125
QkvLinearExtendStepBenchmark/2048/16/4096/512       2530 ms         80.4 ms            1
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this fix the slowdown?

@ds-hwang
Copy link
Contributor Author

ds-hwang commented Nov 19, 2024

Does this fix the slowdown?

Yes :) More details in 918
Thank you for review!

@ds-hwang ds-hwang added this pull request to the merge queue Nov 19, 2024
Merged via the queue into apple:main with commit 2803b36 Nov 19, 2024
10 checks passed
@ds-hwang ds-hwang deleted the mult_bug_fix branch November 19, 2024 17:58
ds-hwang added a commit to ds-hwang/axlearn that referenced this pull request Dec 2, 2024
`k_proj` is not properly set sharding hints, so QKVLinear.extend_step cannot
create next `cached_key` with proper hints.
This causes OOM for diffusion model, because the code cannot know the local
batch size.
     Shape: f32[1024,2048,8,128]{3,2,1,0:T(8,128)}
     Unpadded size: 8.00G

To fix it, copy `cached_key.sharding` to `k_proj.sharding`, as `cached_key`
sharding is properly set up.

In addition, this is the reason of RLHF slowdown, so revert the workaround change.
apple#849
qdavid1 pushed a commit to qdavid1/axlearn that referenced this pull request Dec 11, 2024
jax.lax.dynamic_update_slice_in_dim is generally faster than advanced indexing,
but an unusual slowdown was observed, with RLHF sampling taking up to 3 hours
per run. Investigate and fix it.
https://a1350286.slack.com/archives/C03HJAYC7JA/p1731998432387409?thread_ts=1731968765.840839&cid=C03HJAYC7JA

For your information, in
https://github.pie.apple.com/foundation-models/axlearn/pull/894, I experimented
with both dynamic_update_slice and advanced indexing on TPUv4 and chose the
faster option. It's also known that dynamic_update_slice performs better when
copying contiguous memory. This is a very surprising case.

Advanced Indexing
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
QkvLinearExtendStepBenchmark/2048/16/1024/1         7.16 ms        0.623 ms          492
QkvLinearExtendStepBenchmark/2048/16/4096/1         8.52 ms        0.624 ms          561
QkvLinearExtendStepBenchmark/2048/16/32768/1        34.6 ms         1.64 ms           78
QkvLinearExtendStepBenchmark/2048/16/4096/8         63.6 ms         1.74 ms           81
QkvLinearExtendStepBenchmark/2048/16/4096/64         276 ms         2.40 ms           81
QkvLinearExtendStepBenchmark/2048/16/4096/512       2541 ms         81.6 ms            1

dynamic_update_slice
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
QkvLinearExtendStepBenchmark/2048/16/1024/1         1.70 ms        0.513 ms         1125
QkvLinearExtendStepBenchmark/2048/16/4096/1         3.40 ms        0.519 ms         1174
QkvLinearExtendStepBenchmark/2048/16/32768/1        20.1 ms        0.930 ms          404
QkvLinearExtendStepBenchmark/2048/16/4096/8         3.68 ms        0.524 ms         1139
QkvLinearExtendStepBenchmark/2048/16/4096/64        3.74 ms        0.532 ms         1125
QkvLinearExtendStepBenchmark/2048/16/4096/512       2530 ms         80.4 ms            1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants