Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to torch==2.2.1 #2804

Closed
wants to merge 22 commits into from

Conversation

hmellor
Copy link
Collaborator

@hmellor hmellor commented Feb 7, 2024

Closes #2738
Closes #2722

  • torch==2.1.1 -> torch==2.2.1
  • xformers==0.0.23.post1 -> xformers==0.0.25
  • ROCM not updated because no torch==2.2 containers have been published yet

- `torch==2.1.1` -> `torch==2.2.0`
- `xformers==0.0.23.post1` -> `xformers==0.0.24`
- ROCM not updated because no `torch==2.2.0` containers have been published yet
@hmellor
Copy link
Collaborator Author

hmellor commented Feb 7, 2024

The failing tests are all failing with errors similar to:

__________________ ERROR collecting tests/kernels/test_moe.py __________________
--
  | ImportError while importing test module '/vllm-workspace/tests/kernels/test_moe.py'.
  | Hint: make sure your test modules/packages have valid Python names.
  | Traceback:
  | /usr/lib/python3.8/importlib/__init__.py:127: in import_module
  | return _bootstrap._gcd_import(name[level:], package, level)
  | kernels/test_moe.py:8: in <module>
  | from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
  | /usr/local/lib/python3.8/dist-packages/transformers/models/mixtral/modeling_mixtral.py:58: in <module>
  | from flash_attn import flash_attn_func, flash_attn_varlen_func
  | /usr/local/lib/python3.8/dist-packages/flash_attn/__init__.py:3: in <module>
  | from flash_attn.flash_attn_interface import (
  | /usr/local/lib/python3.8/dist-packages/flash_attn/flash_attn_interface.py:10: in <module>
  | import flash_attn_2_cuda as flash_attn_cuda
  | E   ImportError: /usr/local/lib/python3.8/dist-packages/flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEE

We're already using the latest versions of transformers and flash_attn, so I'm not sure where to go from here.

@Qubitium
Copy link
Contributor

Qubitium commented Feb 8, 2024

We have seen similar errors and have come to the conclusion any pkg that contains c/c++ extensions will randomly break if base pkg such as torch is updated via pip. Normally, if torch is updated to 2.2.0, the underlying dependent pkgs such as xformers/flashattn should recompile but they do not.

Our course of fix is:

  1. pip delete torch/re-install torch standalone manually, avoid using requirements.txt
  2. pip install/update flash-attn manually
  3. if 2 still doesn't work, git clone flash-attn and python setup.py install manually.
  4. delete the c++ extensions cache used by python if 2+3 both fails.

The pip pkg system is a huge minefield in my view when it comes to pkgs that contains jit or precompiled compiled c/c++ code.

@simon-mo simon-mo mentioned this pull request Feb 14, 2024
5 tasks
@hmellor
Copy link
Collaborator Author

hmellor commented Feb 15, 2024

After merging the latest changes from master the import errors are gone and now we are only seeing OOM errors in the models tests https://buildkite.com/vllm/ci/builds/1309#018dac6d-9c04-48d6-af20-53c46ad2b31a:

models/test_mistral.py::test_models[128-bfloat16-mistralai/Mistral-7B-Instruct-v0.1] FAILED
models/test_models.py::test_models[128-float-facebook/opt-125m] PASSED
models/test_models.py::test_models[128-float-meta-llama/Llama-2-7b-hf] FAILED
models/test_models.py::test_models[128-float-mistralai/Mistral-7B-v0.1] FAILED
models/test_models.py::test_models[128-float-Deci/DeciLM-7b] FAILED
models/test_models.py::test_models[128-float-tiiuae/falcon-7b] FAILED
models/test_models.py::test_models[128-float-gpt2] PASSED
models/test_models.py::test_models[128-float-bigcode/tiny_starcoder_py] PASSED
models/test_models.py::test_models[128-float-EleutherAI/gpt-j-6b] FAILED
models/test_models.py::test_models[128-float-EleutherAI/pythia-70m] PASSED
models/test_models.py::test_models[128-float-bigscience/bloom-560m] PASSED
models/test_models.py::test_models[128-float-mosaicml/mpt-7b] FAILED
models/test_models.py::test_models[128-float-microsoft/phi-2] PASSED
models/test_models.py::test_models[128-float-stabilityai/stablelm-3b-4e1t] PASSED

However, these seem to be treated as soft fails by the CI.

cc @WoosukKwon @simon-mo

@stas00
Copy link
Contributor

stas00 commented Feb 21, 2024

What's missing from merging this PR?

GCP/A3 doesn't work with torch<2.2 - so we can't use vllm

Thank you!

@hmellor
Copy link
Collaborator Author

hmellor commented Feb 21, 2024

Hi @stas00 , I've just been waiting for a review.

I've just merged main into this branch again because it went stale while waiting for review.

@WoosukKwon if you get a moment, could I get a review please? AMD have still not published any PyTorch 2.2.0 containers

@stas00
Copy link
Contributor

stas00 commented Feb 21, 2024

Thank you for leading this effort, @hmellor

Why does it have to be == and not >=? if you use torch>=2.1.2 then it kills many birds in one stone - those who want 2.1.2 (rocm) and those who need 2.2 (GCP/A3) can both use it.

@hmellor
Copy link
Collaborator Author

hmellor commented Feb 21, 2024

I've not used rocm myself, but think it's used with the container specified in https://github.com/vllm-project/vllm/blob/main/Dockerfile.rocm which comes with a specific version of PyTorch installed, hence why requirements-rocm.txt doesn't include torch.

@stas00
Copy link
Contributor

stas00 commented Feb 21, 2024

yes, but what I was trying to say is that by relaxing the restriction this project and its users will suffer a lot less maintenance as torch=2.2.1 will be released soon, and so on. and == makes it very difficult to install hundreds of other packages which may not work with fixed envs.

@hmellor
Copy link
Collaborator Author

hmellor commented Feb 21, 2024

I see your point, I'll wait for @WoosukKwon to comment in case he objects.

@WoosukKwon WoosukKwon self-requested a review February 21, 2024 23:47
requirements.txt Outdated Show resolved Hide resolved
pyproject.toml Outdated Show resolved Hide resolved
@WoosukKwon
Copy link
Collaborator

yes, but what I was trying to say is that by relaxing the restriction this project and its users will suffer a lot less maintenance as torch=2.2.1 will be released soon, and so on. and == makes it very difficult to install hundreds of other packages which may not work with fixed envs.

@stas00 Good point. Originally, we used >=. However, we found that this broke vLLM whenever PyTorch was upgraded, because vLLM's pre-compiled binaries were often not compatible with the new version of PyTorch.

@stas00
Copy link
Contributor

stas00 commented Feb 21, 2024

Thank you for clarifying, @WoosukKwon - there are several projects in this boat and they have to be constantly asked to build wheels for new versions of pytorch because of that.

I think at the very least you shouldn't ask for the .patch versions to match, that is torch==1.2.* should suffice - patch releases (1.2.1, 1.2.2) are usually only bug fix releases and should be binary compatible with X.Y.0 versions.

So how do we then move forward - do you plan to make a torch=2.2.* compatible binary wheel - should we open a request? that would mean pip install on its own won't work any longer if a specific torch version is needed - you'd have to supply --index-url argument with the specific build. e.g. same as pytorch does:

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

but of course by default the latest build should work fine w/o --index-url

@hijkzzz
Copy link

hijkzzz commented Feb 22, 2024

can you set torch==2.1.1 -> torch>=2.2.0
the fixed versions are not friendly to other packages

@WoosukKwon
Copy link
Collaborator

@stas00 Thanks for the advice!

I think at the very least you shouldn't ask for the .patch versions to match, that is torch==1.2.* should suffice - patch releases (1.2.1, 1.2.2) are usually only bug fix releases and should be binary compatible with X.Y.0 versions.

Could you let us know where you found this information? If this is guaranteed, I think we can be a bit more flexible like torch>=2.2.0,<2.3.0.

Co-authored-by: Woosuk Kwon <[email protected]>
@stas00
Copy link
Contributor

stas00 commented Feb 22, 2024

I'm double checking with the pytorch developers here: https://pytorch.slack.com/archives/C3PDTEV8E/p1708624187379559 - will follow up once I get a yay or nay from there.

@WoosukKwon
Copy link
Collaborator

@hmellor could you rebase the PR with the latest main branch (which includes #2982)? This will fix the CI failures.

@mgoin
Copy link
Collaborator

mgoin commented Feb 22, 2024

Hi @stas00 I don't believe it is a fair assumption that there is binary compatibility between pytorch patch versions. I investigated this myself a few weeks ago and found this not to be true. In addition, there is this issue response from a pytorch developer that claims this is not a priority: pytorch/pytorch#88301 (comment)

In general, there are no guarantee for binary compatibility between even the patch versions (as adding a new virtual function to base class will offset vptr of all inherited classes).

@stas00
Copy link
Contributor

stas00 commented Feb 22, 2024

Thank you for finding that reply, @mgoin - my question at the pytorch slack got the same answer - we don't know - it might be compatible or it might be not - it's not being validated or expected to be so.

So I stand corrected, my suggestion that one could use torch==2.1.* is not a safe one even though it might work for some patch releases. It probably would depend on which symbols vllm relies on.

@hmellor
Copy link
Collaborator Author

hmellor commented Mar 12, 2024

Sure, done

@stas00
Copy link
Contributor

stas00 commented Mar 12, 2024

Thanks a lot, Harry - this branch can be built from source again.

requirements.txt Outdated Show resolved Hide resolved
@youkaichao
Copy link
Member

Pytorch 2.2.1 does not work, either. I tested it today. (pls. xformers just upgrade to 0.0.25).

@hmellor hmellor closed this Mar 19, 2024
@hmellor hmellor deleted the pytorch-2.2.0-upgrade branch March 19, 2024 15:29
@hmellor hmellor restored the pytorch-2.2.0-upgrade branch March 19, 2024 15:29
@hmellor hmellor reopened this Mar 19, 2024
@hmellor hmellor changed the title Upgrade to torch==2.2.0 Upgrade to torch==2.2.1 Mar 19, 2024
@tanguofu
Copy link

tanguofu commented Apr 1, 2024

Are there a merge plan?

@hmellor
Copy link
Collaborator Author

hmellor commented Apr 4, 2024

I'm not sure why the Docker container is failing to build.

Unhelpfully, the failing command VLLM_USE_PRECOMPILED=1 pip install . --verbose works on my machine 🙃.

@youkaichao
Copy link
Member

@hmellor there are subtle issues in upgrading to pt 2.2 , which requires a self-managed nccl version. I have done it in #3805 , and it should be merged recently. Thank you for your contribution!

@stas00
Copy link
Contributor

stas00 commented Apr 4, 2024

@hmellor there are subtle issues in upgrading to pt 2.2 , which requires a self-managed nccl version. I have done it in #3805 , and it should be merged recently. Thank you for your contribution!

which is a problem in its own way - GCP A3 (H100) instances require nccl-2.19.3+ for their custom TCPX networking to function. It can't work with a lower NCCL version that you proposed. So if you force this nccl version you will cut out GCP A3 users. To be exact it'd only impact use cases with more than one node, since one node doesn't require TCPX.

@hmellor
Copy link
Collaborator Author

hmellor commented Apr 4, 2024

@youkaichao based on your issue NVIDIA/nccl#1234 am I right in saying that the issue is that torch 2.2 defaults to a newer version of nccl which uses more memory?

@youkaichao
Copy link
Member

@youkaichao based on your issue NVIDIA/nccl#1234 am I right in saying that the issue is that torch 2.2 defaults to a newer version of nccl which uses more memory?

Yes. So we need to pin a nccl version ourselves.

@youkaichao
Copy link
Member

close via #3805 .

@youkaichao youkaichao closed this Apr 4, 2024
@hmellor hmellor deleted the pytorch-2.2.0-upgrade branch April 4, 2024 17:57
@hmellor hmellor mentioned this pull request May 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Will we update to pytorch v2.2.0 ? Request support for torch2.2
8 participants