Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add a wrapper for torch.inference_mode #6618

Merged
merged 6 commits into from
Jul 22, 2024
Merged

Conversation

WoosukKwon
Copy link
Collaborator

torch.inference_mode is not supported by some hardware backends such as TPU. To address this, this PR introduces a wrapper class that falls back to torch.no_grad for the unsupported backends.

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 21, 2024
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member

is it possible to put it into https://github.com/vllm-project/vllm/tree/main/vllm/platforms ?

e.g.

class CudaPlatform(Platform):
    inference_mode = torch.inference_mode
class TpuPlatform(Platform):
    inference_mode = torch.no_grad

@WoosukKwon
Copy link
Collaborator Author

@youkaichao Sorry I don't get your point. What will the API look like in your proposal?

Orthogonally, torch.inference_mode and torch.no_grad are slightly different. For example, inference_mode has mode: bool =True input arg while no_grad doesn't take any input arg. They need a wrapper class to cover this difference.

@youkaichao
Copy link
Member

The usage will be a unified decorator:

from vllm.platforms import current_platform

class ModelRunner:

    @current_platform.inference_mode
    def execute_model(

it dispatches to torch.inference_mode in GPU or torch.no_grad in TPU.

we don't need to mimic full functionality of torch.inference_mode, i.e. we don't need to have a mode argument.

this is how we unify the heterogeneity of hardware platforms.

@WoosukKwon
Copy link
Collaborator Author

@youkaichao Got it. Thanks for the explanation.

we don't need to mimic full functionality of torch.inference_mode, i.e. we don't need to have a mode argument.

You're right. I simplified the code based on your suggestion. PTAL.

@current_platform.inference_mode

I also agree that this could be a better idea. However, because inference_mode is used in worker_base and mode_runner_base, your approach will require registering all hardware platforms (CPU, openvino, neuron, tpu, gaudi, xpu, etc.) to current_platform. While we should do this eventually, I'd like to defer it until the hardware backends are more settled. For now, I think it makes sense to put inference_mode under utils, so that this does not block other PRs.

@youkaichao
Copy link
Member

your approach will require registering all hardware platforms (CPU, openvino, neuron, tpu, gaudi, xpu, etc.)

not necessary. we are doing it step by step. for example, I only implemented get_device_capability in cuda and rocm platforms, without touching the rest of the code.

for the problem you mentioned, we can have another platform called UnspecifiedPlatform, to hold inference_mode with the default value of torch.inference_mode. Then CPU, openvino, neuron, gaudi, xpu can all benefit from it. And only TPU needs to have a specialized code path to dispatch to torch.no_grad .

current_platform = None

I can do it in a followup PR, if you don't have bandwidth.

I think starting the code in a unified way is better than we gather around code later.

@WoosukKwon
Copy link
Collaborator Author

@youkaichao I see. Update the PR. PTAL.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for addressing my comments! let's see if tests pass smoothly 👍

@WoosukKwon
Copy link
Collaborator Author

@youkaichao Thanks for your feedback! Let me merge the PR as it passes the tests.

@WoosukKwon WoosukKwon merged commit 42de2ce into main Jul 22, 2024
72 checks passed
@WoosukKwon WoosukKwon deleted the inference-mode branch July 22, 2024 01:43
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
gnpinkert pushed a commit to gnpinkert/vllm that referenced this pull request Jul 26, 2024
cduk pushed a commit to cduk/vllm-pascal that referenced this pull request Aug 6, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants