Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fixes for custom allreduce kernels #2760

Merged
merged 25 commits into from
Mar 22, 2024

Conversation

hanzhi713
Copy link
Contributor

@hanzhi713 hanzhi713 commented Feb 5, 2024

Recently, there are some reports of stuck generation or garbage text when custom all reduce is enabled. While I didn't manage to reproduce any issues on A30 and A100, I did find some potentially unsafe synchronizations and I attempt to fix them here.

  1. When using the signal flag, GPUs are writing to different bytes of the same 8-byte signal. Although the writes are strong writes, they are not considered morally strong according to the CUDA memory model because they didn't overlap completely. Hence, they are still considered as data races.
  2. When using 2-stage allreduce or half butterfly allreduce, a __threadfence_system or a release-acquire pattern is needed to absolutely guarantee the visibility of other devices' write on the current device, which is missing from the current implementation.

Related issues:
#2788 (garbage output when upgrading vllm from 0.2.7 -> 0.3.0)
#2742 (garbage output. Solved when disable_custom_all_reduce=True)
#2731 (one person reported that disable_custom_all_reduce=True solve generation hanging)

In this PR, I made the following changes

  1. Use 8 uint32 instead of 8 bytes per signal per device.
  2. simplify synchronization by only syncing blocks of the same index across GPUs. Some changes are made to ensure that each thread's read only depends on the writes from the thread of the same id in other devices.
  3. add a __threadfence_system guarantee visibility of other device's writes when using 2-stage or half butterfly allreduce. Note that this adds a few microseconds of overhead.
  4. removed support for more than two PCIe-only GPUs because performance improvement is small.
  5. add additional p2p checks to avoid buggy driver/hardware P2P support. Might fix Mixtral GPTQ with TP=2 not generating output #2728
  6. add check for device count when running P2P test. Should fix Distributed inference on multi machine (error Invalid peer device id)  #2795
  7. disable custom allreduce by default by setting the argument to True. User can explicitly opt-in by setting it to False.

@hanzhi713 hanzhi713 changed the title Safer sync Fix unsafe synchronization for custom allreduce kernels Feb 7, 2024
@hanzhi713 hanzhi713 changed the title Fix unsafe synchronization for custom allreduce kernels [WIP] Fix unsafe synchronization for custom allreduce kernels Feb 7, 2024
@hanzhi713 hanzhi713 changed the title [WIP] Fix unsafe synchronization for custom allreduce kernels Fix unsafe synchronization for custom allreduce kernels Feb 8, 2024
@hanzhi713
Copy link
Contributor Author

hanzhi713 commented Feb 8, 2024

@WoosukKwon Did any of you manage to produce custom all reduce stuck or generate garbage output?

@hanzhi713 hanzhi713 changed the title Fix unsafe synchronization for custom allreduce kernels [WIP] Fix unsafe synchronization for custom allreduce kernels Feb 9, 2024
@NikolaBorisov
Copy link
Contributor

To reproduce I run vllm on 4x A100 80G SXM with CodeLlama 70b. I sent some requests like this:

for i in {0..100}; do curl "http://localhost:8000/v1/chat/completions" -H "Content-Type: application/json" -d '{
  "model": "codellama/CodeLlama-70b-Instruct-hf",
  "messages": [
    {"role": "user", "content": "Hello!"}
  ],
  "max_tokens":100
  }' & sleep 2; done

It usually gets stuck before request 50 for me.

@hanzhi713
Copy link
Contributor Author

To reproduce I run vllm on 4x A100 80G SXM with CodeLlama 70b. I sent some requests like this:

for i in {0..100}; do curl "http://localhost:8000/v1/chat/completions" -H "Content-Type: application/json" -d '{
  "model": "codellama/CodeLlama-70b-Instruct-hf",
  "messages": [
    {"role": "user", "content": "Hello!"}
  ],
  "max_tokens":100
  }' & sleep 2; done

It usually gets stuck before request 50 for me.

Did not manage to reproduce when enforce_eager=True. Custom allreduce is running fine I think

@hanzhi713
Copy link
Contributor Author

@WoosukKwon This PR is ready for review.

@hanzhi713 hanzhi713 changed the title [WIP] Fix unsafe synchronization for custom allreduce kernels Some fixes for custom allreduce kernels Feb 15, 2024
@WoosukKwon
Copy link
Collaborator

@hanzhi713 Sorry for the delays in the review. I will review the PR this weekend and make sure this is included in v0.3.4.

@tdene
Copy link

tdene commented Mar 1, 2024

Actually.

@hanzhi713 have you tested this PR with a MoE model like mixtral?
When using this PR merged on top of 5255d99 I'm seeing

custom_all_reduce.py:239] Registering 2 cuda graph addresses
[...]
File "/workspace/vllm/vllm/worker/worker.py", line 160, in warm_up_model 
  self.model_runner.capture_model(self.gpu_cache)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
   return func(*args, **kwargs)
