Skip to content

Commit

Permalink
[6/N] torch.compile rollout to users (#10437)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 19, 2024
1 parent fd9f124 commit 803f37e
Show file tree
Hide file tree
Showing 15 changed files with 129 additions and 141 deletions.
5 changes: 0 additions & 5 deletions tests/compile/piecewise/piecewise_compilation_config.json

This file was deleted.

18 changes: 7 additions & 11 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import os

import torch
from torch import nn
Expand All @@ -11,7 +10,7 @@
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op

Expand Down Expand Up @@ -77,12 +76,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

def test_simple_piecewise_compile():

directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
cudagraph_copy_inputs=True,
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')

Expand All @@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))

# clean up to avoid side effects for other tests
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]
45 changes: 17 additions & 28 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
import os
from dataclasses import dataclass
from typing import Optional, Tuple

Expand All @@ -18,7 +17,7 @@
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_compilation_config, set_current_vllm_config
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
Expand Down Expand Up @@ -254,23 +253,17 @@ def run_model(llama_config,
split_attn: bool = False) -> torch.Tensor:

if use_compile:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
CompilationLevel.PIECEWISE)

compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
)
if split_attn:
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
))
else:
set_compilation_config(CompilationConfig(use_cudagraph=True, ))
compilation_config.non_cudagraph_ops = ["silly.attention"]
else:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )

vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
Expand All @@ -288,10 +281,6 @@ def run_model(llama_config,
input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])

# manual cleanup
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
set_compilation_config(None)

output = output.cpu()

if llama_config.tractable_init:
Expand Down Expand Up @@ -361,7 +350,6 @@ def test_toy_llama():

@torch.inference_mode
def benchmark():
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
from triton.testing import do_bench

# similar to llama 3.1-8B
Expand All @@ -387,15 +375,16 @@ def benchmark():

for piecewise in [False, True]:
if piecewise:
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
))
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
)
else:
set_compilation_config(None)
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, )

vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
Expand Down
13 changes: 9 additions & 4 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,36 @@ def test_compile_correctness(test_setting: TestSetting):
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
["-tp", str(tp_size)]

all_args: List[List[str]] = []
all_envs: List[Optional[Dict[str, str]]] = []

for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE,
]:
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
all_args.append(final_args + ["-O", str(level)])
all_envs.append({})

# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model, [final_args] * 2,
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close")
all_envs.clear()
all_args.clear()

for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
]:
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
all_args.append(final_args + ["-O", str(level)])
all_envs.append({})
if level != CompilationLevel.DYNAMO_ONCE and not fullgraph:
# "DYNAMO_ONCE" will always use fullgraph
all_envs[-1][
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore

compare_all_settings(model, [final_args] * 3, all_envs, method=method)
compare_all_settings(model, all_args * 3, all_envs, method=method)
4 changes: 2 additions & 2 deletions tests/compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.config import CompilationLevel
from vllm.config import CompilationConfig, CompilationLevel
from vllm.platforms import current_platform

TEST_MODELS = [
Expand Down Expand Up @@ -65,7 +65,6 @@ def check_full_graph_support(model,
optimization_level,
tp_size=1):
# make sure these models can be captured in full graph mode
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"

# The base meta llama uses too much memory.
Expand All @@ -86,6 +85,7 @@ def check_full_graph_support(model,
enforce_eager=True,
tensor_parallel_size=tp_size,
disable_custom_all_reduce=True,
compilation_config=CompilationConfig(level=optimization_level),
**model_kwargs)

outputs = llm.generate(prompts, sampling_params)
Expand Down
4 changes: 1 addition & 3 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import List

import pytest
Expand Down Expand Up @@ -53,9 +52,8 @@ class Relu3(ReLUSquaredActivation):
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
level=torch_level, custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

Expand Down
47 changes: 35 additions & 12 deletions tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,47 @@
import glob
import os
import runpy
import tempfile

import depyf

from vllm.config import CompilationLevel

# disable custom dispatcher, let Dynamo takes over
# all the control
os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS)
from vllm.config import CompilationConfig, CompilationLevel

temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
cur_dir = os.path.dirname(__file__)
parent_dir = os.path.dirname(cur_dir)
root_dir = os.path.dirname(parent_dir)
example_file = os.path.join(root_dir, "examples",
"offline_inference_tpu.py")
runpy.run_path(example_file)
from vllm import LLM, SamplingParams

prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, through inaction, allow a human being to come to harm.",
" what is essential is invisible to the eye.",
" but in rising every time we fall.",
]
N = 1
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
sampling_params = SamplingParams(temperature=0.7,
top_p=1.0,
n=N,
max_tokens=16)

# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.

# disable custom dispatcher, let Dynamo takes over
# all the control
llm = LLM(model="google/gemma-2b",
enforce_eager=True,
compilation_config=CompilationConfig(
level=CompilationLevel.DYNAMO_AS_IS))
outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
assert generated_text.startswith(answer)

compiled_code = sorted(
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
Expand Down
10 changes: 6 additions & 4 deletions tests/tpu/test_custom_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
def test_custom_dispatcher():
compare_two_settings(
"google/gemma-2b",
arg1=["--enforce-eager"],
arg2=["--enforce-eager"],
env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)},
env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)})
arg1=["--enforce-eager", "-O",
str(CompilationLevel.DYNAMO_ONCE)],
arg2=["--enforce-eager", "-O",
str(CompilationLevel.DYNAMO_AS_IS)],
env1={},
env2={})
43 changes: 20 additions & 23 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2174,8 +2174,14 @@ class CompilationConfig(BaseModel):
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr

@classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value))
return CompilationConfig.model_validate_json(cli_value)

def model_post_init(self, __context: Any) -> None:
self.level = envs.VLLM_TORCH_COMPILE_LEVEL

count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all")
Expand Down Expand Up @@ -2249,26 +2255,6 @@ def init_during_runtime(self):
"inductor_specialize_for_cudagraph_no_more_than is None")
self.compile_sizes = self.inductor_compile_sizes

@staticmethod
def select_and_init_config() -> "CompilationConfig":
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
if config_path is not None:
with open(config_path) as json_file:
config = CompilationConfig.model_validate_json(
json_file.read())
else:
from vllm.plugins import get_compilation_config
predefined_config = get_compilation_config()
config = predefined_config if predefined_config is not None else (
CompilationConfig())

return config


@dataclass
class VllmConfig:
Expand Down Expand Up @@ -2354,8 +2340,19 @@ def __post_init__(self):
self.model_config, self.load_config)

if self.compilation_config is None:
self.compilation_config = CompilationConfig.select_and_init_config(
)
self.compilation_config = CompilationConfig()
if envs.VLLM_USE_V1:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True
self.compilation_config.non_cudagraph_ops = [
"vllm.unified_v1_flash_attention"
]
self.compilation_config.use_inductor = True
self.compilation_config.enable_fusion = False

current_platform.check_and_update_config(self)

Expand Down
Loading

0 comments on commit 803f37e

Please sign in to comment.