From 5036bb15fe608a97aac68fae0af80fab2d03bdb1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 14 Sep 2024 13:07:16 -0700 Subject: [PATCH] [torch.compile] add a flag to disable custom op (#8488) --- tests/compile/test_full_graph.py | 3 ++- vllm/envs.py | 5 +++++ vllm/model_executor/custom_op.py | 5 +++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 43905082b7caf..5452ce6be8110 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -6,7 +6,8 @@ @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) def test_full_graph(model): # make sure these models can be captured in full graph mode - os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" + if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: + os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" from vllm import LLM, SamplingParams prompts = [ diff --git a/vllm/envs.py b/vllm/envs.py index 0dd24c76d0f0b..fbb4fc7489eab 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -203,6 +203,11 @@ def get_default_config_root(): (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in ("true", "1")), + # Internal flag to control whether we use custom op, + # or use the native pytorch implementation + "VLLM_TEST_COMPILE_NO_CUSTOM_OPS": + lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")), + # Internal flag to enable Dynamo fullgraph capture "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 2356570af141a..9ebecb4c44dbf 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,5 +1,6 @@ import torch.nn as nn +import vllm.envs as envs from vllm.platforms import current_platform from vllm.utils import is_cpu, is_hip, is_xpu @@ -60,6 +61,10 @@ def forward_npu(self, *args, **kwargs): def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. + + if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS: + return self.forward_native + if is_hip(): return self.forward_hip elif is_cpu():