File "/workspace/vllm/vllm/worker/model_runner.py", line 725, in capture_model
  graph_runner.capture(
[...]
File "/workspace/vllm/vllm/model_executor/models/mixtral.py", line 130, in forward
  final_hidden_states = fused_moe(hidden_states,
File "/workspace/vllm/vllm/model_executor/layers/fused_moe.py", line 276, in fused_moe
   gating_output.float(),  # TODO(woosuk): Optimize this.
RuntimeError: CUDA error: invalid device function
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@hanzhi713
Copy link
Contributor Author

hanzhi713 commented Mar 3, 2024

@tdene Hmm maybe the buffer isn't placed on the correct cuda device. Let me check tomorrow.

@hanzhi713
Copy link
Contributor Author

@tdene I didn't manage to reproduce this error

@hanzhi713
Copy link
Contributor Author

@tdene I merged this branch with the latest main. Can you rerun your test? Also, I find it strange to see Registering 2 cuda graph addresses because normally with cuda graph enabled, we should see at least a few hundred addresses, if not thousands.

@hanzhi713
Copy link
Contributor Author

@zhuohan123 Looks like @WoosukKwon is too busy. Can you help me get a different reviewer for this PR?

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hanzhi713 Apologies for the very late review. I had no bandwidth recently.

The PR looks good to me, while I still don't understand the cause of the bug. I let some minor comments. Please take a look at them.

vllm/config.py Outdated Show resolved Hide resolved
vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
Comment on lines +145 to +146
# note: num dev can be larger than world_size if we're only using
# first few GPUs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the case when more than one nodes (hosts) are used for TP, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This check is for the case that the users have N GPUs, but are only using the first M GPUs where M < N.

csrc/custom_all_reduce_test.cu Outdated Show resolved Hide resolved
@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Mar 21, 2024

@hanzhi713 Thanks! Let me do extra tests on the PR and merge.

@WoosukKwon
Copy link
Collaborator

@hanzhi713 I will merge this as I found this worked well on my 4 A100-80GB machine. Thanks for the fix!

@WoosukKwon WoosukKwon merged commit f721096 into vllm-project:main Mar 22, 2024
31 checks passed
@WoosukKwon
Copy link
Collaborator

@hanzhi713 Actually, we recently found that TRT-LLM's custom all reduce kernel is extremely simple. Do you have an idea why it can be much simpler than this implementation? What do you think about using TRT-LLM's kernel?

@hanzhi713
Copy link
Contributor Author

@hanzhi713 Actually, we recently found that TRT-LLM's custom all reduce kernel is extremely simple. Do you have an idea why it can be much simpler than this implementation? What do you think about using TRT-LLM's kernel?

It looks to me the implementation has about the same complexity as mine. What makes you think it looks simpler?

@garycaokai
Copy link

in my test, 4 non NVLink-capable GPUs situation. custom allreduce have 20% performance improvement when batch 1.
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
// performance improvement over NCCL.

@garycaokai
Copy link

in my test, 4 non NVLink-capable GPUs situation. custom allreduce have 20% performance improvement when batch 1. // for 4 or more non NVLink-capable GPUs, custom allreduce provides little // performance improvement over NCCL.

@hanzhi713 can you open 4 non NVLink-capable GPUs situation as an option?

@hanzhi713
Copy link
Contributor Author

@garycaokai Curious about your setup. What GPUs are you using? Are they all connected to a PCIe switch or are they connected to CPU directly?

Yes I can provide that as an option if I find some time to work on this.

@hanzhi713
Copy link
Contributor Author

@garycaokai Also, how did you measure the performance improvement? Is it a latency benchmark? And what is the configuration for the benchmark?

@garycaokai
Copy link

@garycaokai Curious about your setup. What GPUs are you using? Are they all connected to a PCIe switch or are they connected to CPU directly?

Yes I can provide that as an option if I find some time to work on this.
72b int4 tp=4, 1 batch, decode speed from 20 token/s to 26 token/s
A30 * 8 gpus .
#nvidia-smi topo -m
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X PIX PIX PIX SYS SYS SYS SYS NODE NODE 0-11,48-59 0 N/A
GPU1 PIX X PIX PIX SYS SYS SYS SYS NODE NODE 0-11,48-59 0 N/A
GPU2 PIX PIX X PIX SYS SYS SYS SYS NODE NODE 0-11,48-59 0 N/A
GPU3 PIX PIX PIX X SYS SYS SYS SYS NODE NODE 0-11,48-59 0 N/A
GPU4 SYS SYS SYS SYS X PIX PIX PIX SYS SYS 24-35,72-83 2 N/A
GPU5 SYS SYS SYS SYS PIX X PIX PIX SYS SYS 24-35,72-83 2 N/A
GPU6 SYS SYS SYS SYS PIX PIX X PIX SYS SYS 24-35,72-83 2 N/A
GPU7 SYS SYS SYS SYS PIX PIX PIX X SYS SYS 24-35,72-83 2 N/A
NIC0 NODE NODE NODE NODE SYS SYS SYS SYS X PIX
NIC1 NODE NODE NODE NODE SYS SYS SYS SYS PIX X

Legend:

X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks

NIC Legend:

NIC0: mlx5_0
NIC1: mlx5_1

@hanzhi713
Copy link
Contributor Author

hanzhi713 commented Apr 7, 2024

I see this makes a lot of sense. Usually A30 machines have GPUs connected to CPU directly, and CPUs are often terrible PCIe switches. My implementation relies on PCIe P2P. However, in your case you have a PCIe switch connecting to each group of 4 GPUs. Given the much better switching performance, my implementation may work and provide performance improvements.

Let me see if I can find machines of this topology

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants