Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm][Bugfix] Fixed several bugs related to rccl path and attention selector logic #3699

Merged
merged 7 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,6 @@ RUN cd /app \
&& cd ..

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir ray[all]
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3

CMD ["/bin/bash"]
2 changes: 1 addition & 1 deletion requirements-rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ninja # For faster builds.
typing-extensions>=4.8.0
starlette
psutil
ray >= 2.9
ray == 2.9.3
sentencepiece # Required for LLaMA tokenizer.
numpy
tokenizers>=0.15.0
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ def _check_use_naive_attention() -> bool:
if not is_hip():
return False
# For ROCm, check whether flash attention is installed or not.
has_flash_attn = importlib.util.find_spec("flash_attn") is None
if not has_flash_attn:
use_naive_attention = importlib.util.find_spec("flash_attn") is None
if use_naive_attention:
logger.warning("flash_attn is not installed. Using naive attention. "
"This will take significantly more GPU memory.")
return True
Expand Down
40 changes: 40 additions & 0 deletions vllm/model_executor/parallel_utils/find_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import subprocess

import logging
import os

logger = logging.getLogger(__name__)
import re

def get_library_path(library_name):
# Robust way to find the library path from torch installation
# Hard coding a library parth is error prone
try:
torch_dir = os.path.dirname(torch.__file__)
torch_path = os.path.join(torch_dir, "lib", "libtorch.so")

result = subprocess.run(['ldd', '-v', '-r', '-d', torch_path], capture_output=True, text=True)
if result.returncode == 0:
output_lines = result.stdout.split("\n")
for line in output_lines:
if library_name in line:
match = re.search(r'=>\s*(\S+)', line)
if match:
library_path = match.group(1)
return library_path
else:
logger.error(f"PyTorch is not installed properly. {result.stderr}")
except Exception as e:
logger.error(f"Error finding library path: {e}")
return None

# you can test this
if __name__ == "__main__":

# this works for librccl.so, librccl.so.1, etc
rccl_path = get_library_path("librccl.so")
if rccl_path:
print(f"location is {rccl_path}")
else:
print("librccl.so not found")
11 changes: 9 additions & 2 deletions vllm/model_executor/parallel_utils/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from vllm.model_executor.parallel_utils.find_lib import get_library_path
from vllm.utils import is_hip

logger = logging.getLogger(__name__)

so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")

if is_hip():
# a robust way to get the path of librccl, no matter it is librccl.so, or librccl.so.1
so_file = get_library_path("librccl.so")
else:
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")

# manually load the nccl library
if so_file:
Expand All @@ -41,7 +48,7 @@
if torch.version.cuda is not None:
so_file = "libnccl.so.2"
elif torch.version.hip is not None:
so_file = "librccl.so.2"
so_file = "librccl.so.1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at https://rocm.docs.amd.com/projects/rccl/en/latest/api.html , and it says the current version is 2.18.3 . Quite strange that the library name is librccl.so.1 .

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is why I am not assuming what the suffix is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you talk to rccl team why this is the case? If they keep librccl.so.1 that would also be fine, but just please don't be too random.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial test with the current head is that, it does not work for ROCm. There are a bunch of other issues in addition to the ones described in this pull request.

We have tested using cupy and verified that it worked for the hipgraph path with our in-development newer ROCm.

However, this does not work for us.

Another thing, is that, will it be possible we can still opt in using cupy for all-reduce? Can it be abstracted so that people can choose use cupy, nccl, or, whatever?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as how rccl so file name and its version definition: I found information ROCm/rccl repo. Links below:

https://github.com/ROCm/rccl/blob/2f6d59e2e651914d9d6e51b2b702b9a9ac0ea99d/makefiles/version.mk#L2
and
https://github.com/ROCm/rccl/blob/2f6d59e2e651914d9d6e51b2b702b9a9ac0ea99d/CMakeLists.txt#L669C1-L669C19

Hope this answers your question. Let's take a step back, we want to solve the problem of cudagraph mode.
My understanding is that below are possible ways :

  • cupy
  • user-defined nccl/rccl
  • custom all reduce
  • pytorch native all-reduce

How we can easily choose one over the other and what is our long-term plan?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cupy is deprecated and removed now, because we got many bug report with regard to cupy .

pytorch native all-reduce is not available in cudagraph mode, because it usually contains some additional check that will fail graph capture.

Going forward, we will focus on the pynccl wrapper as the first choice, and custom all reduce as a backup plan (it is disabled by default because of instability).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao Our users need the fixes for the other part like the one related to naive attention, since now it becomes the default for those users and it was quite slow.
I need to simplify this PR so that it will be merged quickly

else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.debug(f"Loading nccl from library {so_file}")
Expand Down
Loading